Files
rag_agent/rag-web-ui/backend/app/services/graph/graphrag_adapter.py
2026-04-13 11:34:23 +08:00

184 lines
7.0 KiB
Python

import asyncio
import importlib
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from app.core.config import settings
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.llm.llm_factory import LLMFactory
class GraphRAGAdapter:
_instance_lock = asyncio.Lock()
def __init__(self):
self._graphrag_instances: Dict[int, Any] = {}
self._kb_locks: Dict[int, asyncio.Lock] = {}
self._embedding_model = EmbeddingsFactory.create()
self._llm_model = LLMFactory.create(streaming=False)
self._symbols = self._load_symbols()
def _load_symbols(self) -> Dict[str, Any]:
module = importlib.import_module("nano_graphrag")
storage_module = importlib.import_module("nano_graphrag._storage")
utils_module = importlib.import_module("nano_graphrag._utils")
return {
"GraphRAG": module.GraphRAG,
"QueryParam": module.QueryParam,
"Neo4jStorage": getattr(storage_module, "Neo4jStorage"),
"NetworkXStorage": getattr(storage_module, "NetworkXStorage"),
"EmbeddingFunc": getattr(utils_module, "EmbeddingFunc"),
}
def _get_kb_lock(self, kb_id: int) -> asyncio.Lock:
if kb_id not in self._kb_locks:
self._kb_locks[kb_id] = asyncio.Lock()
return self._kb_locks[kb_id]
async def _llm_complete(self, prompt: str, system_prompt: Optional[str] = None, history_messages: Optional[List[Any]] = None, **kwargs: Any) -> str:
history_messages = history_messages or []
history_lines: List[str] = []
for item in history_messages:
if isinstance(item, dict):
role = str(item.get("role", "user"))
content = item.get("content", "")
if isinstance(content, list):
joined = " ".join(str(part.get("text", "")) for part in content if isinstance(part, dict))
history_lines.append(f"{role}: {joined}")
else:
history_lines.append(f"{role}: {content}")
else:
history_lines.append(str(item))
full_prompt = "\n\n".join(
part
for part in [
f"系统提示: {system_prompt}" if system_prompt else "",
"历史对话:\n" + "\n".join(history_lines) if history_lines else "",
"用户输入:\n" + prompt,
]
if part
)
model = self._llm_model
max_tokens = kwargs.get("max_tokens")
if max_tokens is not None:
try:
model = model.bind(max_tokens=max_tokens)
except Exception:
pass
response = await model.ainvoke(full_prompt)
content = getattr(response, "content", response)
if isinstance(content, str):
return content
return str(content)
async def _embedding_call(self, texts: List[str]) -> np.ndarray:
vectors = await asyncio.to_thread(self._embedding_model.embed_documents, texts)
return np.array(vectors)
async def _get_or_create(self, kb_id: int) -> Any:
if kb_id in self._graphrag_instances:
return self._graphrag_instances[kb_id]
async with GraphRAGAdapter._instance_lock:
if kb_id in self._graphrag_instances:
return self._graphrag_instances[kb_id]
GraphRAG = self._symbols["GraphRAG"]
EmbeddingFunc = self._symbols["EmbeddingFunc"]
embedding_func = EmbeddingFunc(
embedding_dim=settings.GRAPHRAG_EMBEDDING_DIM,
max_token_size=settings.GRAPHRAG_EMBEDDING_MAX_TOKEN_SIZE,
func=self._embedding_call,
)
graph_storage_cls = self._symbols["NetworkXStorage"]
addon_params: Dict[str, Any] = {}
if settings.GRAPHRAG_GRAPH_STORAGE.lower() == "neo4j":
graph_storage_cls = self._symbols["Neo4jStorage"]
addon_params = {
"neo4j_url": settings.NEO4J_URL,
"neo4j_auth": (settings.NEO4J_USERNAME, settings.NEO4J_PASSWORD),
}
working_dir = str(Path(settings.GRAPHRAG_WORKING_DIR) / f"kb_{kb_id}")
rag = GraphRAG(
working_dir=working_dir,
enable_local=True,
enable_naive_rag=True,
graph_storage_cls=graph_storage_cls,
addon_params=addon_params,
embedding_func=embedding_func,
best_model_func=self._llm_complete,
cheap_model_func=self._llm_complete,
entity_extract_max_gleaning=settings.GRAPHRAG_ENTITY_EXTRACT_MAX_GLEANING,
)
self._graphrag_instances[kb_id] = rag
return rag
async def ingest_texts(self, kb_id: int, texts: List[str]) -> None:
cleaned = [text.strip() for text in texts if text and text.strip()]
if not cleaned:
return
rag = await self._get_or_create(kb_id)
lock = self._get_kb_lock(kb_id)
async with lock:
await rag.ainsert(cleaned)
async def local_context(self, kb_id: int, query: str, *, top_k: int = 20, level: int = 2) -> str:
rag = await self._get_or_create(kb_id)
QueryParam = self._symbols["QueryParam"]
param = QueryParam(
mode="local",
top_k=top_k,
level=level,
only_need_context=True,
)
return await rag.aquery(query, param)
async def global_context(self, kb_id: int, query: str, *, level: int = 2) -> str:
rag = await self._get_or_create(kb_id)
QueryParam = self._symbols["QueryParam"]
param = QueryParam(
mode="global",
level=level,
only_need_context=True,
)
return await rag.aquery(query, param)
async def local_context_multi(self, kb_ids: List[int], query: str, *, top_k: int = 20, level: int = 2) -> Tuple[str, List[int]]:
contexts: List[str] = []
used_kb_ids: List[int] = []
for kb_id in kb_ids:
try:
ctx = await self.local_context(kb_id, query, top_k=top_k, level=level)
if ctx:
contexts.append(f"[KB:{kb_id}]\n{ctx}")
used_kb_ids.append(kb_id)
except Exception:
continue
return "\n\n".join(contexts), used_kb_ids
async def global_context_multi(self, kb_ids: List[int], query: str, *, level: int = 2) -> Tuple[str, List[int]]:
contexts: List[str] = []
used_kb_ids: List[int] = []
for kb_id in kb_ids:
try:
ctx = await self.global_context(kb_id, query, level=level)
if ctx:
contexts.append(f"[KB:{kb_id}]\n{ctx}")
used_kb_ids.append(kb_id)
except Exception:
continue
return "\n\n".join(contexts), used_kb_ids