init. project

This commit is contained in:
2026-04-13 11:34:23 +08:00
commit c7c0659a85
202 changed files with 31196 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from app.services.graph.graphrag_adapter import GraphRAGAdapter
__all__ = ["GraphRAGAdapter"]

View File

@@ -0,0 +1,183 @@
import asyncio
import importlib
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from app.core.config import settings
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.llm.llm_factory import LLMFactory
class GraphRAGAdapter:
_instance_lock = asyncio.Lock()
def __init__(self):
self._graphrag_instances: Dict[int, Any] = {}
self._kb_locks: Dict[int, asyncio.Lock] = {}
self._embedding_model = EmbeddingsFactory.create()
self._llm_model = LLMFactory.create(streaming=False)
self._symbols = self._load_symbols()
def _load_symbols(self) -> Dict[str, Any]:
module = importlib.import_module("nano_graphrag")
storage_module = importlib.import_module("nano_graphrag._storage")
utils_module = importlib.import_module("nano_graphrag._utils")
return {
"GraphRAG": module.GraphRAG,
"QueryParam": module.QueryParam,
"Neo4jStorage": getattr(storage_module, "Neo4jStorage"),
"NetworkXStorage": getattr(storage_module, "NetworkXStorage"),
"EmbeddingFunc": getattr(utils_module, "EmbeddingFunc"),
}
def _get_kb_lock(self, kb_id: int) -> asyncio.Lock:
if kb_id not in self._kb_locks:
self._kb_locks[kb_id] = asyncio.Lock()
return self._kb_locks[kb_id]
async def _llm_complete(self, prompt: str, system_prompt: Optional[str] = None, history_messages: Optional[List[Any]] = None, **kwargs: Any) -> str:
history_messages = history_messages or []
history_lines: List[str] = []
for item in history_messages:
if isinstance(item, dict):
role = str(item.get("role", "user"))
content = item.get("content", "")
if isinstance(content, list):
joined = " ".join(str(part.get("text", "")) for part in content if isinstance(part, dict))
history_lines.append(f"{role}: {joined}")
else:
history_lines.append(f"{role}: {content}")
else:
history_lines.append(str(item))
full_prompt = "\n\n".join(
part
for part in [
f"系统提示: {system_prompt}" if system_prompt else "",
"历史对话:\n" + "\n".join(history_lines) if history_lines else "",
"用户输入:\n" + prompt,
]
if part
)
model = self._llm_model
max_tokens = kwargs.get("max_tokens")
if max_tokens is not None:
try:
model = model.bind(max_tokens=max_tokens)
except Exception:
pass
response = await model.ainvoke(full_prompt)
content = getattr(response, "content", response)
if isinstance(content, str):
return content
return str(content)
async def _embedding_call(self, texts: List[str]) -> np.ndarray:
vectors = await asyncio.to_thread(self._embedding_model.embed_documents, texts)
return np.array(vectors)
async def _get_or_create(self, kb_id: int) -> Any:
if kb_id in self._graphrag_instances:
return self._graphrag_instances[kb_id]
async with GraphRAGAdapter._instance_lock:
if kb_id in self._graphrag_instances:
return self._graphrag_instances[kb_id]
GraphRAG = self._symbols["GraphRAG"]
EmbeddingFunc = self._symbols["EmbeddingFunc"]
embedding_func = EmbeddingFunc(
embedding_dim=settings.GRAPHRAG_EMBEDDING_DIM,
max_token_size=settings.GRAPHRAG_EMBEDDING_MAX_TOKEN_SIZE,
func=self._embedding_call,
)
graph_storage_cls = self._symbols["NetworkXStorage"]
addon_params: Dict[str, Any] = {}
if settings.GRAPHRAG_GRAPH_STORAGE.lower() == "neo4j":
graph_storage_cls = self._symbols["Neo4jStorage"]
addon_params = {
"neo4j_url": settings.NEO4J_URL,
"neo4j_auth": (settings.NEO4J_USERNAME, settings.NEO4J_PASSWORD),
}
working_dir = str(Path(settings.GRAPHRAG_WORKING_DIR) / f"kb_{kb_id}")
rag = GraphRAG(
working_dir=working_dir,
enable_local=True,
enable_naive_rag=True,
graph_storage_cls=graph_storage_cls,
addon_params=addon_params,
embedding_func=embedding_func,
best_model_func=self._llm_complete,
cheap_model_func=self._llm_complete,
entity_extract_max_gleaning=settings.GRAPHRAG_ENTITY_EXTRACT_MAX_GLEANING,
)
self._graphrag_instances[kb_id] = rag
return rag
async def ingest_texts(self, kb_id: int, texts: List[str]) -> None:
cleaned = [text.strip() for text in texts if text and text.strip()]
if not cleaned:
return
rag = await self._get_or_create(kb_id)
lock = self._get_kb_lock(kb_id)
async with lock:
await rag.ainsert(cleaned)
async def local_context(self, kb_id: int, query: str, *, top_k: int = 20, level: int = 2) -> str:
rag = await self._get_or_create(kb_id)
QueryParam = self._symbols["QueryParam"]
param = QueryParam(
mode="local",
top_k=top_k,
level=level,
only_need_context=True,
)
return await rag.aquery(query, param)
async def global_context(self, kb_id: int, query: str, *, level: int = 2) -> str:
rag = await self._get_or_create(kb_id)
QueryParam = self._symbols["QueryParam"]
param = QueryParam(
mode="global",
level=level,
only_need_context=True,
)
return await rag.aquery(query, param)
async def local_context_multi(self, kb_ids: List[int], query: str, *, top_k: int = 20, level: int = 2) -> Tuple[str, List[int]]:
contexts: List[str] = []
used_kb_ids: List[int] = []
for kb_id in kb_ids:
try:
ctx = await self.local_context(kb_id, query, top_k=top_k, level=level)
if ctx:
contexts.append(f"[KB:{kb_id}]\n{ctx}")
used_kb_ids.append(kb_id)
except Exception:
continue
return "\n\n".join(contexts), used_kb_ids
async def global_context_multi(self, kb_ids: List[int], query: str, *, level: int = 2) -> Tuple[str, List[int]]:
contexts: List[str] = []
used_kb_ids: List[int] = []
for kb_id in kb_ids:
try:
ctx = await self.global_context(kb_id, query, level=level)
if ctx:
contexts.append(f"[KB:{kb_id}]\n{ctx}")
used_kb_ids.append(kb_id)
except Exception:
continue
return "\n\n".join(contexts), used_kb_ids