init. project
This commit is contained in:
7
rag-web-ui/backend/nano_graphrag/__init__.py
Normal file
7
rag-web-ui/backend/nano_graphrag/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .graphrag import GraphRAG, QueryParam
|
||||
|
||||
__version__ = "0.0.8.2"
|
||||
__author__ = "Jianbai Ye"
|
||||
__url__ = "https://github.com/gusye1234/nano-graphrag"
|
||||
|
||||
# dp stands for data pack
|
||||
301
rag-web-ui/backend/nano_graphrag/_llm.py
Normal file
301
rag-web-ui/backend/nano_graphrag/_llm.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import Optional, List, Any, Callable
|
||||
|
||||
try:
|
||||
import aioboto3
|
||||
except ImportError:
|
||||
aioboto3 = None
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
import os
|
||||
|
||||
from ._utils import compute_args_hash, wrap_embedding_func_with_attrs
|
||||
from .base import BaseKVStorage
|
||||
|
||||
global_openai_async_client = None
|
||||
global_azure_openai_async_client = None
|
||||
global_amazon_bedrock_async_client = None
|
||||
|
||||
|
||||
def get_openai_async_client_instance():
|
||||
global global_openai_async_client
|
||||
if global_openai_async_client is None:
|
||||
global_openai_async_client = AsyncOpenAI()
|
||||
return global_openai_async_client
|
||||
|
||||
|
||||
def get_azure_openai_async_client_instance():
|
||||
global global_azure_openai_async_client
|
||||
if global_azure_openai_async_client is None:
|
||||
global_azure_openai_async_client = AsyncAzureOpenAI()
|
||||
return global_azure_openai_async_client
|
||||
|
||||
|
||||
def get_amazon_bedrock_async_client_instance():
|
||||
global global_amazon_bedrock_async_client
|
||||
if aioboto3 is None:
|
||||
raise ImportError(
|
||||
"aioboto3 is required for Amazon Bedrock support. Install it to use Bedrock providers."
|
||||
)
|
||||
if global_amazon_bedrock_async_client is None:
|
||||
global_amazon_bedrock_async_client = aioboto3.Session()
|
||||
return global_amazon_bedrock_async_client
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
||||
)
|
||||
async def openai_complete_if_cache(
|
||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
openai_async_client = get_openai_async_client_instance()
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
response = await openai_async_client.chat.completions.create(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{args_hash: {"return": response.choices[0].message.content, "model": model}}
|
||||
)
|
||||
await hashing_kv.index_done_callback()
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
||||
)
|
||||
async def amazon_bedrock_complete_if_cache(
|
||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance()
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": [{"text": prompt}]})
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
inference_config = {
|
||||
"temperature": 0,
|
||||
"maxTokens": 4096 if "max_tokens" not in kwargs else kwargs["max_tokens"],
|
||||
}
|
||||
|
||||
async with amazon_bedrock_async_client.client(
|
||||
"bedrock-runtime",
|
||||
region_name=os.getenv("AWS_REGION", "us-east-1")
|
||||
) as bedrock_runtime:
|
||||
if system_prompt:
|
||||
response = await bedrock_runtime.converse(
|
||||
modelId=model, messages=messages, inferenceConfig=inference_config,
|
||||
system=[{"text": system_prompt}]
|
||||
)
|
||||
else:
|
||||
response = await bedrock_runtime.converse(
|
||||
modelId=model, messages=messages, inferenceConfig=inference_config,
|
||||
)
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{args_hash: {"return": response["output"]["message"]["content"][0]["text"], "model": model}}
|
||||
)
|
||||
await hashing_kv.index_done_callback()
|
||||
return response["output"]["message"]["content"][0]["text"]
|
||||
|
||||
|
||||
def create_amazon_bedrock_complete_function(model_id: str) -> Callable:
|
||||
"""
|
||||
Factory function to dynamically create completion functions for Amazon Bedrock
|
||||
|
||||
Args:
|
||||
model_id (str): Amazon Bedrock model identifier (e.g., "us.anthropic.claude-3-sonnet-20240229-v1:0")
|
||||
|
||||
Returns:
|
||||
Callable: Generated completion function
|
||||
"""
|
||||
async def bedrock_complete(
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
history_messages: List[Any] = [],
|
||||
**kwargs
|
||||
) -> str:
|
||||
return await amazon_bedrock_complete_if_cache(
|
||||
model_id,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Set function name for easier debugging
|
||||
bedrock_complete.__name__ = f"{model_id}_complete"
|
||||
|
||||
return bedrock_complete
|
||||
|
||||
|
||||
async def gpt_4o_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await openai_complete_if_cache(
|
||||
"gpt-4o",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def gpt_4o_mini_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await openai_complete_if_cache(
|
||||
"gpt-4o-mini",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
||||
)
|
||||
async def amazon_bedrock_embedding(texts: list[str]) -> np.ndarray:
|
||||
amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance()
|
||||
|
||||
async with amazon_bedrock_async_client.client(
|
||||
"bedrock-runtime",
|
||||
region_name=os.getenv("AWS_REGION", "us-east-1")
|
||||
) as bedrock_runtime:
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
body = json.dumps(
|
||||
{
|
||||
"inputText": text,
|
||||
"dimensions": 1024,
|
||||
}
|
||||
)
|
||||
response = await bedrock_runtime.invoke_model(
|
||||
modelId="amazon.titan-embed-text-v2:0", body=body,
|
||||
)
|
||||
response_body = await response.get("body").read()
|
||||
embeddings.append(json.loads(response_body))
|
||||
return np.array([dp["embedding"] for dp in embeddings])
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
||||
)
|
||||
async def openai_embedding(texts: list[str]) -> np.ndarray:
|
||||
openai_async_client = get_openai_async_client_instance()
|
||||
response = await openai_async_client.embeddings.create(
|
||||
model="text-embedding-3-small", input=texts, encoding_format="float"
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
||||
)
|
||||
async def azure_openai_complete_if_cache(
|
||||
deployment_name, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
azure_openai_client = get_azure_openai_async_client_instance()
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(deployment_name, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
response = await azure_openai_client.chat.completions.create(
|
||||
model=deployment_name, messages=messages, **kwargs
|
||||
)
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{
|
||||
args_hash: {
|
||||
"return": response.choices[0].message.content,
|
||||
"model": deployment_name,
|
||||
}
|
||||
}
|
||||
)
|
||||
await hashing_kv.index_done_callback()
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
async def azure_gpt_4o_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await azure_openai_complete_if_cache(
|
||||
"gpt-4o",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def azure_gpt_4o_mini_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await azure_openai_complete_if_cache(
|
||||
"gpt-4o-mini",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
|
||||
)
|
||||
async def azure_openai_embedding(texts: list[str]) -> np.ndarray:
|
||||
azure_openai_client = get_azure_openai_async_client_instance()
|
||||
response = await azure_openai_client.embeddings.create(
|
||||
model="text-embedding-3-small", input=texts, encoding_format="float"
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
1140
rag-web-ui/backend/nano_graphrag/_op.py
Normal file
1140
rag-web-ui/backend/nano_graphrag/_op.py
Normal file
File diff suppressed because it is too large
Load Diff
94
rag-web-ui/backend/nano_graphrag/_splitter.py
Normal file
94
rag-web-ui/backend/nano_graphrag/_splitter.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from typing import List, Optional, Union, Literal
|
||||
|
||||
class SeparatorSplitter:
|
||||
def __init__(
|
||||
self,
|
||||
separators: Optional[List[List[int]]] = None,
|
||||
keep_separator: Union[bool, Literal["start", "end"]] = "end",
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
length_function: callable = len,
|
||||
):
|
||||
self._separators = separators or []
|
||||
self._keep_separator = keep_separator
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._length_function = length_function
|
||||
|
||||
def split_tokens(self, tokens: List[int]) -> List[List[int]]:
|
||||
splits = self._split_tokens_with_separators(tokens)
|
||||
return self._merge_splits(splits)
|
||||
|
||||
def _split_tokens_with_separators(self, tokens: List[int]) -> List[List[int]]:
|
||||
splits = []
|
||||
current_split = []
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
separator_found = False
|
||||
for separator in self._separators:
|
||||
if tokens[i:i+len(separator)] == separator:
|
||||
if self._keep_separator in [True, "end"]:
|
||||
current_split.extend(separator)
|
||||
if current_split:
|
||||
splits.append(current_split)
|
||||
current_split = []
|
||||
if self._keep_separator == "start":
|
||||
current_split.extend(separator)
|
||||
i += len(separator)
|
||||
separator_found = True
|
||||
break
|
||||
if not separator_found:
|
||||
current_split.append(tokens[i])
|
||||
i += 1
|
||||
if current_split:
|
||||
splits.append(current_split)
|
||||
return [s for s in splits if s]
|
||||
|
||||
def _merge_splits(self, splits: List[List[int]]) -> List[List[int]]:
|
||||
if not splits:
|
||||
return []
|
||||
|
||||
merged_splits = []
|
||||
current_chunk = []
|
||||
|
||||
for split in splits:
|
||||
if not current_chunk:
|
||||
current_chunk = split
|
||||
elif self._length_function(current_chunk) + self._length_function(split) <= self._chunk_size:
|
||||
current_chunk.extend(split)
|
||||
else:
|
||||
merged_splits.append(current_chunk)
|
||||
current_chunk = split
|
||||
|
||||
if current_chunk:
|
||||
merged_splits.append(current_chunk)
|
||||
|
||||
if len(merged_splits) == 1 and self._length_function(merged_splits[0]) > self._chunk_size:
|
||||
return self._split_chunk(merged_splits[0])
|
||||
|
||||
if self._chunk_overlap > 0:
|
||||
return self._enforce_overlap(merged_splits)
|
||||
|
||||
return merged_splits
|
||||
|
||||
def _split_chunk(self, chunk: List[int]) -> List[List[int]]:
|
||||
result = []
|
||||
for i in range(0, len(chunk), self._chunk_size - self._chunk_overlap):
|
||||
new_chunk = chunk[i:i + self._chunk_size]
|
||||
if len(new_chunk) > self._chunk_overlap: # 只有当 chunk 长度大于 overlap 时才添加
|
||||
result.append(new_chunk)
|
||||
return result
|
||||
|
||||
def _enforce_overlap(self, chunks: List[List[int]]) -> List[List[int]]:
|
||||
result = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i == 0:
|
||||
result.append(chunk)
|
||||
else:
|
||||
overlap = chunks[i-1][-self._chunk_overlap:]
|
||||
new_chunk = overlap + chunk
|
||||
if self._length_function(new_chunk) > self._chunk_size:
|
||||
new_chunk = new_chunk[:self._chunk_size]
|
||||
result.append(new_chunk)
|
||||
return result
|
||||
|
||||
9
rag-web-ui/backend/nano_graphrag/_storage/__init__.py
Normal file
9
rag-web-ui/backend/nano_graphrag/_storage/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .gdb_networkx import NetworkXStorage
|
||||
from .gdb_neo4j import Neo4jStorage
|
||||
from .vdb_nanovectordb import NanoVectorDBStorage
|
||||
from .kv_json import JsonKVStorage
|
||||
|
||||
try:
|
||||
from .vdb_hnswlib import HNSWVectorStorage
|
||||
except ImportError:
|
||||
HNSWVectorStorage = None
|
||||
529
rag-web-ui/backend/nano_graphrag/_storage/gdb_neo4j.py
Normal file
529
rag-web-ui/backend/nano_graphrag/_storage/gdb_neo4j.py
Normal file
@@ -0,0 +1,529 @@
|
||||
import json
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
from neo4j import AsyncGraphDatabase
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
from ..base import BaseGraphStorage, SingleCommunitySchema
|
||||
from .._utils import logger
|
||||
from ..prompt import GRAPH_FIELD_SEP
|
||||
|
||||
neo4j_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def make_path_idable(path):
|
||||
return path.replace(".", "_").replace("/", "__").replace("-", "_").replace(":", "_").replace("\\", "__")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Neo4jStorage(BaseGraphStorage):
|
||||
def __post_init__(self):
|
||||
self.neo4j_url = self.global_config["addon_params"].get("neo4j_url", None)
|
||||
self.neo4j_auth = self.global_config["addon_params"].get("neo4j_auth", None)
|
||||
self.namespace = (
|
||||
f"{make_path_idable(self.global_config['working_dir'])}__{self.namespace}"
|
||||
)
|
||||
logger.info(f"Using the label {self.namespace} for Neo4j as identifier")
|
||||
if self.neo4j_url is None or self.neo4j_auth is None:
|
||||
raise ValueError("Missing neo4j_url or neo4j_auth in addon_params")
|
||||
self.async_driver = AsyncGraphDatabase.driver(
|
||||
self.neo4j_url, auth=self.neo4j_auth, max_connection_pool_size=50,
|
||||
)
|
||||
|
||||
# async def create_database(self):
|
||||
# async with self.async_driver.session() as session:
|
||||
# try:
|
||||
# constraints = await session.run("SHOW CONSTRAINTS")
|
||||
# # TODO I don't know why CREATE CONSTRAINT IF NOT EXISTS still trigger error
|
||||
# # so have to check if the constrain exists
|
||||
# constrain_exists = False
|
||||
|
||||
# async for record in constraints:
|
||||
# if (
|
||||
# self.namespace in record["labelsOrTypes"]
|
||||
# and "id" in record["properties"]
|
||||
# and record["type"] == "UNIQUENESS"
|
||||
# ):
|
||||
# constrain_exists = True
|
||||
# break
|
||||
# if not constrain_exists:
|
||||
# await session.run(
|
||||
# f"CREATE CONSTRAINT FOR (n:{self.namespace}) REQUIRE n.id IS UNIQUE"
|
||||
# )
|
||||
# logger.info(f"Add constraint for namespace: {self.namespace}")
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error accessing or setting up the database: {str(e)}")
|
||||
# raise
|
||||
|
||||
async def _init_workspace(self):
|
||||
await self.async_driver.verify_authentication()
|
||||
await self.async_driver.verify_connectivity()
|
||||
# TODOLater: create database if not exists always cause an error when async
|
||||
# await self.create_database()
|
||||
|
||||
async def index_start_callback(self):
|
||||
logger.info("Init Neo4j workspace")
|
||||
await self._init_workspace()
|
||||
|
||||
# create index for faster searching
|
||||
try:
|
||||
async with self.async_driver.session() as session:
|
||||
await session.run(
|
||||
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.id)"
|
||||
)
|
||||
|
||||
await session.run(
|
||||
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.entity_type)"
|
||||
)
|
||||
|
||||
await session.run(
|
||||
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.communityIds)"
|
||||
)
|
||||
|
||||
await session.run(
|
||||
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.source_id)"
|
||||
)
|
||||
logger.info("Neo4j indexes created successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create indexes: {e}")
|
||||
raise e
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
async with self.async_driver.session() as session:
|
||||
result = await session.run(
|
||||
f"MATCH (n:`{self.namespace}`) WHERE n.id = $node_id RETURN COUNT(n) > 0 AS exists",
|
||||
node_id=node_id,
|
||||
)
|
||||
record = await result.single()
|
||||
return record["exists"] if record else False
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
async with self.async_driver.session() as session:
|
||||
result = await session.run(
|
||||
f"""
|
||||
MATCH (s:`{self.namespace}`)
|
||||
WHERE s.id = $source_id
|
||||
MATCH (t:`{self.namespace}`)
|
||||
WHERE t.id = $target_id
|
||||
RETURN EXISTS((s)-[]->(t)) AS exists
|
||||
""",
|
||||
source_id=source_node_id,
|
||||
target_id=target_node_id,
|
||||
)
|
||||
|
||||
record = await result.single()
|
||||
return record["exists"] if record else False
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
results = await self.node_degrees_batch([node_id])
|
||||
return results[0] if results else 0
|
||||
|
||||
async def node_degrees_batch(self, node_ids: List[str]) -> List[str]:
|
||||
if not node_ids:
|
||||
return {}
|
||||
|
||||
result_dict = {node_id: 0 for node_id in node_ids}
|
||||
async with self.async_driver.session() as session:
|
||||
result = await session.run(
|
||||
f"""
|
||||
UNWIND $node_ids AS node_id
|
||||
MATCH (n:`{self.namespace}`)
|
||||
WHERE n.id = node_id
|
||||
OPTIONAL MATCH (n)-[]-(m:`{self.namespace}`)
|
||||
RETURN node_id, COUNT(m) AS degree
|
||||
""",
|
||||
node_ids=node_ids
|
||||
)
|
||||
|
||||
async for record in result:
|
||||
result_dict[record["node_id"]] = record["degree"]
|
||||
|
||||
return [result_dict[node_id] for node_id in node_ids]
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
results = await self.edge_degrees_batch([(src_id, tgt_id)])
|
||||
return results[0] if results else 0
|
||||
|
||||
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> list[int]:
|
||||
if not edge_pairs:
|
||||
return []
|
||||
|
||||
result_dict = {tuple(edge_pair): 0 for edge_pair in edge_pairs}
|
||||
|
||||
edges_params = [{"src_id": src, "tgt_id": tgt} for src, tgt in edge_pairs]
|
||||
|
||||
try:
|
||||
async with self.async_driver.session() as session:
|
||||
result = await session.run(
|
||||
f"""
|
||||
UNWIND $edges AS edge
|
||||
|
||||
MATCH (s:`{self.namespace}`)
|
||||
WHERE s.id = edge.src_id
|
||||
WITH edge, s
|
||||
OPTIONAL MATCH (s)-[]-(n1:`{self.namespace}`)
|
||||
WITH edge, COUNT(n1) AS src_degree
|
||||
|
||||
MATCH (t:`{self.namespace}`)
|
||||
WHERE t.id = edge.tgt_id
|
||||
WITH edge, src_degree, t
|
||||
OPTIONAL MATCH (t)-[]-(n2:`{self.namespace}`)
|
||||
WITH edge.src_id AS src_id, edge.tgt_id AS tgt_id, src_degree, COUNT(n2) AS tgt_degree
|
||||
|
||||
RETURN src_id, tgt_id, src_degree + tgt_degree AS degree
|
||||
""",
|
||||
edges=edges_params
|
||||
)
|
||||
|
||||
async for record in result:
|
||||
src_id = record["src_id"]
|
||||
tgt_id = record["tgt_id"]
|
||||
degree = record["degree"]
|
||||
|
||||
# 更新结果字典
|
||||
edge_pair = (src_id, tgt_id)
|
||||
result_dict[edge_pair] = degree
|
||||
|
||||
return [result_dict[tuple(edge_pair)] for edge_pair in edge_pairs]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch edge degree calculation: {e}")
|
||||
return [0] * len(edge_pairs)
|
||||
|
||||
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
result = await self.get_nodes_batch([node_id])
|
||||
return result[0] if result else None
|
||||
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, Union[dict, None]]:
|
||||
if not node_ids:
|
||||
return {}
|
||||
|
||||
result_dict = {node_id: None for node_id in node_ids}
|
||||
|
||||
try:
|
||||
async with self.async_driver.session() as session:
|
||||
result = await session.run(
|
||||
f"""
|
||||
UNWIND $node_ids AS node_id
|
||||
MATCH (n:`{self.namespace}`)
|
||||
WHERE n.id = node_id
|
||||
RETURN node_id, properties(n) AS node_data
|
||||
""",
|
||||
node_ids=node_ids
|
||||
)
|
||||
|
||||
async for record in result:
|
||||
node_id = record["node_id"]
|
||||
raw_node_data = record["node_data"]
|
||||
|
||||
if raw_node_data:
|
||||
raw_node_data["clusters"] = json.dumps(
|
||||
[
|
||||
{
|
||||
"level": index,
|
||||
"cluster": cluster_id,
|
||||
}
|
||||
for index, cluster_id in enumerate(
|
||||
raw_node_data.get("communityIds", [])
|
||||
)
|
||||
]
|
||||
)
|
||||
result_dict[node_id] = raw_node_data
|
||||
return [result_dict[node_id] for node_id in node_ids]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch node retrieval: {e}")
|
||||
raise e
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
results = await self.get_edges_batch([(source_node_id, target_node_id)])
|
||||
return results[0] if results else None
|
||||
|
||||
async def get_edges_batch(
|
||||
self, edge_pairs: list[tuple[str, str]]
|
||||
) -> list[Union[dict, None]]:
|
||||
if not edge_pairs:
|
||||
return []
|
||||
|
||||
result_dict = {tuple(edge_pair): None for edge_pair in edge_pairs}
|
||||
|
||||
edges_params = [{"source_id": src, "target_id": tgt} for src, tgt in edge_pairs]
|
||||
|
||||
try:
|
||||
async with self.async_driver.session() as session:
|
||||
result = await session.run(
|
||||
f"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (s:`{self.namespace}`)-[r]->(t:`{self.namespace}`)
|
||||
WHERE s.id = edge.source_id AND t.id = edge.target_id
|
||||
RETURN edge.source_id AS source_id, edge.target_id AS target_id, properties(r) AS edge_data
|
||||
""",
|
||||
edges=edges_params
|
||||
)
|
||||
|
||||
async for record in result:
|
||||
source_id = record["source_id"]
|
||||
target_id = record["target_id"]
|
||||
edge_data = record["edge_data"]
|
||||
|
||||
edge_pair = (source_id, target_id)
|
||||
result_dict[edge_pair] = edge_data
|
||||
|
||||
return [result_dict[tuple(edge_pair)] for edge_pair in edge_pairs]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch edge retrieval: {e}")
|
||||
return [None] * len(edge_pairs)
|
||||
|
||||
async def get_node_edges(
|
||||
self, source_node_id: str
|
||||
) -> list[tuple[str, str]]:
|
||||
results = await self.get_nodes_edges_batch([source_node_id])
|
||||
return results[0] if results else []
|
||||
|
||||
async def get_nodes_edges_batch(
|
||||
self, node_ids: list[str]
|
||||
) -> list[list[tuple[str, str]]]:
|
||||
if not node_ids:
|
||||
return []
|
||||
|
||||
result_dict = {node_id: [] for node_id in node_ids}
|
||||
|
||||
try:
|
||||
async with self.async_driver.session() as session:
|
||||
result = await session.run(
|
||||
f"""
|
||||
UNWIND $node_ids AS node_id
|
||||
MATCH (s:`{self.namespace}`)-[r]->(t:`{self.namespace}`)
|
||||
WHERE s.id = node_id
|
||||
RETURN s.id AS source_id, t.id AS target_id
|
||||
""",
|
||||
node_ids=node_ids
|
||||
)
|
||||
|
||||
async for record in result:
|
||||
source_id = record["source_id"]
|
||||
target_id = record["target_id"]
|
||||
|
||||
if source_id in result_dict:
|
||||
result_dict[source_id].append((source_id, target_id))
|
||||
|
||||
return [result_dict[node_id] for node_id in node_ids]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch node edges retrieval: {e}")
|
||||
return [[] for _ in node_ids]
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
await self.upsert_nodes_batch([(node_id, node_data)])
|
||||
|
||||
async def upsert_nodes_batch(self, nodes_data: list[tuple[str, dict[str, str]]]):
|
||||
if not nodes_data:
|
||||
return []
|
||||
|
||||
nodes_by_type = {}
|
||||
for node_id, node_data in nodes_data:
|
||||
node_type = node_data.get("entity_type", "UNKNOWN").strip('"')
|
||||
if node_type not in nodes_by_type:
|
||||
nodes_by_type[node_type] = []
|
||||
nodes_by_type[node_type].append((node_id, node_data))
|
||||
|
||||
async with self.async_driver.session() as session:
|
||||
for node_type, type_nodes in nodes_by_type.items():
|
||||
params = [{"id": node_id, "data": node_data} for node_id, node_data in type_nodes]
|
||||
|
||||
await session.run(
|
||||
f"""
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:`{self.namespace}`:`{node_type}` {{id: node.id}})
|
||||
SET n += node.data
|
||||
""",
|
||||
nodes=params
|
||||
)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
await self.upsert_edges_batch([(source_node_id, target_node_id, edge_data)])
|
||||
|
||||
|
||||
async def upsert_edges_batch(
|
||||
self, edges_data: list[tuple[str, str, dict[str, str]]]
|
||||
):
|
||||
if not edges_data:
|
||||
return
|
||||
|
||||
edges_params = []
|
||||
for source_id, target_id, edge_data in edges_data:
|
||||
edge_data_copy = edge_data.copy()
|
||||
edge_data_copy.setdefault("weight", 0.0)
|
||||
|
||||
edges_params.append({
|
||||
"source_id": source_id,
|
||||
"target_id": target_id,
|
||||
"edge_data": edge_data_copy
|
||||
})
|
||||
|
||||
async with self.async_driver.session() as session:
|
||||
await session.run(
|
||||
f"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (s:`{self.namespace}`)
|
||||
WHERE s.id = edge.source_id
|
||||
WITH edge, s
|
||||
MATCH (t:`{self.namespace}`)
|
||||
WHERE t.id = edge.target_id
|
||||
MERGE (s)-[r:RELATED]->(t)
|
||||
SET r += edge.edge_data
|
||||
""",
|
||||
edges=edges_params
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
async def clustering(self, algorithm: str):
|
||||
if algorithm != "leiden":
|
||||
raise ValueError(
|
||||
f"Clustering algorithm {algorithm} not supported in Neo4j implementation"
|
||||
)
|
||||
|
||||
random_seed = self.global_config["graph_cluster_seed"]
|
||||
max_level = self.global_config["max_graph_cluster_size"]
|
||||
async with self.async_driver.session() as session:
|
||||
try:
|
||||
# Project the graph with undirected relationships
|
||||
await session.run(
|
||||
f"""
|
||||
CALL gds.graph.project(
|
||||
'graph_{self.namespace}',
|
||||
['{self.namespace}'],
|
||||
{{
|
||||
RELATED: {{
|
||||
orientation: 'UNDIRECTED',
|
||||
properties: ['weight']
|
||||
}}
|
||||
}}
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Run Leiden algorithm
|
||||
result = await session.run(
|
||||
f"""
|
||||
CALL gds.leiden.write(
|
||||
'graph_{self.namespace}',
|
||||
{{
|
||||
writeProperty: 'communityIds',
|
||||
includeIntermediateCommunities: True,
|
||||
relationshipWeightProperty: "weight",
|
||||
maxLevels: {max_level},
|
||||
tolerance: 0.0001,
|
||||
gamma: 1.0,
|
||||
theta: 0.01,
|
||||
randomSeed: {random_seed}
|
||||
}}
|
||||
)
|
||||
YIELD communityCount, modularities;
|
||||
"""
|
||||
)
|
||||
result = await result.single()
|
||||
community_count: int = result["communityCount"]
|
||||
modularities = result["modularities"]
|
||||
logger.info(
|
||||
f"Performed graph clustering with {community_count} communities and modularities {modularities}"
|
||||
)
|
||||
finally:
|
||||
# Drop the projected graph
|
||||
await session.run(f"CALL gds.graph.drop('graph_{self.namespace}')")
|
||||
|
||||
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
|
||||
results = defaultdict(
|
||||
lambda: dict(
|
||||
level=None,
|
||||
title=None,
|
||||
edges=set(),
|
||||
nodes=set(),
|
||||
chunk_ids=set(),
|
||||
occurrence=0.0,
|
||||
sub_communities=[],
|
||||
)
|
||||
)
|
||||
|
||||
async with self.async_driver.session() as session:
|
||||
# Fetch community data
|
||||
result = await session.run(
|
||||
f"""
|
||||
MATCH (n:`{self.namespace}`)
|
||||
WITH n, n.communityIds AS communityIds, [(n)-[]-(m:`{self.namespace}`) | m.id] AS connected_nodes
|
||||
RETURN n.id AS node_id, n.source_id AS source_id,
|
||||
communityIds AS cluster_key,
|
||||
connected_nodes
|
||||
"""
|
||||
)
|
||||
|
||||
# records = await result.fetch()
|
||||
|
||||
max_num_ids = 0
|
||||
async for record in result:
|
||||
for index, c_id in enumerate(record["cluster_key"]):
|
||||
node_id = str(record["node_id"])
|
||||
source_id = record["source_id"]
|
||||
level = index
|
||||
cluster_key = str(c_id)
|
||||
connected_nodes = record["connected_nodes"]
|
||||
|
||||
results[cluster_key]["level"] = level
|
||||
results[cluster_key]["title"] = f"Cluster {cluster_key}"
|
||||
results[cluster_key]["nodes"].add(node_id)
|
||||
results[cluster_key]["edges"].update(
|
||||
[
|
||||
tuple(sorted([node_id, str(connected)]))
|
||||
for connected in connected_nodes
|
||||
if connected != node_id
|
||||
]
|
||||
)
|
||||
chunk_ids = source_id.split(GRAPH_FIELD_SEP)
|
||||
results[cluster_key]["chunk_ids"].update(chunk_ids)
|
||||
max_num_ids = max(
|
||||
max_num_ids, len(results[cluster_key]["chunk_ids"])
|
||||
)
|
||||
|
||||
# Process results
|
||||
for k, v in results.items():
|
||||
v["edges"] = [list(e) for e in v["edges"]]
|
||||
v["nodes"] = list(v["nodes"])
|
||||
v["chunk_ids"] = list(v["chunk_ids"])
|
||||
v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
|
||||
|
||||
# Compute sub-communities (this is a simplified approach)
|
||||
for cluster in results.values():
|
||||
cluster["sub_communities"] = [
|
||||
sub_key
|
||||
for sub_key, sub_cluster in results.items()
|
||||
if sub_cluster["level"] > cluster["level"]
|
||||
and set(sub_cluster["nodes"]).issubset(set(cluster["nodes"]))
|
||||
]
|
||||
|
||||
return dict(results)
|
||||
|
||||
async def index_done_callback(self):
|
||||
await self.async_driver.close()
|
||||
|
||||
async def _debug_delete_all_node_edges(self):
|
||||
async with self.async_driver.session() as session:
|
||||
try:
|
||||
# Delete all relationships in the namespace
|
||||
await session.run(f"MATCH (n:`{self.namespace}`)-[r]-() DELETE r")
|
||||
|
||||
# Delete all nodes in the namespace
|
||||
await session.run(f"MATCH (n:`{self.namespace}`) DELETE n")
|
||||
|
||||
logger.info(
|
||||
f"All nodes and edges in namespace '{self.namespace}' have been deleted."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting nodes and edges: {str(e)}")
|
||||
raise
|
||||
268
rag-web-ui/backend/nano_graphrag/_storage/gdb_networkx.py
Normal file
268
rag-web-ui/backend/nano_graphrag/_storage/gdb_networkx.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import html
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, cast, List
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import asyncio
|
||||
|
||||
from .._utils import logger
|
||||
from ..base import (
|
||||
BaseGraphStorage,
|
||||
SingleCommunitySchema,
|
||||
)
|
||||
from ..prompt import GRAPH_FIELD_SEP
|
||||
|
||||
|
||||
@dataclass
|
||||
class NetworkXStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name) -> nx.Graph:
|
||||
if os.path.exists(file_name):
|
||||
return nx.read_graphml(file_name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def write_nx_graph(graph: nx.Graph, file_name):
|
||||
logger.info(
|
||||
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
|
||||
)
|
||||
nx.write_graphml(graph, file_name)
|
||||
|
||||
@staticmethod
|
||||
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
|
||||
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
||||
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
|
||||
"""
|
||||
from graspologic.utils import largest_connected_component
|
||||
|
||||
graph = graph.copy()
|
||||
graph = cast(nx.Graph, largest_connected_component(graph))
|
||||
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
|
||||
graph = nx.relabel_nodes(graph, node_mapping)
|
||||
return NetworkXStorage._stabilize_graph(graph)
|
||||
|
||||
@staticmethod
|
||||
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
||||
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
||||
Ensure an undirected graph with the same relationships will always be read the same way.
|
||||
"""
|
||||
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
|
||||
|
||||
sorted_nodes = graph.nodes(data=True)
|
||||
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
|
||||
|
||||
fixed_graph.add_nodes_from(sorted_nodes)
|
||||
edges = list(graph.edges(data=True))
|
||||
|
||||
if not graph.is_directed():
|
||||
|
||||
def _sort_source_target(edge):
|
||||
source, target, edge_data = edge
|
||||
if source > target:
|
||||
temp = source
|
||||
source = target
|
||||
target = temp
|
||||
return source, target, edge_data
|
||||
|
||||
edges = [_sort_source_target(edge) for edge in edges]
|
||||
|
||||
def _get_edge_key(source: Any, target: Any) -> str:
|
||||
return f"{source} -> {target}"
|
||||
|
||||
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
|
||||
|
||||
fixed_graph.add_edges_from(edges)
|
||||
return fixed_graph
|
||||
|
||||
def __post_init__(self):
|
||||
self._graphml_xml_file = os.path.join(
|
||||
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
||||
)
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
||||
if preloaded_graph is not None:
|
||||
logger.info(
|
||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||
)
|
||||
self._graph = preloaded_graph or nx.Graph()
|
||||
self._clustering_algorithms = {
|
||||
"leiden": self._leiden_clustering,
|
||||
}
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def index_done_callback(self):
|
||||
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
return self._graph.has_node(node_id)
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
return self._graph.has_edge(source_node_id, target_node_id)
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
return self._graph.nodes.get(node_id)
|
||||
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, Union[dict, None]]:
|
||||
return await asyncio.gather(*[self.get_node(node_id) for node_id in node_ids])
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
# [numberchiffre]: node_id not part of graph returns `DegreeView({})` instead of 0
|
||||
return self._graph.degree(node_id) if self._graph.has_node(node_id) else 0
|
||||
|
||||
async def node_degrees_batch(self, node_ids: List[str]) -> List[str]:
|
||||
return await asyncio.gather(*[self.node_degree(node_id) for node_id in node_ids])
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
return (self._graph.degree(src_id) if self._graph.has_node(src_id) else 0) + (
|
||||
self._graph.degree(tgt_id) if self._graph.has_node(tgt_id) else 0
|
||||
)
|
||||
|
||||
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> list[int]:
|
||||
return await asyncio.gather(*[self.edge_degree(src_id, tgt_id) for src_id, tgt_id in edge_pairs])
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
return self._graph.edges.get((source_node_id, target_node_id))
|
||||
|
||||
async def get_edges_batch(
|
||||
self, edge_pairs: list[tuple[str, str]]
|
||||
) -> list[Union[dict, None]]:
|
||||
return await asyncio.gather(*[self.get_edge(source_node_id, target_node_id) for source_node_id, target_node_id in edge_pairs])
|
||||
|
||||
async def get_node_edges(self, source_node_id: str):
|
||||
if self._graph.has_node(source_node_id):
|
||||
return list(self._graph.edges(source_node_id))
|
||||
return None
|
||||
|
||||
async def get_nodes_edges_batch(
|
||||
self, node_ids: list[str]
|
||||
) -> list[list[tuple[str, str]]]:
|
||||
return await asyncio.gather(*[self.get_node_edges(node_id) for node_id
|
||||
in node_ids])
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
self._graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_nodes_batch(self, nodes_data: list[tuple[str, dict[str, str]]]):
|
||||
await asyncio.gather(*[self.upsert_node(node_id, node_data) for node_id, node_data in nodes_data])
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def upsert_edges_batch(
|
||||
self, edges_data: list[tuple[str, str, dict[str, str]]]
|
||||
):
|
||||
await asyncio.gather(*[self.upsert_edge(source_node_id, target_node_id, edge_data)
|
||||
for source_node_id, target_node_id, edge_data in edges_data])
|
||||
|
||||
async def clustering(self, algorithm: str):
|
||||
if algorithm not in self._clustering_algorithms:
|
||||
raise ValueError(f"Clustering algorithm {algorithm} not supported")
|
||||
await self._clustering_algorithms[algorithm]()
|
||||
|
||||
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
|
||||
results = defaultdict(
|
||||
lambda: dict(
|
||||
level=None,
|
||||
title=None,
|
||||
edges=set(),
|
||||
nodes=set(),
|
||||
chunk_ids=set(),
|
||||
occurrence=0.0,
|
||||
sub_communities=[],
|
||||
)
|
||||
)
|
||||
max_num_ids = 0
|
||||
levels = defaultdict(set)
|
||||
for node_id, node_data in self._graph.nodes(data=True):
|
||||
if "clusters" not in node_data:
|
||||
continue
|
||||
clusters = json.loads(node_data["clusters"])
|
||||
this_node_edges = self._graph.edges(node_id)
|
||||
|
||||
for cluster in clusters:
|
||||
level = cluster["level"]
|
||||
cluster_key = str(cluster["cluster"])
|
||||
levels[level].add(cluster_key)
|
||||
results[cluster_key]["level"] = level
|
||||
results[cluster_key]["title"] = f"Cluster {cluster_key}"
|
||||
results[cluster_key]["nodes"].add(node_id)
|
||||
results[cluster_key]["edges"].update(
|
||||
[tuple(sorted(e)) for e in this_node_edges]
|
||||
)
|
||||
results[cluster_key]["chunk_ids"].update(
|
||||
node_data["source_id"].split(GRAPH_FIELD_SEP)
|
||||
)
|
||||
max_num_ids = max(max_num_ids, len(results[cluster_key]["chunk_ids"]))
|
||||
|
||||
ordered_levels = sorted(levels.keys())
|
||||
for i, curr_level in enumerate(ordered_levels[:-1]):
|
||||
next_level = ordered_levels[i + 1]
|
||||
this_level_comms = levels[curr_level]
|
||||
next_level_comms = levels[next_level]
|
||||
# compute the sub-communities by nodes intersection
|
||||
for comm in this_level_comms:
|
||||
results[comm]["sub_communities"] = [
|
||||
c
|
||||
for c in next_level_comms
|
||||
if results[c]["nodes"].issubset(results[comm]["nodes"])
|
||||
]
|
||||
|
||||
for k, v in results.items():
|
||||
v["edges"] = list(v["edges"])
|
||||
v["edges"] = [list(e) for e in v["edges"]]
|
||||
v["nodes"] = list(v["nodes"])
|
||||
v["chunk_ids"] = list(v["chunk_ids"])
|
||||
v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
|
||||
return dict(results)
|
||||
|
||||
def _cluster_data_to_subgraphs(self, cluster_data: dict[str, list[dict[str, str]]]):
|
||||
for node_id, clusters in cluster_data.items():
|
||||
self._graph.nodes[node_id]["clusters"] = json.dumps(clusters)
|
||||
|
||||
async def _leiden_clustering(self):
|
||||
from graspologic.partition import hierarchical_leiden
|
||||
|
||||
graph = NetworkXStorage.stable_largest_connected_component(self._graph)
|
||||
community_mapping = hierarchical_leiden(
|
||||
graph,
|
||||
max_cluster_size=self.global_config["max_graph_cluster_size"],
|
||||
random_seed=self.global_config["graph_cluster_seed"],
|
||||
)
|
||||
|
||||
node_communities: dict[str, list[dict[str, str]]] = defaultdict(list)
|
||||
__levels = defaultdict(set)
|
||||
for partition in community_mapping:
|
||||
level_key = partition.level
|
||||
cluster_id = partition.cluster
|
||||
node_communities[partition.node].append(
|
||||
{"level": level_key, "cluster": cluster_id}
|
||||
)
|
||||
__levels[level_key].add(cluster_id)
|
||||
node_communities = dict(node_communities)
|
||||
__levels = {k: len(v) for k, v in __levels.items()}
|
||||
logger.info(f"Each level has communities: {dict(__levels)}")
|
||||
self._cluster_data_to_subgraphs(node_communities)
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
|
||||
async def _node2vec_embed(self):
|
||||
from graspologic import embed
|
||||
|
||||
embeddings, nodes = embed.node2vec_embed(
|
||||
self._graph,
|
||||
**self.global_config["node2vec_params"],
|
||||
)
|
||||
|
||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
return embeddings, nodes_ids
|
||||
46
rag-web-ui/backend/nano_graphrag/_storage/kv_json.py
Normal file
46
rag-web-ui/backend/nano_graphrag/_storage/kv_json.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .._utils import load_json, logger, write_json
|
||||
from ..base import (
|
||||
BaseKVStorage,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data = load_json(self._file_name) or {}
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
return list(self._data.keys())
|
||||
|
||||
async def index_done_callback(self):
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def get_by_id(self, id):
|
||||
return self._data.get(id, None)
|
||||
|
||||
async def get_by_ids(self, ids, fields=None):
|
||||
if fields is None:
|
||||
return [self._data.get(id, None) for id in ids]
|
||||
return [
|
||||
(
|
||||
{k: v for k, v in self._data[id].items() if k in fields}
|
||||
if self._data.get(id, None)
|
||||
else None
|
||||
)
|
||||
for id in ids
|
||||
]
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
return set([s for s in data if s not in self._data])
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
self._data.update(data)
|
||||
|
||||
async def drop(self):
|
||||
self._data = {}
|
||||
141
rag-web-ui/backend/nano_graphrag/_storage/vdb_hnswlib.py
Normal file
141
rag-web-ui/backend/nano_graphrag/_storage/vdb_hnswlib.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
import pickle
|
||||
import hnswlib
|
||||
import numpy as np
|
||||
import xxhash
|
||||
|
||||
from .._utils import logger
|
||||
from ..base import BaseVectorStorage
|
||||
|
||||
|
||||
@dataclass
|
||||
class HNSWVectorStorage(BaseVectorStorage):
|
||||
ef_construction: int = 100
|
||||
M: int = 16
|
||||
max_elements: int = 1000000
|
||||
ef_search: int = 50
|
||||
num_threads: int = -1
|
||||
_index: Any = field(init=False)
|
||||
_metadata: dict[str, dict] = field(default_factory=dict)
|
||||
_current_elements: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
self._index_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"{self.namespace}_hnsw.index"
|
||||
)
|
||||
self._metadata_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"{self.namespace}_hnsw_metadata.pkl"
|
||||
)
|
||||
self._embedding_batch_num = self.global_config.get("embedding_batch_num", 100)
|
||||
|
||||
hnsw_params = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
self.ef_construction = hnsw_params.get("ef_construction", self.ef_construction)
|
||||
self.M = hnsw_params.get("M", self.M)
|
||||
self.max_elements = hnsw_params.get("max_elements", self.max_elements)
|
||||
self.ef_search = hnsw_params.get("ef_search", self.ef_search)
|
||||
self.num_threads = hnsw_params.get("num_threads", self.num_threads)
|
||||
self._index = hnswlib.Index(
|
||||
space="cosine", dim=self.embedding_func.embedding_dim
|
||||
)
|
||||
|
||||
if os.path.exists(self._index_file_name) and os.path.exists(
|
||||
self._metadata_file_name
|
||||
):
|
||||
self._index.load_index(
|
||||
self._index_file_name, max_elements=self.max_elements
|
||||
)
|
||||
with open(self._metadata_file_name, "rb") as f:
|
||||
self._metadata, self._current_elements = pickle.load(f)
|
||||
logger.info(
|
||||
f"Loaded existing index for {self.namespace} with {self._current_elements} elements"
|
||||
)
|
||||
else:
|
||||
self._index.init_index(
|
||||
max_elements=self.max_elements,
|
||||
ef_construction=self.ef_construction,
|
||||
M=self.M,
|
||||
)
|
||||
self._index.set_ef(self.ef_search)
|
||||
self._metadata = {}
|
||||
self._current_elements = 0
|
||||
logger.info(f"Created new index for {self.namespace}")
|
||||
|
||||
async def upsert(self, data: dict[str, dict]) -> np.ndarray:
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not data:
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
|
||||
if self._current_elements + len(data) > self.max_elements:
|
||||
raise ValueError(
|
||||
f"Cannot insert {len(data)} elements. Current: {self._current_elements}, Max: {self.max_elements}"
|
||||
)
|
||||
|
||||
list_data = [
|
||||
{
|
||||
"id": k,
|
||||
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batch_size = min(self._embedding_batch_num, len(contents))
|
||||
embeddings = np.concatenate(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self.embedding_func(contents[i : i + batch_size])
|
||||
for i in range(0, len(contents), batch_size)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
ids = np.fromiter(
|
||||
(xxhash.xxh32_intdigest(d["id"].encode()) for d in list_data),
|
||||
dtype=np.uint32,
|
||||
count=len(list_data),
|
||||
)
|
||||
self._metadata.update(
|
||||
{
|
||||
id_int: {
|
||||
k: v for k, v in d.items() if k in self.meta_fields or k == "id"
|
||||
}
|
||||
for id_int, d in zip(ids, list_data)
|
||||
}
|
||||
)
|
||||
self._index.add_items(data=embeddings, ids=ids, num_threads=self.num_threads)
|
||||
self._current_elements = self._index.get_current_count()
|
||||
return ids
|
||||
|
||||
async def query(self, query: str, top_k: int = 5) -> list[dict]:
|
||||
if self._current_elements == 0:
|
||||
return []
|
||||
|
||||
top_k = min(top_k, self._current_elements)
|
||||
|
||||
if top_k > self.ef_search:
|
||||
logger.warning(
|
||||
f"Setting ef_search to {top_k} because top_k is larger than ef_search"
|
||||
)
|
||||
self._index.set_ef(top_k)
|
||||
|
||||
embedding = await self.embedding_func([query])
|
||||
labels, distances = self._index.knn_query(
|
||||
data=embedding[0], k=top_k, num_threads=self.num_threads
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
**self._metadata.get(label, {}),
|
||||
"distance": distance,
|
||||
"similarity": 1 - distance,
|
||||
}
|
||||
for label, distance in zip(labels[0], distances[0])
|
||||
]
|
||||
|
||||
async def index_done_callback(self):
|
||||
self._index.save_index(self._index_file_name)
|
||||
with open(self._metadata_file_name, "wb") as f:
|
||||
pickle.dump((self._metadata, self._current_elements), f)
|
||||
@@ -0,0 +1,68 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from nano_vectordb import NanoVectorDB
|
||||
|
||||
from .._utils import logger
|
||||
from ..base import BaseVectorStorage
|
||||
|
||||
|
||||
@dataclass
|
||||
class NanoVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = 0.2
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
)
|
||||
self.cosine_better_than_threshold = self.global_config.get(
|
||||
"query_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
list_data = [
|
||||
{
|
||||
"__id__": k,
|
||||
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
results = self._client.upsert(datas=list_data)
|
||||
return results
|
||||
|
||||
async def query(self, query: str, top_k=5):
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
results = self._client.query(
|
||||
query=embedding,
|
||||
top_k=top_k,
|
||||
better_than_threshold=self.cosine_better_than_threshold,
|
||||
)
|
||||
results = [
|
||||
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
|
||||
]
|
||||
return results
|
||||
|
||||
async def index_done_callback(self):
|
||||
self._client.save()
|
||||
307
rag-web-ui/backend/nano_graphrag/_utils.py
Normal file
307
rag-web-ui/backend/nano_graphrag/_utils.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import asyncio
|
||||
import html
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import numbers
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from hashlib import md5
|
||||
from typing import Any, Union, Literal
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError:
|
||||
AutoTokenizer = None
|
||||
|
||||
logger = logging.getLogger("nano-graphrag")
|
||||
logging.getLogger("neo4j").setLevel(logging.ERROR)
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
# If there is already an event loop, use it.
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
# If in a sub-thread, create a new event loop.
|
||||
logger.info("Creating a new event loop in a sub-thread.")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
|
||||
def extract_first_complete_json(s: str):
|
||||
"""Extract the first complete JSON object from the string using a stack to track braces."""
|
||||
stack = []
|
||||
first_json_start = None
|
||||
|
||||
for i, char in enumerate(s):
|
||||
if char == '{':
|
||||
stack.append(i)
|
||||
if first_json_start is None:
|
||||
first_json_start = i
|
||||
elif char == '}':
|
||||
if stack:
|
||||
start = stack.pop()
|
||||
if not stack:
|
||||
first_json_str = s[first_json_start:i+1]
|
||||
try:
|
||||
# Attempt to parse the JSON string
|
||||
return json.loads(first_json_str.replace("\n", ""))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON decoding failed: {e}. Attempted string: {first_json_str[:50]}...")
|
||||
return None
|
||||
finally:
|
||||
first_json_start = None
|
||||
logger.warning("No complete JSON object found in the input string.")
|
||||
return None
|
||||
|
||||
def parse_value(value: str):
|
||||
"""Convert a string value to its appropriate type (int, float, bool, None, or keep as string). Work as a more broad 'eval()'"""
|
||||
value = value.strip()
|
||||
|
||||
if value == "null":
|
||||
return None
|
||||
elif value == "true":
|
||||
return True
|
||||
elif value == "false":
|
||||
return False
|
||||
else:
|
||||
# Try to convert to int or float
|
||||
try:
|
||||
if '.' in value: # If there's a dot, it might be a float
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
# If conversion fails, return the value as-is (likely a string)
|
||||
return value.strip('"') # Remove surrounding quotes if they exist
|
||||
|
||||
def extract_values_from_json(json_string, keys=["reasoning", "answer", "data"], allow_no_quotes=False):
|
||||
"""Extract key values from a non-standard or malformed JSON string, handling nested objects."""
|
||||
extracted_values = {}
|
||||
|
||||
# Enhanced pattern to match both quoted and unquoted values, as well as nested objects
|
||||
regex_pattern = r'(?P<key>"?\w+"?)\s*:\s*(?P<value>{[^}]*}|".*?"|[^,}]+)'
|
||||
|
||||
for match in re.finditer(regex_pattern, json_string, re.DOTALL):
|
||||
key = match.group('key').strip('"') # Strip quotes from key
|
||||
value = match.group('value').strip()
|
||||
|
||||
# If the value is another nested JSON (starts with '{' and ends with '}'), recursively parse it
|
||||
if value.startswith('{') and value.endswith('}'):
|
||||
extracted_values[key] = extract_values_from_json(value)
|
||||
else:
|
||||
# Parse the value into the appropriate type (int, float, bool, etc.)
|
||||
extracted_values[key] = parse_value(value)
|
||||
|
||||
if not extracted_values:
|
||||
logger.warning("No values could be extracted from the string.")
|
||||
|
||||
return extracted_values
|
||||
|
||||
|
||||
def convert_response_to_json(response: str) -> dict:
|
||||
"""Convert response string to JSON, with error handling and fallback to non-standard JSON extraction."""
|
||||
prediction_json = extract_first_complete_json(response)
|
||||
|
||||
if prediction_json is None:
|
||||
logger.info("Attempting to extract values from a non-standard JSON string...")
|
||||
prediction_json = extract_values_from_json(response, allow_no_quotes=True)
|
||||
|
||||
if not prediction_json:
|
||||
logger.error("Unable to extract meaningful data from the response.")
|
||||
else:
|
||||
logger.info("JSON data successfully extracted.")
|
||||
|
||||
return prediction_json
|
||||
|
||||
|
||||
|
||||
|
||||
class TokenizerWrapper:
|
||||
def __init__(self, tokenizer_type: Literal["tiktoken", "huggingface"] = "tiktoken", model_name: str = "gpt-4o"):
|
||||
self.tokenizer_type = tokenizer_type
|
||||
self.model_name = model_name
|
||||
self._tokenizer = None
|
||||
self._lazy_load_tokenizer()
|
||||
|
||||
def _lazy_load_tokenizer(self):
|
||||
if self._tokenizer is not None:
|
||||
return
|
||||
logger.info(f"Loading tokenizer: type='{self.tokenizer_type}', name='{self.model_name}'")
|
||||
if self.tokenizer_type == "tiktoken":
|
||||
self._tokenizer = tiktoken.encoding_for_model(self.model_name)
|
||||
elif self.tokenizer_type == "huggingface":
|
||||
if AutoTokenizer is None:
|
||||
raise ImportError("`transformers` is not installed. Please install it via `pip install transformers` to use HuggingFace tokenizers.")
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
||||
else:
|
||||
raise ValueError(f"Unknown tokenizer_type: {self.tokenizer_type}")
|
||||
|
||||
def get_tokenizer(self):
|
||||
"""提供对底层 tokenizer 对象的访问,用于特殊情况(如 decode_batch)。"""
|
||||
self._lazy_load_tokenizer()
|
||||
return self._tokenizer
|
||||
|
||||
def encode(self, text: str) -> list[int]:
|
||||
self._lazy_load_tokenizer()
|
||||
return self._tokenizer.encode(text)
|
||||
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
self._lazy_load_tokenizer()
|
||||
return self._tokenizer.decode(tokens)
|
||||
|
||||
# +++ 新增 +++: 增加一个批量解码的方法以提高效率,并保持接口一致性
|
||||
def decode_batch(self, tokens_list: list[list[int]]) -> list[str]:
|
||||
self._lazy_load_tokenizer()
|
||||
# HuggingFace tokenizer 有 decode_batch,但 tiktoken 没有,我们用列表推导来模拟
|
||||
if self.tokenizer_type == "tiktoken":
|
||||
return [self._tokenizer.decode(tokens) for tokens in tokens_list]
|
||||
elif self.tokenizer_type == "huggingface":
|
||||
return self._tokenizer.batch_decode(tokens_list, skip_special_tokens=True)
|
||||
else:
|
||||
raise ValueError(f"Unknown tokenizer_type: {self.tokenizer_type}")
|
||||
|
||||
|
||||
|
||||
def truncate_list_by_token_size(
|
||||
list_data: list,
|
||||
key: callable,
|
||||
max_token_size: int,
|
||||
tokenizer_wrapper: TokenizerWrapper
|
||||
):
|
||||
"""Truncate a list of data by token size using a provided tokenizer wrapper."""
|
||||
if max_token_size <= 0:
|
||||
return []
|
||||
tokens = 0
|
||||
for i, data in enumerate(list_data):
|
||||
tokens += len(tokenizer_wrapper.encode(key(data))) + 1 # 防御性,模拟通过\n拼接列表的情况
|
||||
if tokens > max_token_size:
|
||||
return list_data[:i]
|
||||
return list_data
|
||||
|
||||
|
||||
def compute_mdhash_id(content, prefix: str = ""):
|
||||
return prefix + md5(content.encode()).hexdigest()
|
||||
|
||||
|
||||
def write_json(json_obj, file_name):
|
||||
with open(file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def load_json(file_name):
|
||||
if not os.path.exists(file_name):
|
||||
return None
|
||||
with open(file_name, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
# it's dirty to type, so it's a good way to have fun
|
||||
def pack_user_ass_to_openai_messages(prompt: str, generated_content: str, using_amazon_bedrock: bool):
|
||||
if using_amazon_bedrock:
|
||||
return [
|
||||
{"role": "user", "content": [{"text": prompt}]},
|
||||
{"role": "assistant", "content": [{"text": generated_content}]},
|
||||
]
|
||||
else:
|
||||
return [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": generated_content},
|
||||
]
|
||||
|
||||
|
||||
def is_float_regex(value):
|
||||
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
||||
|
||||
|
||||
def compute_args_hash(*args):
|
||||
return md5(str(args).encode()).hexdigest()
|
||||
|
||||
|
||||
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
||||
"""Split a string by multiple markers"""
|
||||
if not markers:
|
||||
return [content]
|
||||
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
||||
return [r.strip() for r in results if r.strip()]
|
||||
|
||||
|
||||
def enclose_string_with_quotes(content: Any) -> str:
|
||||
"""Enclose a string with quotes"""
|
||||
if isinstance(content, numbers.Number):
|
||||
return str(content)
|
||||
content = str(content)
|
||||
content = content.strip().strip("'").strip('"')
|
||||
return f'"{content}"'
|
||||
|
||||
|
||||
def list_of_list_to_csv(data: list[list]):
|
||||
return "\n".join(
|
||||
[
|
||||
",\t".join([f"{enclose_string_with_quotes(data_dd)}" for data_dd in data_d])
|
||||
for data_d in data
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------------
|
||||
# Refer the utils functions of the official GraphRAG implementation:
|
||||
# https://github.com/microsoft/graphrag
|
||||
def clean_str(input: Any) -> str:
|
||||
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
|
||||
# If we get non-string input, just give it back
|
||||
if not isinstance(input, str):
|
||||
return input
|
||||
|
||||
result = html.unescape(input.strip())
|
||||
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
||||
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
||||
|
||||
|
||||
# Utils types -----------------------------------------------------------------------
|
||||
@dataclass
|
||||
class EmbeddingFunc:
|
||||
embedding_dim: int
|
||||
max_token_size: int
|
||||
func: callable
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||
return await self.func(*args, **kwargs)
|
||||
|
||||
|
||||
# Decorators ------------------------------------------------------------------------
|
||||
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
||||
"""Add restriction of maximum async calling times for a async func"""
|
||||
|
||||
def final_decro(func):
|
||||
"""Not using async.Semaphore to aovid use nest-asyncio"""
|
||||
__current_size = 0
|
||||
|
||||
@wraps(func)
|
||||
async def wait_func(*args, **kwargs):
|
||||
nonlocal __current_size
|
||||
while __current_size >= max_size:
|
||||
await asyncio.sleep(waitting_time)
|
||||
__current_size += 1
|
||||
result = await func(*args, **kwargs)
|
||||
__current_size -= 1
|
||||
return result
|
||||
|
||||
return wait_func
|
||||
|
||||
return final_decro
|
||||
|
||||
|
||||
def wrap_embedding_func_with_attrs(**kwargs):
|
||||
"""Wrap a function with attributes"""
|
||||
|
||||
def final_decro(func) -> EmbeddingFunc:
|
||||
new_func = EmbeddingFunc(**kwargs, func=func)
|
||||
return new_func
|
||||
|
||||
return final_decro
|
||||
186
rag-web-ui/backend/nano_graphrag/base.py
Normal file
186
rag-web-ui/backend/nano_graphrag/base.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypedDict, Union, Literal, Generic, TypeVar, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._utils import EmbeddingFunc
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryParam:
|
||||
mode: Literal["local", "global", "naive"] = "global"
|
||||
only_need_context: bool = False
|
||||
response_type: str = "Multiple Paragraphs"
|
||||
level: int = 2
|
||||
top_k: int = 20
|
||||
# naive search
|
||||
naive_max_token_for_text_unit = 12000
|
||||
# local search
|
||||
local_max_token_for_text_unit: int = 4000 # 12000 * 0.33
|
||||
local_max_token_for_local_context: int = 4800 # 12000 * 0.4
|
||||
local_max_token_for_community_report: int = 3200 # 12000 * 0.27
|
||||
local_community_single_one: bool = False
|
||||
# global search
|
||||
global_min_community_rating: float = 0
|
||||
global_max_consider_community: float = 512
|
||||
global_max_token_for_community_report: int = 16384
|
||||
global_special_community_map_llm_kwargs: dict = field(
|
||||
default_factory=lambda: {"response_format": {"type": "json_object"}}
|
||||
)
|
||||
|
||||
|
||||
TextChunkSchema = TypedDict(
|
||||
"TextChunkSchema",
|
||||
{"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
|
||||
)
|
||||
|
||||
SingleCommunitySchema = TypedDict(
|
||||
"SingleCommunitySchema",
|
||||
{
|
||||
"level": int,
|
||||
"title": str,
|
||||
"edges": list[list[str, str]],
|
||||
"nodes": list[str],
|
||||
"chunk_ids": list[str],
|
||||
"occurrence": float,
|
||||
"sub_communities": list[str],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class CommunitySchema(SingleCommunitySchema):
|
||||
report_string: str
|
||||
report_json: dict
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageNameSpace:
|
||||
namespace: str
|
||||
global_config: dict
|
||||
|
||||
async def index_start_callback(self):
|
||||
"""commit the storage operations after indexing"""
|
||||
pass
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""commit the storage operations after indexing"""
|
||||
pass
|
||||
|
||||
async def query_done_callback(self):
|
||||
"""commit the storage operations after querying"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseVectorStorage(StorageNameSpace):
|
||||
embedding_func: EmbeddingFunc
|
||||
meta_fields: set = field(default_factory=set)
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
"""Use 'content' field from value for embedding, use key as id.
|
||||
If embedding_func is None, use 'embedding' field from value
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseKVStorage(Generic[T], StorageNameSpace):
|
||||
async def all_keys(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[T, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_by_ids(
|
||||
self, ids: list[str], fields: Union[set[str], None] = None
|
||||
) -> list[Union[T, None]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
"""return un-exist keys"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def upsert(self, data: dict[str, T]):
|
||||
raise NotImplementedError
|
||||
|
||||
async def drop(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseGraphStorage(StorageNameSpace):
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
async def node_degrees_batch(self, node_ids: List[str]) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> list[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, Union[dict, None]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_edges_batch(
|
||||
self, edge_pairs: list[tuple[str, str]]
|
||||
) -> list[Union[dict, None]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_node_edges(
|
||||
self, source_node_id: str
|
||||
) -> Union[list[tuple[str, str]], None]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_nodes_edges_batch(
|
||||
self, node_ids: list[str]
|
||||
) -> list[list[tuple[str, str]]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
raise NotImplementedError
|
||||
|
||||
async def upsert_nodes_batch(self, nodes_data: list[tuple[str, dict[str, str]]]):
|
||||
raise NotImplementedError
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
async def upsert_edges_batch(
|
||||
self, edges_data: list[tuple[str, str, dict[str, str]]]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
async def clustering(self, algorithm: str):
|
||||
raise NotImplementedError
|
||||
|
||||
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
|
||||
"""Return the community representation with report and nodes"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
raise NotImplementedError("Node embedding is not used in nano-graphrag.")
|
||||
171
rag-web-ui/backend/nano_graphrag/entity_extraction/extract.py
Normal file
171
rag-web-ui/backend/nano_graphrag/entity_extraction/extract.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from typing import Union
|
||||
import pickle
|
||||
import asyncio
|
||||
from openai import BadRequestError
|
||||
from collections import defaultdict
|
||||
import dspy
|
||||
from nano_graphrag.base import (
|
||||
BaseGraphStorage,
|
||||
BaseVectorStorage,
|
||||
TextChunkSchema,
|
||||
)
|
||||
from nano_graphrag.prompt import PROMPTS
|
||||
from nano_graphrag._utils import logger, compute_mdhash_id
|
||||
from nano_graphrag.entity_extraction.module import TypedEntityRelationshipExtractor
|
||||
from nano_graphrag._op import _merge_edges_then_upsert, _merge_nodes_then_upsert
|
||||
|
||||
|
||||
async def generate_dataset(
|
||||
chunks: dict[str, TextChunkSchema],
|
||||
filepath: str,
|
||||
save_dataset: bool = True,
|
||||
global_config: dict = {},
|
||||
) -> list[dspy.Example]:
|
||||
entity_extractor = TypedEntityRelationshipExtractor(num_refine_turns=1, self_refine=True)
|
||||
|
||||
if global_config.get("use_compiled_dspy_entity_relationship", False):
|
||||
entity_extractor.load(global_config["entity_relationship_module_path"])
|
||||
|
||||
ordered_chunks = list(chunks.items())
|
||||
already_processed = 0
|
||||
already_entities = 0
|
||||
already_relations = 0
|
||||
|
||||
async def _process_single_content(
|
||||
chunk_key_dp: tuple[str, TextChunkSchema]
|
||||
) -> dspy.Example:
|
||||
nonlocal already_processed, already_entities, already_relations
|
||||
chunk_dp = chunk_key_dp[1]
|
||||
content = chunk_dp["content"]
|
||||
try:
|
||||
prediction = await asyncio.to_thread(entity_extractor, input_text=content)
|
||||
entities, relationships = prediction.entities, prediction.relationships
|
||||
except BadRequestError as e:
|
||||
logger.error(f"Error in TypedEntityRelationshipExtractor: {e}")
|
||||
entities, relationships = [], []
|
||||
example = dspy.Example(
|
||||
input_text=content, entities=entities, relationships=relationships
|
||||
).with_inputs("input_text")
|
||||
already_entities += len(entities)
|
||||
already_relations += len(relationships)
|
||||
already_processed += 1
|
||||
now_ticks = PROMPTS["process_tickers"][
|
||||
already_processed % len(PROMPTS["process_tickers"])
|
||||
]
|
||||
print(
|
||||
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
return example
|
||||
|
||||
examples = await asyncio.gather(
|
||||
*[_process_single_content(c) for c in ordered_chunks]
|
||||
)
|
||||
filtered_examples = [
|
||||
example
|
||||
for example in examples
|
||||
if len(example.entities) > 0 and len(example.relationships) > 0
|
||||
]
|
||||
num_filtered_examples = len(examples) - len(filtered_examples)
|
||||
if save_dataset:
|
||||
with open(filepath, "wb") as f:
|
||||
pickle.dump(filtered_examples, f)
|
||||
logger.info(
|
||||
f"Saved {len(filtered_examples)} examples with keys: {filtered_examples[0].keys()}, filtered {num_filtered_examples} examples"
|
||||
)
|
||||
|
||||
return filtered_examples
|
||||
|
||||
|
||||
async def extract_entities_dspy(
|
||||
chunks: dict[str, TextChunkSchema],
|
||||
knwoledge_graph_inst: BaseGraphStorage,
|
||||
entity_vdb: BaseVectorStorage,
|
||||
global_config: dict,
|
||||
) -> Union[BaseGraphStorage, None]:
|
||||
entity_extractor = TypedEntityRelationshipExtractor(num_refine_turns=1, self_refine=True)
|
||||
|
||||
if global_config.get("use_compiled_dspy_entity_relationship", False):
|
||||
entity_extractor.load(global_config["entity_relationship_module_path"])
|
||||
|
||||
ordered_chunks = list(chunks.items())
|
||||
already_processed = 0
|
||||
already_entities = 0
|
||||
already_relations = 0
|
||||
|
||||
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
||||
nonlocal already_processed, already_entities, already_relations
|
||||
chunk_key = chunk_key_dp[0]
|
||||
chunk_dp = chunk_key_dp[1]
|
||||
content = chunk_dp["content"]
|
||||
try:
|
||||
prediction = await asyncio.to_thread(entity_extractor, input_text=content)
|
||||
entities, relationships = prediction.entities, prediction.relationships
|
||||
except BadRequestError as e:
|
||||
logger.error(f"Error in TypedEntityRelationshipExtractor: {e}")
|
||||
entities, relationships = [], []
|
||||
|
||||
maybe_nodes = defaultdict(list)
|
||||
maybe_edges = defaultdict(list)
|
||||
|
||||
for entity in entities:
|
||||
entity["source_id"] = chunk_key
|
||||
maybe_nodes[entity["entity_name"]].append(entity)
|
||||
already_entities += 1
|
||||
|
||||
for relationship in relationships:
|
||||
relationship["source_id"] = chunk_key
|
||||
maybe_edges[(relationship["src_id"], relationship["tgt_id"])].append(
|
||||
relationship
|
||||
)
|
||||
already_relations += 1
|
||||
|
||||
already_processed += 1
|
||||
now_ticks = PROMPTS["process_tickers"][
|
||||
already_processed % len(PROMPTS["process_tickers"])
|
||||
]
|
||||
print(
|
||||
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
return dict(maybe_nodes), dict(maybe_edges)
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[_process_single_content(c) for c in ordered_chunks]
|
||||
)
|
||||
print()
|
||||
maybe_nodes = defaultdict(list)
|
||||
maybe_edges = defaultdict(list)
|
||||
for m_nodes, m_edges in results:
|
||||
for k, v in m_nodes.items():
|
||||
maybe_nodes[k].extend(v)
|
||||
for k, v in m_edges.items():
|
||||
maybe_edges[k].extend(v)
|
||||
all_entities_data = await asyncio.gather(
|
||||
*[
|
||||
_merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
|
||||
for k, v in maybe_nodes.items()
|
||||
]
|
||||
)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
_merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
|
||||
for k, v in maybe_edges.items()
|
||||
]
|
||||
)
|
||||
if not len(all_entities_data):
|
||||
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
||||
return None
|
||||
if entity_vdb is not None:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||
"content": dp["entity_name"] + dp["description"],
|
||||
"entity_name": dp["entity_name"],
|
||||
}
|
||||
for dp in all_entities_data
|
||||
}
|
||||
await entity_vdb.upsert(data_for_vdb)
|
||||
|
||||
return knwoledge_graph_inst
|
||||
62
rag-web-ui/backend/nano_graphrag/entity_extraction/metric.py
Normal file
62
rag-web-ui/backend/nano_graphrag/entity_extraction/metric.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import dspy
|
||||
from nano_graphrag.entity_extraction.module import Relationship
|
||||
|
||||
|
||||
class AssessRelationships(dspy.Signature):
|
||||
"""
|
||||
Assess the similarity between gold and predicted relationships:
|
||||
1. Match relationships based on src_id and tgt_id pairs, allowing for slight variations in entity names.
|
||||
2. For matched pairs, compare:
|
||||
a) Description similarity (semantic meaning)
|
||||
b) Weight similarity
|
||||
c) Order similarity
|
||||
3. Consider unmatched relationships as penalties.
|
||||
4. Aggregate scores, accounting for precision and recall.
|
||||
5. Return a final similarity score between 0 (no similarity) and 1 (perfect match).
|
||||
|
||||
Key considerations:
|
||||
- Prioritize matching based on entity pairs over exact string matches.
|
||||
- Use semantic similarity for descriptions rather than exact matches.
|
||||
- Weight the importance of different aspects (e.g., entity matching, description, weight, order).
|
||||
- Balance the impact of matched and unmatched relationships in the final score.
|
||||
"""
|
||||
|
||||
gold_relationships: list[Relationship] = dspy.InputField(
|
||||
desc="The gold-standard relationships to compare against."
|
||||
)
|
||||
predicted_relationships: list[Relationship] = dspy.InputField(
|
||||
desc="The predicted relationships to compare against the gold-standard relationships."
|
||||
)
|
||||
similarity_score: float = dspy.OutputField(
|
||||
desc="Similarity score between 0 and 1, with 1 being the highest similarity."
|
||||
)
|
||||
|
||||
|
||||
def relationships_similarity_metric(
|
||||
gold: dspy.Example, pred: dspy.Prediction, trace=None
|
||||
) -> float:
|
||||
model = dspy.ChainOfThought(AssessRelationships)
|
||||
gold_relationships = [Relationship(**item) for item in gold["relationships"]]
|
||||
predicted_relationships = [Relationship(**item) for item in pred["relationships"]]
|
||||
similarity_score = float(
|
||||
model(
|
||||
gold_relationships=gold_relationships,
|
||||
predicted_relationships=predicted_relationships,
|
||||
).similarity_score
|
||||
)
|
||||
return similarity_score
|
||||
|
||||
|
||||
def entity_recall_metric(
|
||||
gold: dspy.Example, pred: dspy.Prediction, trace=None
|
||||
) -> float:
|
||||
true_set = set(item["entity_name"] for item in gold["entities"])
|
||||
pred_set = set(item["entity_name"] for item in pred["entities"])
|
||||
true_positives = len(pred_set.intersection(true_set))
|
||||
false_negatives = len(true_set - pred_set)
|
||||
recall = (
|
||||
true_positives / (true_positives + false_negatives)
|
||||
if (true_positives + false_negatives) > 0
|
||||
else 0
|
||||
)
|
||||
return recall
|
||||
330
rag-web-ui/backend/nano_graphrag/entity_extraction/module.py
Normal file
330
rag-web-ui/backend/nano_graphrag/entity_extraction/module.py
Normal file
@@ -0,0 +1,330 @@
|
||||
import dspy
|
||||
from pydantic import BaseModel, Field
|
||||
from nano_graphrag._utils import clean_str
|
||||
from nano_graphrag._utils import logger
|
||||
|
||||
|
||||
"""
|
||||
Obtained from:
|
||||
https://github.com/SciPhi-AI/R2R/blob/6e958d1e451c1cb10b6fc868572659785d1091cb/r2r/providers/prompts/defaults.jsonl
|
||||
"""
|
||||
ENTITY_TYPES = [
|
||||
"PERSON",
|
||||
"ORGANIZATION",
|
||||
"LOCATION",
|
||||
"DATE",
|
||||
"TIME",
|
||||
"MONEY",
|
||||
"PERCENTAGE",
|
||||
"PRODUCT",
|
||||
"EVENT",
|
||||
"LANGUAGE",
|
||||
"NATIONALITY",
|
||||
"RELIGION",
|
||||
"TITLE",
|
||||
"PROFESSION",
|
||||
"ANIMAL",
|
||||
"PLANT",
|
||||
"DISEASE",
|
||||
"MEDICATION",
|
||||
"CHEMICAL",
|
||||
"MATERIAL",
|
||||
"COLOR",
|
||||
"SHAPE",
|
||||
"MEASUREMENT",
|
||||
"WEATHER",
|
||||
"NATURAL_DISASTER",
|
||||
"AWARD",
|
||||
"LAW",
|
||||
"CRIME",
|
||||
"TECHNOLOGY",
|
||||
"SOFTWARE",
|
||||
"HARDWARE",
|
||||
"VEHICLE",
|
||||
"FOOD",
|
||||
"DRINK",
|
||||
"SPORT",
|
||||
"MUSIC_GENRE",
|
||||
"INSTRUMENT",
|
||||
"ARTWORK",
|
||||
"BOOK",
|
||||
"MOVIE",
|
||||
"TV_SHOW",
|
||||
"ACADEMIC_SUBJECT",
|
||||
"SCIENTIFIC_THEORY",
|
||||
"POLITICAL_PARTY",
|
||||
"CURRENCY",
|
||||
"STOCK_SYMBOL",
|
||||
"FILE_TYPE",
|
||||
"PROGRAMMING_LANGUAGE",
|
||||
"MEDICAL_PROCEDURE",
|
||||
"CELESTIAL_BODY",
|
||||
]
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
entity_name: str = Field(..., description="The name of the entity.")
|
||||
entity_type: str = Field(..., description="The type of the entity.")
|
||||
description: str = Field(
|
||||
..., description="The description of the entity, in details and comprehensive."
|
||||
)
|
||||
importance_score: float = Field(
|
||||
...,
|
||||
ge=0,
|
||||
le=1,
|
||||
description="Importance score of the entity. Should be between 0 and 1 with 1 being the most important.",
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"entity_name": clean_str(self.entity_name.upper()),
|
||||
"entity_type": clean_str(self.entity_type.upper()),
|
||||
"description": clean_str(self.description),
|
||||
"importance_score": float(self.importance_score),
|
||||
}
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
src_id: str = Field(..., description="The name of the source entity.")
|
||||
tgt_id: str = Field(..., description="The name of the target entity.")
|
||||
description: str = Field(
|
||||
...,
|
||||
description="The description of the relationship between the source and target entity, in details and comprehensive.",
|
||||
)
|
||||
weight: float = Field(
|
||||
...,
|
||||
ge=0,
|
||||
le=1,
|
||||
description="The weight of the relationship. Should be between 0 and 1 with 1 being the strongest relationship.",
|
||||
)
|
||||
order: int = Field(
|
||||
...,
|
||||
ge=1,
|
||||
le=3,
|
||||
description="The order of the relationship. 1 for direct relationships, 2 for second-order, 3 for third-order.",
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"src_id": clean_str(self.src_id.upper()),
|
||||
"tgt_id": clean_str(self.tgt_id.upper()),
|
||||
"description": clean_str(self.description),
|
||||
"weight": float(self.weight),
|
||||
"order": int(self.order),
|
||||
}
|
||||
|
||||
|
||||
class CombinedExtraction(dspy.Signature):
|
||||
"""
|
||||
Given a text document that is potentially relevant to this activity and a list of entity types,
|
||||
identify all entities of those types from the text and all relationships among the identified entities.
|
||||
|
||||
Entity Guidelines:
|
||||
1. Each entity name should be an actual atomic word from the input text.
|
||||
2. Avoid duplicates and generic terms.
|
||||
3. Make sure descriptions are detailed and comprehensive. Use multiple complete sentences for each point below:
|
||||
a). The entity's role or significance in the context
|
||||
b). Key attributes or characteristics
|
||||
c). Relationships to other entities (if applicable)
|
||||
d). Historical or cultural relevance (if applicable)
|
||||
e). Any notable actions or events associated with the entity
|
||||
4. All entity types from the text must be included.
|
||||
5. IMPORTANT: Only use entity types from the provided 'entity_types' list. Do not introduce new entity types.
|
||||
|
||||
Relationship Guidelines:
|
||||
1. Make sure relationship descriptions are detailed and comprehensive. Use multiple complete sentences for each point below:
|
||||
a). The nature of the relationship (e.g., familial, professional, causal)
|
||||
b). The impact or significance of the relationship on both entities
|
||||
c). Any historical or contextual information relevant to the relationship
|
||||
d). How the relationship evolved over time (if applicable)
|
||||
e). Any notable events or actions that resulted from this relationship
|
||||
2. Include direct relationships (order 1) as well as higher-order relationships (order 2 and 3):
|
||||
a). Direct relationships: Immediate connections between entities.
|
||||
b). Second-order relationships: Indirect effects or connections that result from direct relationships.
|
||||
c). Third-order relationships: Further indirect effects that result from second-order relationships.
|
||||
3. The "src_id" and "tgt_id" fields must exactly match entity names from the extracted entities list.
|
||||
"""
|
||||
|
||||
input_text: str = dspy.InputField(
|
||||
desc="The text to extract entities and relationships from."
|
||||
)
|
||||
entity_types: list[str] = dspy.InputField(
|
||||
desc="List of entity types used for extraction."
|
||||
)
|
||||
entities: list[Entity] = dspy.OutputField(
|
||||
desc="List of entities extracted from the text and the entity types."
|
||||
)
|
||||
relationships: list[Relationship] = dspy.OutputField(
|
||||
desc="List of relationships extracted from the text and the entity types."
|
||||
)
|
||||
|
||||
|
||||
class CritiqueCombinedExtraction(dspy.Signature):
|
||||
"""
|
||||
Critique the current extraction of entities and relationships from a given text.
|
||||
Focus on completeness, accuracy, and adherence to the provided entity types and extraction guidelines.
|
||||
|
||||
Critique Guidelines:
|
||||
1. Evaluate if all relevant entities from the input text are captured and correctly typed.
|
||||
2. Check if entity descriptions are comprehensive and follow the provided guidelines.
|
||||
3. Assess the completeness of relationship extractions, including higher-order relationships.
|
||||
4. Verify that relationship descriptions are detailed and follow the provided guidelines.
|
||||
5. Identify any inconsistencies, errors, or missed opportunities in the current extraction.
|
||||
6. Suggest specific improvements or additions to enhance the quality of the extraction.
|
||||
"""
|
||||
|
||||
input_text: str = dspy.InputField(
|
||||
desc="The original text from which entities and relationships were extracted."
|
||||
)
|
||||
entity_types: list[str] = dspy.InputField(
|
||||
desc="List of valid entity types for this extraction task."
|
||||
)
|
||||
current_entities: list[Entity] = dspy.InputField(
|
||||
desc="List of currently extracted entities to be critiqued."
|
||||
)
|
||||
current_relationships: list[Relationship] = dspy.InputField(
|
||||
desc="List of currently extracted relationships to be critiqued."
|
||||
)
|
||||
entity_critique: str = dspy.OutputField(
|
||||
desc="Detailed critique of the current entities, highlighting areas for improvement for completeness and accuracy.."
|
||||
)
|
||||
relationship_critique: str = dspy.OutputField(
|
||||
desc="Detailed critique of the current relationships, highlighting areas for improvement for completeness and accuracy.."
|
||||
)
|
||||
|
||||
|
||||
class RefineCombinedExtraction(dspy.Signature):
|
||||
"""
|
||||
Refine the current extraction of entities and relationships based on the provided critique.
|
||||
Improve completeness, accuracy, and adherence to the extraction guidelines.
|
||||
|
||||
Refinement Guidelines:
|
||||
1. Address all points raised in the entity and relationship critiques.
|
||||
2. Add missing entities and relationships identified in the critique.
|
||||
3. Improve entity and relationship descriptions as suggested.
|
||||
4. Ensure all refinements still adhere to the original extraction guidelines.
|
||||
5. Maintain consistency between entities and relationships during refinement.
|
||||
6. Focus on enhancing the overall quality and comprehensiveness of the extraction.
|
||||
"""
|
||||
|
||||
input_text: str = dspy.InputField(
|
||||
desc="The original text from which entities and relationships were extracted."
|
||||
)
|
||||
entity_types: list[str] = dspy.InputField(
|
||||
desc="List of valid entity types for this extraction task."
|
||||
)
|
||||
current_entities: list[Entity] = dspy.InputField(
|
||||
desc="List of currently extracted entities to be refined."
|
||||
)
|
||||
current_relationships: list[Relationship] = dspy.InputField(
|
||||
desc="List of currently extracted relationships to be refined."
|
||||
)
|
||||
entity_critique: str = dspy.InputField(
|
||||
desc="Detailed critique of the current entities to guide refinement."
|
||||
)
|
||||
relationship_critique: str = dspy.InputField(
|
||||
desc="Detailed critique of the current relationships to guide refinement."
|
||||
)
|
||||
refined_entities: list[Entity] = dspy.OutputField(
|
||||
desc="List of refined entities, addressing the entity critique and improving upon the current entities."
|
||||
)
|
||||
refined_relationships: list[Relationship] = dspy.OutputField(
|
||||
desc="List of refined relationships, addressing the relationship critique and improving upon the current relationships."
|
||||
)
|
||||
|
||||
|
||||
class TypedEntityRelationshipExtractorException(dspy.Module):
|
||||
def __init__(
|
||||
self,
|
||||
predictor: dspy.Module,
|
||||
exception_types: tuple[type[Exception]] = (Exception,),
|
||||
):
|
||||
super().__init__()
|
||||
self.predictor = predictor
|
||||
self.exception_types = exception_types
|
||||
|
||||
def copy(self):
|
||||
return TypedEntityRelationshipExtractorException(self.predictor)
|
||||
|
||||
def forward(self, **kwargs):
|
||||
try:
|
||||
prediction = self.predictor(**kwargs)
|
||||
return prediction
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, self.exception_types):
|
||||
return dspy.Prediction(entities=[], relationships=[])
|
||||
|
||||
raise e
|
||||
|
||||
|
||||
class TypedEntityRelationshipExtractor(dspy.Module):
|
||||
def __init__(
|
||||
self,
|
||||
lm: dspy.LM = None,
|
||||
max_retries: int = 3,
|
||||
entity_types: list[str] = ENTITY_TYPES,
|
||||
self_refine: bool = False,
|
||||
num_refine_turns: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.lm = lm
|
||||
self.entity_types = entity_types
|
||||
self.self_refine = self_refine
|
||||
self.num_refine_turns = num_refine_turns
|
||||
|
||||
self.extractor = dspy.ChainOfThought(
|
||||
signature=CombinedExtraction, max_retries=max_retries
|
||||
)
|
||||
self.extractor = TypedEntityRelationshipExtractorException(
|
||||
self.extractor, exception_types=(ValueError,)
|
||||
)
|
||||
|
||||
if self.self_refine:
|
||||
self.critique = dspy.ChainOfThought(
|
||||
signature=CritiqueCombinedExtraction, max_retries=max_retries
|
||||
)
|
||||
self.refine = dspy.ChainOfThought(
|
||||
signature=RefineCombinedExtraction, max_retries=max_retries
|
||||
)
|
||||
|
||||
def forward(self, input_text: str) -> dspy.Prediction:
|
||||
with dspy.context(lm=self.lm if self.lm is not None else dspy.settings.lm):
|
||||
extraction_result = self.extractor(
|
||||
input_text=input_text, entity_types=self.entity_types
|
||||
)
|
||||
|
||||
current_entities: list[Entity] = extraction_result.entities
|
||||
current_relationships: list[Relationship] = extraction_result.relationships
|
||||
|
||||
if self.self_refine:
|
||||
for _ in range(self.num_refine_turns):
|
||||
critique_result = self.critique(
|
||||
input_text=input_text,
|
||||
entity_types=self.entity_types,
|
||||
current_entities=current_entities,
|
||||
current_relationships=current_relationships,
|
||||
)
|
||||
refined_result = self.refine(
|
||||
input_text=input_text,
|
||||
entity_types=self.entity_types,
|
||||
current_entities=current_entities,
|
||||
current_relationships=current_relationships,
|
||||
entity_critique=critique_result.entity_critique,
|
||||
relationship_critique=critique_result.relationship_critique,
|
||||
)
|
||||
logger.debug(
|
||||
f"entities: {len(current_entities)} | refined_entities: {len(refined_result.refined_entities)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"relationships: {len(current_relationships)} | refined_relationships: {len(refined_result.refined_relationships)}"
|
||||
)
|
||||
current_entities = refined_result.refined_entities
|
||||
current_relationships = refined_result.refined_relationships
|
||||
|
||||
entities = [entity.to_dict() for entity in current_entities]
|
||||
relationships = [
|
||||
relationship.to_dict() for relationship in current_relationships
|
||||
]
|
||||
|
||||
return dspy.Prediction(entities=entities, relationships=relationships)
|
||||
382
rag-web-ui/backend/nano_graphrag/graphrag.py
Normal file
382
rag-web-ui/backend/nano_graphrag/graphrag.py
Normal file
@@ -0,0 +1,382 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Type, Union, cast
|
||||
|
||||
|
||||
|
||||
from ._llm import (
|
||||
amazon_bedrock_embedding,
|
||||
create_amazon_bedrock_complete_function,
|
||||
gpt_4o_complete,
|
||||
gpt_4o_mini_complete,
|
||||
openai_embedding,
|
||||
azure_gpt_4o_complete,
|
||||
azure_openai_embedding,
|
||||
azure_gpt_4o_mini_complete,
|
||||
)
|
||||
from ._op import (
|
||||
chunking_by_token_size,
|
||||
extract_entities,
|
||||
generate_community_report,
|
||||
get_chunks,
|
||||
local_query,
|
||||
global_query,
|
||||
naive_query,
|
||||
)
|
||||
from ._storage import (
|
||||
JsonKVStorage,
|
||||
NanoVectorDBStorage,
|
||||
NetworkXStorage,
|
||||
)
|
||||
from ._utils import (
|
||||
EmbeddingFunc,
|
||||
compute_mdhash_id,
|
||||
limit_async_func_call,
|
||||
convert_response_to_json,
|
||||
always_get_an_event_loop,
|
||||
logger,
|
||||
TokenizerWrapper,
|
||||
)
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
BaseKVStorage,
|
||||
BaseVectorStorage,
|
||||
StorageNameSpace,
|
||||
QueryParam,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphRAG:
|
||||
working_dir: str = field(
|
||||
default_factory=lambda: f"./nano_graphrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
||||
)
|
||||
# graph mode
|
||||
enable_local: bool = True
|
||||
enable_naive_rag: bool = False
|
||||
|
||||
# text chunking
|
||||
tokenizer_type: str = "tiktoken" # or 'huggingface'
|
||||
tiktoken_model_name: str = "gpt-4o"
|
||||
huggingface_model_name: str = "bert-base-uncased" # default HF model
|
||||
chunk_func: Callable[
|
||||
[
|
||||
list[list[int]],
|
||||
List[str],
|
||||
TokenizerWrapper,
|
||||
Optional[int],
|
||||
Optional[int],
|
||||
],
|
||||
List[Dict[str, Union[str, int]]],
|
||||
] = chunking_by_token_size
|
||||
chunk_token_size: int = 1200
|
||||
chunk_overlap_token_size: int = 100
|
||||
|
||||
|
||||
# entity extraction
|
||||
entity_extract_max_gleaning: int = 1
|
||||
entity_summary_to_max_tokens: int = 500
|
||||
|
||||
# graph clustering
|
||||
graph_cluster_algorithm: str = "leiden"
|
||||
max_graph_cluster_size: int = 10
|
||||
graph_cluster_seed: int = 0xDEADBEEF
|
||||
|
||||
# node embedding
|
||||
node_embedding_algorithm: str = "node2vec"
|
||||
node2vec_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"dimensions": 1536,
|
||||
"num_walks": 10,
|
||||
"walk_length": 40,
|
||||
"num_walks": 10,
|
||||
"window_size": 2,
|
||||
"iterations": 3,
|
||||
"random_seed": 3,
|
||||
}
|
||||
)
|
||||
|
||||
# community reports
|
||||
special_community_report_llm_kwargs: dict = field(
|
||||
default_factory=lambda: {"response_format": {"type": "json_object"}}
|
||||
)
|
||||
|
||||
# text embedding
|
||||
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
||||
embedding_batch_num: int = 32
|
||||
embedding_func_max_async: int = 16
|
||||
query_better_than_threshold: float = 0.2
|
||||
|
||||
# LLM
|
||||
using_azure_openai: bool = False
|
||||
using_amazon_bedrock: bool = False
|
||||
best_model_id: str = "us.anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
cheap_model_id: str = "us.anthropic.claude-3-haiku-20240307-v1:0"
|
||||
best_model_func: callable = gpt_4o_complete
|
||||
best_model_max_token_size: int = 32768
|
||||
best_model_max_async: int = 16
|
||||
cheap_model_func: callable = gpt_4o_mini_complete
|
||||
cheap_model_max_token_size: int = 32768
|
||||
cheap_model_max_async: int = 16
|
||||
|
||||
# entity extraction
|
||||
entity_extraction_func: callable = extract_entities
|
||||
|
||||
# storage
|
||||
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
|
||||
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
|
||||
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
||||
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
|
||||
enable_llm_cache: bool = True
|
||||
|
||||
# extension
|
||||
always_create_working_dir: bool = True
|
||||
addon_params: dict = field(default_factory=dict)
|
||||
convert_response_to_json_func: callable = convert_response_to_json
|
||||
|
||||
def __post_init__(self):
|
||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
||||
logger.debug(f"GraphRAG init with param:\n\n {_print_config}\n")
|
||||
|
||||
self.tokenizer_wrapper = TokenizerWrapper(
|
||||
tokenizer_type=self.tokenizer_type,
|
||||
model_name=self.tiktoken_model_name if self.tokenizer_type == "tiktoken" else self.huggingface_model_name
|
||||
)
|
||||
|
||||
if self.using_azure_openai:
|
||||
# If there's no OpenAI API key, use Azure OpenAI
|
||||
if self.best_model_func == gpt_4o_complete:
|
||||
self.best_model_func = azure_gpt_4o_complete
|
||||
if self.cheap_model_func == gpt_4o_mini_complete:
|
||||
self.cheap_model_func = azure_gpt_4o_mini_complete
|
||||
if self.embedding_func == openai_embedding:
|
||||
self.embedding_func = azure_openai_embedding
|
||||
logger.info(
|
||||
"Switched the default openai funcs to Azure OpenAI if you didn't set any of it"
|
||||
)
|
||||
|
||||
if self.using_amazon_bedrock:
|
||||
self.best_model_func = create_amazon_bedrock_complete_function(self.best_model_id)
|
||||
self.cheap_model_func = create_amazon_bedrock_complete_function(self.cheap_model_id)
|
||||
self.embedding_func = amazon_bedrock_embedding
|
||||
logger.info(
|
||||
"Switched the default openai funcs to Amazon Bedrock"
|
||||
)
|
||||
|
||||
if not os.path.exists(self.working_dir) and self.always_create_working_dir:
|
||||
logger.info(f"Creating working directory {self.working_dir}")
|
||||
os.makedirs(self.working_dir)
|
||||
|
||||
self.full_docs = self.key_string_value_json_storage_cls(
|
||||
namespace="full_docs", global_config=asdict(self)
|
||||
)
|
||||
|
||||
self.text_chunks = self.key_string_value_json_storage_cls(
|
||||
namespace="text_chunks", global_config=asdict(self)
|
||||
)
|
||||
|
||||
self.llm_response_cache = (
|
||||
self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache", global_config=asdict(self)
|
||||
)
|
||||
if self.enable_llm_cache
|
||||
else None
|
||||
)
|
||||
|
||||
self.community_reports = self.key_string_value_json_storage_cls(
|
||||
namespace="community_reports", global_config=asdict(self)
|
||||
)
|
||||
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
||||
namespace="chunk_entity_relation", global_config=asdict(self)
|
||||
)
|
||||
|
||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
||||
self.embedding_func
|
||||
)
|
||||
self.entities_vdb = (
|
||||
self.vector_db_storage_cls(
|
||||
namespace="entities",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"entity_name"},
|
||||
)
|
||||
if self.enable_local
|
||||
else None
|
||||
)
|
||||
self.chunks_vdb = (
|
||||
self.vector_db_storage_cls(
|
||||
namespace="chunks",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
if self.enable_naive_rag
|
||||
else None
|
||||
)
|
||||
|
||||
self.best_model_func = limit_async_func_call(self.best_model_max_async)(
|
||||
partial(self.best_model_func, hashing_kv=self.llm_response_cache)
|
||||
)
|
||||
self.cheap_model_func = limit_async_func_call(self.cheap_model_max_async)(
|
||||
partial(self.cheap_model_func, hashing_kv=self.llm_response_cache)
|
||||
)
|
||||
|
||||
|
||||
|
||||
def insert(self, string_or_strings):
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.ainsert(string_or_strings))
|
||||
|
||||
def query(self, query: str, param: QueryParam = QueryParam()):
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.aquery(query, param))
|
||||
|
||||
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
||||
if param.mode == "local" and not self.enable_local:
|
||||
raise ValueError("enable_local is False, cannot query in local mode")
|
||||
if param.mode == "naive" and not self.enable_naive_rag:
|
||||
raise ValueError("enable_naive_rag is False, cannot query in naive mode")
|
||||
if param.mode == "local":
|
||||
response = await local_query(
|
||||
query,
|
||||
self.chunk_entity_relation_graph,
|
||||
self.entities_vdb,
|
||||
self.community_reports,
|
||||
self.text_chunks,
|
||||
param,
|
||||
self.tokenizer_wrapper,
|
||||
asdict(self),
|
||||
)
|
||||
elif param.mode == "global":
|
||||
response = await global_query(
|
||||
query,
|
||||
self.chunk_entity_relation_graph,
|
||||
self.entities_vdb,
|
||||
self.community_reports,
|
||||
self.text_chunks,
|
||||
param,
|
||||
self.tokenizer_wrapper,
|
||||
asdict(self),
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
response = await naive_query(
|
||||
query,
|
||||
self.chunks_vdb,
|
||||
self.text_chunks,
|
||||
param,
|
||||
self.tokenizer_wrapper,
|
||||
asdict(self),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {param.mode}")
|
||||
await self._query_done()
|
||||
return response
|
||||
|
||||
async def ainsert(self, string_or_strings):
|
||||
await self._insert_start()
|
||||
try:
|
||||
if isinstance(string_or_strings, str):
|
||||
string_or_strings = [string_or_strings]
|
||||
# ---------- new docs
|
||||
new_docs = {
|
||||
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
|
||||
for c in string_or_strings
|
||||
}
|
||||
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
||||
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
||||
if not len(new_docs):
|
||||
logger.warning(f"All docs are already in the storage")
|
||||
return
|
||||
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
||||
|
||||
# ---------- chunking
|
||||
|
||||
inserting_chunks = get_chunks(
|
||||
new_docs=new_docs,
|
||||
chunk_func=self.chunk_func,
|
||||
overlap_token_size=self.chunk_overlap_token_size,
|
||||
max_token_size=self.chunk_token_size,
|
||||
tokenizer_wrapper=self.tokenizer_wrapper,
|
||||
)
|
||||
|
||||
_add_chunk_keys = await self.text_chunks.filter_keys(
|
||||
list(inserting_chunks.keys())
|
||||
)
|
||||
inserting_chunks = {
|
||||
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
||||
}
|
||||
if not len(inserting_chunks):
|
||||
logger.warning(f"All chunks are already in the storage")
|
||||
return
|
||||
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
||||
if self.enable_naive_rag:
|
||||
logger.info("Insert chunks for naive RAG")
|
||||
await self.chunks_vdb.upsert(inserting_chunks)
|
||||
|
||||
# TODO: don't support incremental update for communities now, so we have to drop all
|
||||
await self.community_reports.drop()
|
||||
|
||||
# ---------- extract/summary entity and upsert to graph
|
||||
logger.info("[Entity Extraction]...")
|
||||
maybe_new_kg = await self.entity_extraction_func(
|
||||
inserting_chunks,
|
||||
knwoledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
tokenizer_wrapper=self.tokenizer_wrapper,
|
||||
global_config=asdict(self),
|
||||
using_amazon_bedrock=self.using_amazon_bedrock,
|
||||
)
|
||||
if maybe_new_kg is None:
|
||||
logger.warning("No new entities found")
|
||||
return
|
||||
self.chunk_entity_relation_graph = maybe_new_kg
|
||||
# ---------- update clusterings of graph
|
||||
logger.info("[Community Report]...")
|
||||
await self.chunk_entity_relation_graph.clustering(
|
||||
self.graph_cluster_algorithm
|
||||
)
|
||||
await generate_community_report(
|
||||
self.community_reports, self.chunk_entity_relation_graph, self.tokenizer_wrapper, asdict(self)
|
||||
)
|
||||
|
||||
# ---------- commit upsertings and indexing
|
||||
await self.full_docs.upsert(new_docs)
|
||||
await self.text_chunks.upsert(inserting_chunks)
|
||||
finally:
|
||||
await self._insert_done()
|
||||
|
||||
async def _insert_start(self):
|
||||
tasks = []
|
||||
for storage_inst in [
|
||||
self.chunk_entity_relation_graph,
|
||||
]:
|
||||
if storage_inst is None:
|
||||
continue
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_start_callback())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _insert_done(self):
|
||||
tasks = []
|
||||
for storage_inst in [
|
||||
self.full_docs,
|
||||
self.text_chunks,
|
||||
self.llm_response_cache,
|
||||
self.community_reports,
|
||||
self.entities_vdb,
|
||||
self.chunks_vdb,
|
||||
self.chunk_entity_relation_graph,
|
||||
]:
|
||||
if storage_inst is None:
|
||||
continue
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _query_done(self):
|
||||
tasks = []
|
||||
for storage_inst in [self.llm_response_cache]:
|
||||
if storage_inst is None:
|
||||
continue
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
await asyncio.gather(*tasks)
|
||||
305
rag-web-ui/backend/nano_graphrag/prompt.py
Normal file
305
rag-web-ui/backend/nano_graphrag/prompt.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
GraphRAG core prompts (Chinese, aerospace-oriented).
|
||||
"""
|
||||
|
||||
GRAPH_FIELD_SEP = "<SEP>"
|
||||
PROMPTS = {}
|
||||
|
||||
PROMPTS[
|
||||
"claim_extraction"
|
||||
] = """-任务定位-
|
||||
你是航天知识情报分析助手,负责从文本中抽取实体相关的主张/断言(claim)。
|
||||
|
||||
-目标-
|
||||
给定输入文本、实体约束和主张说明,抽取满足条件的实体及其对应主张,结果必须可溯源。
|
||||
|
||||
-执行步骤-
|
||||
1. 先识别满足实体约束的命名实体。实体约束可能是实体名称列表,也可能是实体类型列表。
|
||||
2. 对步骤1中的每个实体,抽取其作为主语的主张。每条主张需输出:
|
||||
- Subject: 主张主体实体名(大写;必须来自步骤1)
|
||||
- Object: 客体实体名(大写;未知时使用 **NONE**)
|
||||
- Claim Type: 主张类型(大写;应可复用)
|
||||
- Claim Status: **TRUE**、**FALSE** 或 **SUSPECTED**
|
||||
- Claim Description: 说明主张的依据、逻辑和关键证据
|
||||
- Claim Date: 起止时间(ISO-8601);未知时为 **NONE**
|
||||
- Claim Source Text: 与主张直接相关的原文引文(尽量完整)
|
||||
|
||||
格式要求:
|
||||
(<subject_entity>{tuple_delimiter}<object_entity>{tuple_delimiter}<claim_type>{tuple_delimiter}<claim_status>{tuple_delimiter}<claim_start_date>{tuple_delimiter}<claim_end_date>{tuple_delimiter}<claim_description>{tuple_delimiter}<claim_source>)
|
||||
|
||||
3. 使用 **{record_delimiter}** 连接所有记录。
|
||||
4. 结束时输出 {completion_delimiter}。
|
||||
|
||||
-约束-
|
||||
- 仅基于输入文本,不得编造。
|
||||
- 航天语境优先(如航天器、推进系统、姿态控制、任务阶段、地面系统、试验事件、故障模式等)。
|
||||
|
||||
-输入-
|
||||
Entity specification: {entity_specs}
|
||||
Claim description: {claim_description}
|
||||
Text: {input_text}
|
||||
Output: """
|
||||
|
||||
PROMPTS[
|
||||
"community_report"
|
||||
] = """你是航天领域知识图谱分析助手,负责为一个社区(community)生成结构化研判报告。
|
||||
|
||||
# 目标
|
||||
根据给定的实体、关系和可选主张,输出可用于技术评审与任务决策的社区报告。
|
||||
|
||||
# 报告结构
|
||||
必须返回 JSON 字符串,结构如下:
|
||||
{
|
||||
"title": <标题>,
|
||||
"summary": <执行摘要>,
|
||||
"rating": <0-10 浮点评分>,
|
||||
"rating_explanation": <评分说明>,
|
||||
"findings": [
|
||||
{
|
||||
"summary": <要点小结>,
|
||||
"explanation": <详细说明>
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 字段要求
|
||||
- title: 简洁且具体,尽量包含代表性实体。
|
||||
- summary: 说明社区整体结构、核心实体、关键关系与主要风险/价值。
|
||||
- rating: 社区影响度/风险度评分(0-10)。
|
||||
- rating_explanation: 单句说明评分依据。
|
||||
- findings: 5-10 条关键发现,覆盖技术链路、任务影响、可靠性风险、协同关系等。
|
||||
|
||||
# 领域要求(航天)
|
||||
优先关注以下维度:
|
||||
- 任务阶段:论证、研制、总装、测试、发射、入轨、在轨运行、回收/退役
|
||||
- 系统层级:航天器、有效载荷、推进系统、姿态控制、测控通信、地面系统
|
||||
- 关键指标:推力、比冲、功率、带宽、精度、寿命、可靠性、故障率
|
||||
- 风险与依赖:单点故障、接口依赖、时序耦合、供应链与试验验证缺口
|
||||
|
||||
# 证据约束
|
||||
- 仅使用输入中可证据化信息。
|
||||
- 无证据内容不得写入。
|
||||
- 若信息不足,应明确指出不确定性与缺失点。
|
||||
|
||||
# 输入
|
||||
Text:
|
||||
```
|
||||
{input_text}
|
||||
```
|
||||
|
||||
Output:
|
||||
"""
|
||||
|
||||
PROMPTS[
|
||||
"entity_extraction"
|
||||
] = """-任务目标-
|
||||
给定文本与实体类型列表,识别所有相关实体,并抽取实体间“明确存在”的关系。
|
||||
|
||||
-实体类型约束-
|
||||
- entity_type 必须来自给定集合:[{entity_types}]
|
||||
- 若无法确定,使用最接近类型,不可臆造新类型
|
||||
|
||||
-关系类型约束-
|
||||
关系描述必须以“关系类型=<类型>;依据=<说明>”开头。
|
||||
关系类型从以下集合中选择:
|
||||
[组成, 隶属, 控制, 被控制, 供能, 支撑, 测量, 感知, 执行, 通信, 影响, 制约, 因果, 时序前后, 协同, 风险关联, 其他]
|
||||
|
||||
-执行步骤-
|
||||
1. 抽取实体。每个实体输出:
|
||||
- entity_name: 实体名(保留原文专有名词;必要时标准化)
|
||||
- entity_type: 实体类型(必须在给定集合中)
|
||||
- entity_description: 面向航天任务语境的实体描述(属性、职责、行为)
|
||||
|
||||
实体格式:
|
||||
("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>)
|
||||
|
||||
2. 在步骤1实体中,抽取“证据充分且语义明确”的关系对(source_entity, target_entity)。
|
||||
每条关系输出:
|
||||
- source_entity
|
||||
- target_entity
|
||||
- relationship_description: 必须以“关系类型=<类型>;依据=<说明>”开头
|
||||
- relationship_strength: 1-10 数值,表示关系强度
|
||||
|
||||
关系格式:
|
||||
("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_strength>)
|
||||
|
||||
3. 结果使用 **{record_delimiter}** 拼接。
|
||||
4. 结束时输出 {completion_delimiter}。
|
||||
|
||||
-质量规则-
|
||||
- 只抽取文本中可直接支持的实体和关系。
|
||||
- 不输出模糊、猜测或无依据关系。
|
||||
- 对航天语义优先:航天器、分系统、任务阶段、参数指标、故障模式、试验事件。
|
||||
|
||||
-输入-
|
||||
Entity_types: {entity_types}
|
||||
Text: {input_text}
|
||||
Output:
|
||||
"""
|
||||
|
||||
PROMPTS[
|
||||
"summarize_entity_descriptions"
|
||||
] = """你是航天知识库整理助手。
|
||||
给定一个或两个实体名称,以及若干描述片段,请合并为一段一致、完整、可复用的摘要。
|
||||
|
||||
要求:
|
||||
- 覆盖所有有效信息
|
||||
- 若描述冲突,给出最一致、最保守的综合结论
|
||||
- 使用第三人称
|
||||
- 保留实体名,避免指代不清
|
||||
- 优先保留任务阶段、技术指标、系统依赖和风险信息
|
||||
|
||||
#######
|
||||
-输入数据-
|
||||
Entities: {entity_name}
|
||||
Description List: {description_list}
|
||||
#######
|
||||
Output:
|
||||
"""
|
||||
|
||||
PROMPTS[
|
||||
"entiti_continue_extraction"
|
||||
] = """上一轮可能遗漏了实体或关系。请仅补充遗漏项,严格沿用既定输出格式,不要重复已输出记录:"""
|
||||
|
||||
PROMPTS[
|
||||
"entiti_if_loop_extraction"
|
||||
] = """请判断是否仍有遗漏实体或关系。仅回答 YES 或 NO。"""
|
||||
|
||||
PROMPTS["DEFAULT_ENTITY_TYPES"] = [
|
||||
"航天器",
|
||||
"任务",
|
||||
"任务阶段",
|
||||
"有效载荷",
|
||||
"推进系统",
|
||||
"姿态控制系统",
|
||||
"测控通信系统",
|
||||
"电源系统",
|
||||
"热控系统",
|
||||
"结构机构",
|
||||
"传感器",
|
||||
"执行机构",
|
||||
"地面系统",
|
||||
"组织机构",
|
||||
"试验事件",
|
||||
"故障模式",
|
||||
"参数指标",
|
||||
"轨道",
|
||||
"地点",
|
||||
"时间",
|
||||
]
|
||||
PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
|
||||
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
|
||||
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
|
||||
|
||||
PROMPTS[
|
||||
"local_rag_response"
|
||||
] = """---角色---
|
||||
你是航天领域图谱问答助手,擅长基于局部子图进行多跳推理。
|
||||
|
||||
---任务---
|
||||
根据输入的数据表(实体、关系、证据片段)回答用户问题。
|
||||
|
||||
---回答要求---
|
||||
1. 先给结论,再给推理链路。
|
||||
2. 推理链路至少包含:
|
||||
- 关键实体
|
||||
- 关键关系
|
||||
- 中间推断
|
||||
- 最终结论
|
||||
3. 若证据不足,明确说明“证据不足以得出结论”。
|
||||
4. 严禁编造。
|
||||
5. 使用中文、专业且简洁。
|
||||
|
||||
---目标长度与格式---
|
||||
{response_type}
|
||||
|
||||
---输入数据表---
|
||||
{context_data}
|
||||
|
||||
---输出格式建议---
|
||||
- 结论
|
||||
- 推理链路
|
||||
- 证据与不确定性
|
||||
"""
|
||||
|
||||
PROMPTS[
|
||||
"global_map_rag_points"
|
||||
] = """---角色---
|
||||
你是航天知识全局研判助手,负责从社区级信息中提炼关键观点。
|
||||
|
||||
---任务---
|
||||
根据输入数据表,输出一组可用于后续全局汇总的关键点。
|
||||
|
||||
---输出要求---
|
||||
- 必须输出 JSON:
|
||||
{
|
||||
"points": [
|
||||
{"description": "观点描述", "score": 0-100整数}
|
||||
]
|
||||
}
|
||||
- description: 观点需可证据化,优先覆盖跨系统影响、任务阶段耦合、技术风险和性能趋势。
|
||||
- score: 对回答用户问题的重要性分值。
|
||||
- 若无法回答,输出 1 条 score=0 且明确说明信息不足。
|
||||
- 仅使用输入证据,不得编造。
|
||||
|
||||
---输入数据表---
|
||||
{context_data}
|
||||
"""
|
||||
|
||||
PROMPTS[
|
||||
"global_reduce_rag_response"
|
||||
] = """---角色---
|
||||
你是航天领域总师级分析助手,负责融合多位分析员报告并给出全局结论。
|
||||
|
||||
---任务---
|
||||
根据按重要性降序排列的分析员报告,输出面向决策的综合回答。
|
||||
|
||||
---要求---
|
||||
1. 先给总体结论,再给分项分析(趋势、共性风险、关键差异、建议动作)。
|
||||
2. 报告融合时去重、去噪,保留高证据密度信息。
|
||||
3. 必须指出不确定性与信息缺口。
|
||||
4. 仅基于输入报告,不得编造。
|
||||
5. 使用中文,风格专业、严谨。
|
||||
|
||||
---目标长度与格式---
|
||||
{response_type}
|
||||
|
||||
---分析员报告---
|
||||
{report_data}
|
||||
"""
|
||||
|
||||
PROMPTS[
|
||||
"naive_rag_response"
|
||||
] = """你是航天知识问答助手。
|
||||
下面是可用知识:
|
||||
{content_data}
|
||||
---
|
||||
请基于上述知识回答用户问题,要求:
|
||||
- 回答准确、简洁、专业
|
||||
- 若信息不足,明确说明缺失点
|
||||
- 不得编造
|
||||
---目标长度与格式---
|
||||
{response_type}
|
||||
"""
|
||||
|
||||
PROMPTS["fail_response"] = "抱歉,当前无法基于现有信息回答该问题。"
|
||||
|
||||
PROMPTS["process_tickers"] = ["-", "\\", "|", "/"]
|
||||
|
||||
PROMPTS["default_text_separator"] = [
|
||||
"\n\n",
|
||||
"\r\n\r\n",
|
||||
"\n",
|
||||
"\r\n",
|
||||
"。",
|
||||
".",
|
||||
".",
|
||||
"!",
|
||||
"!",
|
||||
"?",
|
||||
"?",
|
||||
" ",
|
||||
"\t",
|
||||
"\u3000",
|
||||
"\u200b",
|
||||
]
|
||||
Reference in New Issue
Block a user