import asyncio import importlib from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import numpy as np from app.core.config import settings from app.services.embedding.embedding_factory import EmbeddingsFactory from app.services.llm.llm_factory import LLMFactory class GraphRAGAdapter: _instance_lock = asyncio.Lock() def __init__(self): self._graphrag_instances: Dict[int, Any] = {} self._kb_locks: Dict[int, asyncio.Lock] = {} self._embedding_model = EmbeddingsFactory.create() self._llm_model = LLMFactory.create(streaming=False) self._symbols = self._load_symbols() def _load_symbols(self) -> Dict[str, Any]: module = importlib.import_module("nano_graphrag") storage_module = importlib.import_module("nano_graphrag._storage") utils_module = importlib.import_module("nano_graphrag._utils") return { "GraphRAG": module.GraphRAG, "QueryParam": module.QueryParam, "Neo4jStorage": getattr(storage_module, "Neo4jStorage"), "NetworkXStorage": getattr(storage_module, "NetworkXStorage"), "EmbeddingFunc": getattr(utils_module, "EmbeddingFunc"), } def _get_kb_lock(self, kb_id: int) -> asyncio.Lock: if kb_id not in self._kb_locks: self._kb_locks[kb_id] = asyncio.Lock() return self._kb_locks[kb_id] async def _llm_complete(self, prompt: str, system_prompt: Optional[str] = None, history_messages: Optional[List[Any]] = None, **kwargs: Any) -> str: history_messages = history_messages or [] history_lines: List[str] = [] for item in history_messages: if isinstance(item, dict): role = str(item.get("role", "user")) content = item.get("content", "") if isinstance(content, list): joined = " ".join(str(part.get("text", "")) for part in content if isinstance(part, dict)) history_lines.append(f"{role}: {joined}") else: history_lines.append(f"{role}: {content}") else: history_lines.append(str(item)) full_prompt = "\n\n".join( part for part in [ f"系统提示: {system_prompt}" if system_prompt else "", "历史对话:\n" + "\n".join(history_lines) if history_lines else "", "用户输入:\n" + prompt, ] if part ) model = self._llm_model max_tokens = kwargs.get("max_tokens") if max_tokens is not None: try: model = model.bind(max_tokens=max_tokens) except Exception: pass response = await model.ainvoke(full_prompt) content = getattr(response, "content", response) if isinstance(content, str): return content return str(content) async def _embedding_call(self, texts: List[str]) -> np.ndarray: vectors = await asyncio.to_thread(self._embedding_model.embed_documents, texts) return np.array(vectors) async def _get_or_create(self, kb_id: int) -> Any: if kb_id in self._graphrag_instances: return self._graphrag_instances[kb_id] async with GraphRAGAdapter._instance_lock: if kb_id in self._graphrag_instances: return self._graphrag_instances[kb_id] GraphRAG = self._symbols["GraphRAG"] EmbeddingFunc = self._symbols["EmbeddingFunc"] embedding_func = EmbeddingFunc( embedding_dim=settings.GRAPHRAG_EMBEDDING_DIM, max_token_size=settings.GRAPHRAG_EMBEDDING_MAX_TOKEN_SIZE, func=self._embedding_call, ) graph_storage_cls = self._symbols["NetworkXStorage"] addon_params: Dict[str, Any] = {} if settings.GRAPHRAG_GRAPH_STORAGE.lower() == "neo4j": graph_storage_cls = self._symbols["Neo4jStorage"] addon_params = { "neo4j_url": settings.NEO4J_URL, "neo4j_auth": (settings.NEO4J_USERNAME, settings.NEO4J_PASSWORD), } working_dir = str(Path(settings.GRAPHRAG_WORKING_DIR) / f"kb_{kb_id}") rag = GraphRAG( working_dir=working_dir, enable_local=True, enable_naive_rag=True, graph_storage_cls=graph_storage_cls, addon_params=addon_params, embedding_func=embedding_func, best_model_func=self._llm_complete, cheap_model_func=self._llm_complete, entity_extract_max_gleaning=settings.GRAPHRAG_ENTITY_EXTRACT_MAX_GLEANING, ) self._graphrag_instances[kb_id] = rag return rag async def ingest_texts(self, kb_id: int, texts: List[str]) -> None: cleaned = [text.strip() for text in texts if text and text.strip()] if not cleaned: return rag = await self._get_or_create(kb_id) lock = self._get_kb_lock(kb_id) async with lock: await rag.ainsert(cleaned) async def local_context(self, kb_id: int, query: str, *, top_k: int = 20, level: int = 2) -> str: rag = await self._get_or_create(kb_id) QueryParam = self._symbols["QueryParam"] param = QueryParam( mode="local", top_k=top_k, level=level, only_need_context=True, ) return await rag.aquery(query, param) async def global_context(self, kb_id: int, query: str, *, level: int = 2) -> str: rag = await self._get_or_create(kb_id) QueryParam = self._symbols["QueryParam"] param = QueryParam( mode="global", level=level, only_need_context=True, ) return await rag.aquery(query, param) async def local_context_multi(self, kb_ids: List[int], query: str, *, top_k: int = 20, level: int = 2) -> Tuple[str, List[int]]: contexts: List[str] = [] used_kb_ids: List[int] = [] for kb_id in kb_ids: try: ctx = await self.local_context(kb_id, query, top_k=top_k, level=level) if ctx: contexts.append(f"[KB:{kb_id}]\n{ctx}") used_kb_ids.append(kb_id) except Exception: continue return "\n\n".join(contexts), used_kb_ids async def global_context_multi(self, kb_ids: List[int], query: str, *, level: int = 2) -> Tuple[str, List[int]]: contexts: List[str] = [] used_kb_ids: List[int] = [] for kb_id in kb_ids: try: ctx = await self.global_context(kb_id, query, level=level) if ctx: contexts.append(f"[KB:{kb_id}]\n{ctx}") used_kb_ids.append(kb_id) except Exception: continue return "\n\n".join(contexts), used_kb_ids