184 lines
7.0 KiB
Python
184 lines
7.0 KiB
Python
import asyncio
|
|
import importlib
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from app.core.config import settings
|
|
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
|
from app.services.llm.llm_factory import LLMFactory
|
|
|
|
|
|
class GraphRAGAdapter:
|
|
_instance_lock = asyncio.Lock()
|
|
|
|
def __init__(self):
|
|
self._graphrag_instances: Dict[int, Any] = {}
|
|
self._kb_locks: Dict[int, asyncio.Lock] = {}
|
|
self._embedding_model = EmbeddingsFactory.create()
|
|
self._llm_model = LLMFactory.create(streaming=False)
|
|
self._symbols = self._load_symbols()
|
|
|
|
def _load_symbols(self) -> Dict[str, Any]:
|
|
module = importlib.import_module("nano_graphrag")
|
|
|
|
storage_module = importlib.import_module("nano_graphrag._storage")
|
|
utils_module = importlib.import_module("nano_graphrag._utils")
|
|
|
|
return {
|
|
"GraphRAG": module.GraphRAG,
|
|
"QueryParam": module.QueryParam,
|
|
"Neo4jStorage": getattr(storage_module, "Neo4jStorage"),
|
|
"NetworkXStorage": getattr(storage_module, "NetworkXStorage"),
|
|
"EmbeddingFunc": getattr(utils_module, "EmbeddingFunc"),
|
|
}
|
|
|
|
def _get_kb_lock(self, kb_id: int) -> asyncio.Lock:
|
|
if kb_id not in self._kb_locks:
|
|
self._kb_locks[kb_id] = asyncio.Lock()
|
|
return self._kb_locks[kb_id]
|
|
|
|
async def _llm_complete(self, prompt: str, system_prompt: Optional[str] = None, history_messages: Optional[List[Any]] = None, **kwargs: Any) -> str:
|
|
history_messages = history_messages or []
|
|
|
|
history_lines: List[str] = []
|
|
for item in history_messages:
|
|
if isinstance(item, dict):
|
|
role = str(item.get("role", "user"))
|
|
content = item.get("content", "")
|
|
if isinstance(content, list):
|
|
joined = " ".join(str(part.get("text", "")) for part in content if isinstance(part, dict))
|
|
history_lines.append(f"{role}: {joined}")
|
|
else:
|
|
history_lines.append(f"{role}: {content}")
|
|
else:
|
|
history_lines.append(str(item))
|
|
|
|
full_prompt = "\n\n".join(
|
|
part
|
|
for part in [
|
|
f"系统提示: {system_prompt}" if system_prompt else "",
|
|
"历史对话:\n" + "\n".join(history_lines) if history_lines else "",
|
|
"用户输入:\n" + prompt,
|
|
]
|
|
if part
|
|
)
|
|
|
|
model = self._llm_model
|
|
max_tokens = kwargs.get("max_tokens")
|
|
if max_tokens is not None:
|
|
try:
|
|
model = model.bind(max_tokens=max_tokens)
|
|
except Exception:
|
|
pass
|
|
|
|
response = await model.ainvoke(full_prompt)
|
|
content = getattr(response, "content", response)
|
|
if isinstance(content, str):
|
|
return content
|
|
return str(content)
|
|
|
|
async def _embedding_call(self, texts: List[str]) -> np.ndarray:
|
|
vectors = await asyncio.to_thread(self._embedding_model.embed_documents, texts)
|
|
return np.array(vectors)
|
|
|
|
async def _get_or_create(self, kb_id: int) -> Any:
|
|
if kb_id in self._graphrag_instances:
|
|
return self._graphrag_instances[kb_id]
|
|
|
|
async with GraphRAGAdapter._instance_lock:
|
|
if kb_id in self._graphrag_instances:
|
|
return self._graphrag_instances[kb_id]
|
|
|
|
GraphRAG = self._symbols["GraphRAG"]
|
|
EmbeddingFunc = self._symbols["EmbeddingFunc"]
|
|
|
|
embedding_func = EmbeddingFunc(
|
|
embedding_dim=settings.GRAPHRAG_EMBEDDING_DIM,
|
|
max_token_size=settings.GRAPHRAG_EMBEDDING_MAX_TOKEN_SIZE,
|
|
func=self._embedding_call,
|
|
)
|
|
|
|
graph_storage_cls = self._symbols["NetworkXStorage"]
|
|
addon_params: Dict[str, Any] = {}
|
|
if settings.GRAPHRAG_GRAPH_STORAGE.lower() == "neo4j":
|
|
graph_storage_cls = self._symbols["Neo4jStorage"]
|
|
addon_params = {
|
|
"neo4j_url": settings.NEO4J_URL,
|
|
"neo4j_auth": (settings.NEO4J_USERNAME, settings.NEO4J_PASSWORD),
|
|
}
|
|
|
|
working_dir = str(Path(settings.GRAPHRAG_WORKING_DIR) / f"kb_{kb_id}")
|
|
|
|
rag = GraphRAG(
|
|
working_dir=working_dir,
|
|
enable_local=True,
|
|
enable_naive_rag=True,
|
|
graph_storage_cls=graph_storage_cls,
|
|
addon_params=addon_params,
|
|
embedding_func=embedding_func,
|
|
best_model_func=self._llm_complete,
|
|
cheap_model_func=self._llm_complete,
|
|
entity_extract_max_gleaning=settings.GRAPHRAG_ENTITY_EXTRACT_MAX_GLEANING,
|
|
)
|
|
self._graphrag_instances[kb_id] = rag
|
|
return rag
|
|
|
|
async def ingest_texts(self, kb_id: int, texts: List[str]) -> None:
|
|
cleaned = [text.strip() for text in texts if text and text.strip()]
|
|
if not cleaned:
|
|
return
|
|
|
|
rag = await self._get_or_create(kb_id)
|
|
lock = self._get_kb_lock(kb_id)
|
|
async with lock:
|
|
await rag.ainsert(cleaned)
|
|
|
|
async def local_context(self, kb_id: int, query: str, *, top_k: int = 20, level: int = 2) -> str:
|
|
rag = await self._get_or_create(kb_id)
|
|
QueryParam = self._symbols["QueryParam"]
|
|
param = QueryParam(
|
|
mode="local",
|
|
top_k=top_k,
|
|
level=level,
|
|
only_need_context=True,
|
|
)
|
|
return await rag.aquery(query, param)
|
|
|
|
async def global_context(self, kb_id: int, query: str, *, level: int = 2) -> str:
|
|
rag = await self._get_or_create(kb_id)
|
|
QueryParam = self._symbols["QueryParam"]
|
|
param = QueryParam(
|
|
mode="global",
|
|
level=level,
|
|
only_need_context=True,
|
|
)
|
|
return await rag.aquery(query, param)
|
|
|
|
async def local_context_multi(self, kb_ids: List[int], query: str, *, top_k: int = 20, level: int = 2) -> Tuple[str, List[int]]:
|
|
contexts: List[str] = []
|
|
used_kb_ids: List[int] = []
|
|
for kb_id in kb_ids:
|
|
try:
|
|
ctx = await self.local_context(kb_id, query, top_k=top_k, level=level)
|
|
if ctx:
|
|
contexts.append(f"[KB:{kb_id}]\n{ctx}")
|
|
used_kb_ids.append(kb_id)
|
|
except Exception:
|
|
continue
|
|
return "\n\n".join(contexts), used_kb_ids
|
|
|
|
async def global_context_multi(self, kb_ids: List[int], query: str, *, level: int = 2) -> Tuple[str, List[int]]:
|
|
contexts: List[str] = []
|
|
used_kb_ids: List[int] = []
|
|
for kb_id in kb_ids:
|
|
try:
|
|
ctx = await self.global_context(kb_id, query, level=level)
|
|
if ctx:
|
|
contexts.append(f"[KB:{kb_id}]\n{ctx}")
|
|
used_kb_ids.append(kb_id)
|
|
except Exception:
|
|
continue
|
|
return "\n\n".join(contexts), used_kb_ids
|