init. project
This commit is contained in:
3
rag-web-ui/backend/app/services/graph/__init__.py
Normal file
3
rag-web-ui/backend/app/services/graph/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from app.services.graph.graphrag_adapter import GraphRAGAdapter
|
||||
|
||||
__all__ = ["GraphRAGAdapter"]
|
||||
183
rag-web-ui/backend/app/services/graph/graphrag_adapter.py
Normal file
183
rag-web-ui/backend/app/services/graph/graphrag_adapter.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import settings
|
||||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
|
||||
|
||||
class GraphRAGAdapter:
|
||||
_instance_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self._graphrag_instances: Dict[int, Any] = {}
|
||||
self._kb_locks: Dict[int, asyncio.Lock] = {}
|
||||
self._embedding_model = EmbeddingsFactory.create()
|
||||
self._llm_model = LLMFactory.create(streaming=False)
|
||||
self._symbols = self._load_symbols()
|
||||
|
||||
def _load_symbols(self) -> Dict[str, Any]:
|
||||
module = importlib.import_module("nano_graphrag")
|
||||
|
||||
storage_module = importlib.import_module("nano_graphrag._storage")
|
||||
utils_module = importlib.import_module("nano_graphrag._utils")
|
||||
|
||||
return {
|
||||
"GraphRAG": module.GraphRAG,
|
||||
"QueryParam": module.QueryParam,
|
||||
"Neo4jStorage": getattr(storage_module, "Neo4jStorage"),
|
||||
"NetworkXStorage": getattr(storage_module, "NetworkXStorage"),
|
||||
"EmbeddingFunc": getattr(utils_module, "EmbeddingFunc"),
|
||||
}
|
||||
|
||||
def _get_kb_lock(self, kb_id: int) -> asyncio.Lock:
|
||||
if kb_id not in self._kb_locks:
|
||||
self._kb_locks[kb_id] = asyncio.Lock()
|
||||
return self._kb_locks[kb_id]
|
||||
|
||||
async def _llm_complete(self, prompt: str, system_prompt: Optional[str] = None, history_messages: Optional[List[Any]] = None, **kwargs: Any) -> str:
|
||||
history_messages = history_messages or []
|
||||
|
||||
history_lines: List[str] = []
|
||||
for item in history_messages:
|
||||
if isinstance(item, dict):
|
||||
role = str(item.get("role", "user"))
|
||||
content = item.get("content", "")
|
||||
if isinstance(content, list):
|
||||
joined = " ".join(str(part.get("text", "")) for part in content if isinstance(part, dict))
|
||||
history_lines.append(f"{role}: {joined}")
|
||||
else:
|
||||
history_lines.append(f"{role}: {content}")
|
||||
else:
|
||||
history_lines.append(str(item))
|
||||
|
||||
full_prompt = "\n\n".join(
|
||||
part
|
||||
for part in [
|
||||
f"系统提示: {system_prompt}" if system_prompt else "",
|
||||
"历史对话:\n" + "\n".join(history_lines) if history_lines else "",
|
||||
"用户输入:\n" + prompt,
|
||||
]
|
||||
if part
|
||||
)
|
||||
|
||||
model = self._llm_model
|
||||
max_tokens = kwargs.get("max_tokens")
|
||||
if max_tokens is not None:
|
||||
try:
|
||||
model = model.bind(max_tokens=max_tokens)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
response = await model.ainvoke(full_prompt)
|
||||
content = getattr(response, "content", response)
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return str(content)
|
||||
|
||||
async def _embedding_call(self, texts: List[str]) -> np.ndarray:
|
||||
vectors = await asyncio.to_thread(self._embedding_model.embed_documents, texts)
|
||||
return np.array(vectors)
|
||||
|
||||
async def _get_or_create(self, kb_id: int) -> Any:
|
||||
if kb_id in self._graphrag_instances:
|
||||
return self._graphrag_instances[kb_id]
|
||||
|
||||
async with GraphRAGAdapter._instance_lock:
|
||||
if kb_id in self._graphrag_instances:
|
||||
return self._graphrag_instances[kb_id]
|
||||
|
||||
GraphRAG = self._symbols["GraphRAG"]
|
||||
EmbeddingFunc = self._symbols["EmbeddingFunc"]
|
||||
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=settings.GRAPHRAG_EMBEDDING_DIM,
|
||||
max_token_size=settings.GRAPHRAG_EMBEDDING_MAX_TOKEN_SIZE,
|
||||
func=self._embedding_call,
|
||||
)
|
||||
|
||||
graph_storage_cls = self._symbols["NetworkXStorage"]
|
||||
addon_params: Dict[str, Any] = {}
|
||||
if settings.GRAPHRAG_GRAPH_STORAGE.lower() == "neo4j":
|
||||
graph_storage_cls = self._symbols["Neo4jStorage"]
|
||||
addon_params = {
|
||||
"neo4j_url": settings.NEO4J_URL,
|
||||
"neo4j_auth": (settings.NEO4J_USERNAME, settings.NEO4J_PASSWORD),
|
||||
}
|
||||
|
||||
working_dir = str(Path(settings.GRAPHRAG_WORKING_DIR) / f"kb_{kb_id}")
|
||||
|
||||
rag = GraphRAG(
|
||||
working_dir=working_dir,
|
||||
enable_local=True,
|
||||
enable_naive_rag=True,
|
||||
graph_storage_cls=graph_storage_cls,
|
||||
addon_params=addon_params,
|
||||
embedding_func=embedding_func,
|
||||
best_model_func=self._llm_complete,
|
||||
cheap_model_func=self._llm_complete,
|
||||
entity_extract_max_gleaning=settings.GRAPHRAG_ENTITY_EXTRACT_MAX_GLEANING,
|
||||
)
|
||||
self._graphrag_instances[kb_id] = rag
|
||||
return rag
|
||||
|
||||
async def ingest_texts(self, kb_id: int, texts: List[str]) -> None:
|
||||
cleaned = [text.strip() for text in texts if text and text.strip()]
|
||||
if not cleaned:
|
||||
return
|
||||
|
||||
rag = await self._get_or_create(kb_id)
|
||||
lock = self._get_kb_lock(kb_id)
|
||||
async with lock:
|
||||
await rag.ainsert(cleaned)
|
||||
|
||||
async def local_context(self, kb_id: int, query: str, *, top_k: int = 20, level: int = 2) -> str:
|
||||
rag = await self._get_or_create(kb_id)
|
||||
QueryParam = self._symbols["QueryParam"]
|
||||
param = QueryParam(
|
||||
mode="local",
|
||||
top_k=top_k,
|
||||
level=level,
|
||||
only_need_context=True,
|
||||
)
|
||||
return await rag.aquery(query, param)
|
||||
|
||||
async def global_context(self, kb_id: int, query: str, *, level: int = 2) -> str:
|
||||
rag = await self._get_or_create(kb_id)
|
||||
QueryParam = self._symbols["QueryParam"]
|
||||
param = QueryParam(
|
||||
mode="global",
|
||||
level=level,
|
||||
only_need_context=True,
|
||||
)
|
||||
return await rag.aquery(query, param)
|
||||
|
||||
async def local_context_multi(self, kb_ids: List[int], query: str, *, top_k: int = 20, level: int = 2) -> Tuple[str, List[int]]:
|
||||
contexts: List[str] = []
|
||||
used_kb_ids: List[int] = []
|
||||
for kb_id in kb_ids:
|
||||
try:
|
||||
ctx = await self.local_context(kb_id, query, top_k=top_k, level=level)
|
||||
if ctx:
|
||||
contexts.append(f"[KB:{kb_id}]\n{ctx}")
|
||||
used_kb_ids.append(kb_id)
|
||||
except Exception:
|
||||
continue
|
||||
return "\n\n".join(contexts), used_kb_ids
|
||||
|
||||
async def global_context_multi(self, kb_ids: List[int], query: str, *, level: int = 2) -> Tuple[str, List[int]]:
|
||||
contexts: List[str] = []
|
||||
used_kb_ids: List[int] = []
|
||||
for kb_id in kb_ids:
|
||||
try:
|
||||
ctx = await self.global_context(kb_id, query, level=level)
|
||||
if ctx:
|
||||
contexts.append(f"[KB:{kb_id}]\n{ctx}")
|
||||
used_kb_ids.append(kb_id)
|
||||
except Exception:
|
||||
continue
|
||||
return "\n\n".join(contexts), used_kb_ids
|
||||
Reference in New Issue
Block a user