import json import os from logging import getLogger from typing import Dict, List import faiss import pickle from ielts_be.services import IKnowledgeBase class TrainingContentKnowledgeBase(IKnowledgeBase): def __init__(self, embeddings, path: str = 'pathways_2_rw_with_ids.json'): self._embedding_model = embeddings self._tips = None # self._read_json(path) self._category_metadata = None self._indices = None self._logger = getLogger(__name__) self.load_indices_and_metadata() @staticmethod def _read_json(path: str) -> Dict[str, any]: with open(path, 'r', encoding="utf-8") as json_file: return json.loads(json_file.read()) def print_category_count(self): category_tips = {} for unit in self._tips['units']: for page in unit['pages']: for tip in page['tips']: category = tip['category'].lower().replace(" ", "_") if category not in category_tips: category_tips[category] = 0 else: category_tips[category] = category_tips[category] + 1 print(category_tips) def create_embeddings_and_save_them(self) -> None: category_embeddings = {} category_metadata = {} for unit in self._tips['units']: for page in unit['pages']: for tip in page['tips']: category = tip['category'].lower().replace(" ", "_") if category not in category_embeddings: category_embeddings[category] = [] category_metadata[category] = [] category_embeddings[category].append(tip['embedding']) category_metadata[category].append({"id": tip['id'], "text": tip['text']}) category_indices = {} for category, embeddings in category_embeddings.items(): embeddings_array = self._embedding_model.encode(embeddings) index = faiss.IndexFlatL2(embeddings_array.shape[1]) index.add(embeddings_array) category_indices[category] = index faiss.write_index(index, f"./faiss/{category}_tips_index.faiss") with open("./faiss/tips_metadata.pkl", "wb") as f: pickle.dump(category_metadata, f) def load_indices_and_metadata( self, directory: str = './faiss', suffix: str = '_tips_index.faiss', metadata_path: str = './faiss/tips_metadata.pkl' ): files = os.listdir(directory) self._indices = {} for file in files: if file.endswith(suffix): self._indices[file[:-len(suffix)]] = faiss.read_index(f'{directory}/{file}') self._logger.info(f'Loaded embeddings for {file[:-len(suffix)]} category.') with open(metadata_path, 'rb') as f: self._category_metadata = pickle.load(f) self._logger.info("Loaded tips metadata") def query_knowledge_base(self, query: str, category: str, top_k: int = 5) -> List[Dict[str, str]]: query_embedding = self._embedding_model.encode([query]) index = self._indices[category] D, I = index.search(query_embedding, top_k) results = [self._category_metadata[category][i] for i in I[0]] return results