init. project

This commit is contained in:
2026-04-13 11:34:23 +08:00
commit c7c0659a85
202 changed files with 31196 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
from .graphrag import GraphRAG, QueryParam
__version__ = "0.0.8.2"
__author__ = "Jianbai Ye"
__url__ = "https://github.com/gusye1234/nano-graphrag"
# dp stands for data pack

View File

@@ -0,0 +1,301 @@
import json
import numpy as np
from typing import Optional, List, Any, Callable
try:
import aioboto3
except ImportError:
aioboto3 = None
from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
import os
from ._utils import compute_args_hash, wrap_embedding_func_with_attrs
from .base import BaseKVStorage
global_openai_async_client = None
global_azure_openai_async_client = None
global_amazon_bedrock_async_client = None
def get_openai_async_client_instance():
global global_openai_async_client
if global_openai_async_client is None:
global_openai_async_client = AsyncOpenAI()
return global_openai_async_client
def get_azure_openai_async_client_instance():
global global_azure_openai_async_client
if global_azure_openai_async_client is None:
global_azure_openai_async_client = AsyncAzureOpenAI()
return global_azure_openai_async_client
def get_amazon_bedrock_async_client_instance():
global global_amazon_bedrock_async_client
if aioboto3 is None:
raise ImportError(
"aioboto3 is required for Amazon Bedrock support. Install it to use Bedrock providers."
)
if global_amazon_bedrock_async_client is None:
global_amazon_bedrock_async_client = aioboto3.Session()
return global_amazon_bedrock_async_client
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def openai_complete_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
openai_async_client = get_openai_async_client_instance()
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}}
)
await hashing_kv.index_done_callback()
return response.choices[0].message.content
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def amazon_bedrock_complete_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance()
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
messages.extend(history_messages)
messages.append({"role": "user", "content": [{"text": prompt}]})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
inference_config = {
"temperature": 0,
"maxTokens": 4096 if "max_tokens" not in kwargs else kwargs["max_tokens"],
}
async with amazon_bedrock_async_client.client(
"bedrock-runtime",
region_name=os.getenv("AWS_REGION", "us-east-1")
) as bedrock_runtime:
if system_prompt:
response = await bedrock_runtime.converse(
modelId=model, messages=messages, inferenceConfig=inference_config,
system=[{"text": system_prompt}]
)
else:
response = await bedrock_runtime.converse(
modelId=model, messages=messages, inferenceConfig=inference_config,
)
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response["output"]["message"]["content"][0]["text"], "model": model}}
)
await hashing_kv.index_done_callback()
return response["output"]["message"]["content"][0]["text"]
def create_amazon_bedrock_complete_function(model_id: str) -> Callable:
"""
Factory function to dynamically create completion functions for Amazon Bedrock
Args:
model_id (str): Amazon Bedrock model identifier (e.g., "us.anthropic.claude-3-sonnet-20240229-v1:0")
Returns:
Callable: Generated completion function
"""
async def bedrock_complete(
prompt: str,
system_prompt: Optional[str] = None,
history_messages: List[Any] = [],
**kwargs
) -> str:
return await amazon_bedrock_complete_if_cache(
model_id,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs
)
# Set function name for easier debugging
bedrock_complete.__name__ = f"{model_id}_complete"
return bedrock_complete
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-4o",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def amazon_bedrock_embedding(texts: list[str]) -> np.ndarray:
amazon_bedrock_async_client = get_amazon_bedrock_async_client_instance()
async with amazon_bedrock_async_client.client(
"bedrock-runtime",
region_name=os.getenv("AWS_REGION", "us-east-1")
) as bedrock_runtime:
embeddings = []
for text in texts:
body = json.dumps(
{
"inputText": text,
"dimensions": 1024,
}
)
response = await bedrock_runtime.invoke_model(
modelId="amazon.titan-embed-text-v2:0", body=body,
)
response_body = await response.get("body").read()
embeddings.append(json.loads(response_body))
return np.array([dp["embedding"] for dp in embeddings])
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def openai_embedding(texts: list[str]) -> np.ndarray:
openai_async_client = get_openai_async_client_instance()
response = await openai_async_client.embeddings.create(
model="text-embedding-3-small", input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def azure_openai_complete_if_cache(
deployment_name, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
azure_openai_client = get_azure_openai_async_client_instance()
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(deployment_name, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await azure_openai_client.chat.completions.create(
model=deployment_name, messages=messages, **kwargs
)
if hashing_kv is not None:
await hashing_kv.upsert(
{
args_hash: {
"return": response.choices[0].message.content,
"model": deployment_name,
}
}
)
await hashing_kv.index_done_callback()
return response.choices[0].message.content
async def azure_gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await azure_openai_complete_if_cache(
"gpt-4o",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def azure_gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await azure_openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def azure_openai_embedding(texts: list[str]) -> np.ndarray:
azure_openai_client = get_azure_openai_async_client_instance()
response = await azure_openai_client.embeddings.create(
model="text-embedding-3-small", input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,94 @@
from typing import List, Optional, Union, Literal
class SeparatorSplitter:
def __init__(
self,
separators: Optional[List[List[int]]] = None,
keep_separator: Union[bool, Literal["start", "end"]] = "end",
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: callable = len,
):
self._separators = separators or []
self._keep_separator = keep_separator
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
def split_tokens(self, tokens: List[int]) -> List[List[int]]:
splits = self._split_tokens_with_separators(tokens)
return self._merge_splits(splits)
def _split_tokens_with_separators(self, tokens: List[int]) -> List[List[int]]:
splits = []
current_split = []
i = 0
while i < len(tokens):
separator_found = False
for separator in self._separators:
if tokens[i:i+len(separator)] == separator:
if self._keep_separator in [True, "end"]:
current_split.extend(separator)
if current_split:
splits.append(current_split)
current_split = []
if self._keep_separator == "start":
current_split.extend(separator)
i += len(separator)
separator_found = True
break
if not separator_found:
current_split.append(tokens[i])
i += 1
if current_split:
splits.append(current_split)
return [s for s in splits if s]
def _merge_splits(self, splits: List[List[int]]) -> List[List[int]]:
if not splits:
return []
merged_splits = []
current_chunk = []
for split in splits:
if not current_chunk:
current_chunk = split
elif self._length_function(current_chunk) + self._length_function(split) <= self._chunk_size:
current_chunk.extend(split)
else:
merged_splits.append(current_chunk)
current_chunk = split
if current_chunk:
merged_splits.append(current_chunk)
if len(merged_splits) == 1 and self._length_function(merged_splits[0]) > self._chunk_size:
return self._split_chunk(merged_splits[0])
if self._chunk_overlap > 0:
return self._enforce_overlap(merged_splits)
return merged_splits
def _split_chunk(self, chunk: List[int]) -> List[List[int]]:
result = []
for i in range(0, len(chunk), self._chunk_size - self._chunk_overlap):
new_chunk = chunk[i:i + self._chunk_size]
if len(new_chunk) > self._chunk_overlap: # 只有当 chunk 长度大于 overlap 时才添加
result.append(new_chunk)
return result
def _enforce_overlap(self, chunks: List[List[int]]) -> List[List[int]]:
result = []
for i, chunk in enumerate(chunks):
if i == 0:
result.append(chunk)
else:
overlap = chunks[i-1][-self._chunk_overlap:]
new_chunk = overlap + chunk
if self._length_function(new_chunk) > self._chunk_size:
new_chunk = new_chunk[:self._chunk_size]
result.append(new_chunk)
return result

View File

@@ -0,0 +1,9 @@
from .gdb_networkx import NetworkXStorage
from .gdb_neo4j import Neo4jStorage
from .vdb_nanovectordb import NanoVectorDBStorage
from .kv_json import JsonKVStorage
try:
from .vdb_hnswlib import HNSWVectorStorage
except ImportError:
HNSWVectorStorage = None

View File

@@ -0,0 +1,529 @@
import json
import asyncio
from collections import defaultdict
from typing import List
from neo4j import AsyncGraphDatabase
from dataclasses import dataclass
from typing import Union
from ..base import BaseGraphStorage, SingleCommunitySchema
from .._utils import logger
from ..prompt import GRAPH_FIELD_SEP
neo4j_lock = asyncio.Lock()
def make_path_idable(path):
return path.replace(".", "_").replace("/", "__").replace("-", "_").replace(":", "_").replace("\\", "__")
@dataclass
class Neo4jStorage(BaseGraphStorage):
def __post_init__(self):
self.neo4j_url = self.global_config["addon_params"].get("neo4j_url", None)
self.neo4j_auth = self.global_config["addon_params"].get("neo4j_auth", None)
self.namespace = (
f"{make_path_idable(self.global_config['working_dir'])}__{self.namespace}"
)
logger.info(f"Using the label {self.namespace} for Neo4j as identifier")
if self.neo4j_url is None or self.neo4j_auth is None:
raise ValueError("Missing neo4j_url or neo4j_auth in addon_params")
self.async_driver = AsyncGraphDatabase.driver(
self.neo4j_url, auth=self.neo4j_auth, max_connection_pool_size=50,
)
# async def create_database(self):
# async with self.async_driver.session() as session:
# try:
# constraints = await session.run("SHOW CONSTRAINTS")
# # TODO I don't know why CREATE CONSTRAINT IF NOT EXISTS still trigger error
# # so have to check if the constrain exists
# constrain_exists = False
# async for record in constraints:
# if (
# self.namespace in record["labelsOrTypes"]
# and "id" in record["properties"]
# and record["type"] == "UNIQUENESS"
# ):
# constrain_exists = True
# break
# if not constrain_exists:
# await session.run(
# f"CREATE CONSTRAINT FOR (n:{self.namespace}) REQUIRE n.id IS UNIQUE"
# )
# logger.info(f"Add constraint for namespace: {self.namespace}")
# except Exception as e:
# logger.error(f"Error accessing or setting up the database: {str(e)}")
# raise
async def _init_workspace(self):
await self.async_driver.verify_authentication()
await self.async_driver.verify_connectivity()
# TODOLater: create database if not exists always cause an error when async
# await self.create_database()
async def index_start_callback(self):
logger.info("Init Neo4j workspace")
await self._init_workspace()
# create index for faster searching
try:
async with self.async_driver.session() as session:
await session.run(
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.id)"
)
await session.run(
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.entity_type)"
)
await session.run(
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.communityIds)"
)
await session.run(
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.source_id)"
)
logger.info("Neo4j indexes created successfully")
except Exception as e:
logger.error(f"Failed to create indexes: {e}")
raise e
async def has_node(self, node_id: str) -> bool:
async with self.async_driver.session() as session:
result = await session.run(
f"MATCH (n:`{self.namespace}`) WHERE n.id = $node_id RETURN COUNT(n) > 0 AS exists",
node_id=node_id,
)
record = await result.single()
return record["exists"] if record else False
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
async with self.async_driver.session() as session:
result = await session.run(
f"""
MATCH (s:`{self.namespace}`)
WHERE s.id = $source_id
MATCH (t:`{self.namespace}`)
WHERE t.id = $target_id
RETURN EXISTS((s)-[]->(t)) AS exists
""",
source_id=source_node_id,
target_id=target_node_id,
)
record = await result.single()
return record["exists"] if record else False
async def node_degree(self, node_id: str) -> int:
results = await self.node_degrees_batch([node_id])
return results[0] if results else 0
async def node_degrees_batch(self, node_ids: List[str]) -> List[str]:
if not node_ids:
return {}
result_dict = {node_id: 0 for node_id in node_ids}
async with self.async_driver.session() as session:
result = await session.run(
f"""
UNWIND $node_ids AS node_id
MATCH (n:`{self.namespace}`)
WHERE n.id = node_id
OPTIONAL MATCH (n)-[]-(m:`{self.namespace}`)
RETURN node_id, COUNT(m) AS degree
""",
node_ids=node_ids
)
async for record in result:
result_dict[record["node_id"]] = record["degree"]
return [result_dict[node_id] for node_id in node_ids]
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
results = await self.edge_degrees_batch([(src_id, tgt_id)])
return results[0] if results else 0
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> list[int]:
if not edge_pairs:
return []
result_dict = {tuple(edge_pair): 0 for edge_pair in edge_pairs}
edges_params = [{"src_id": src, "tgt_id": tgt} for src, tgt in edge_pairs]
try:
async with self.async_driver.session() as session:
result = await session.run(
f"""
UNWIND $edges AS edge
MATCH (s:`{self.namespace}`)
WHERE s.id = edge.src_id
WITH edge, s
OPTIONAL MATCH (s)-[]-(n1:`{self.namespace}`)
WITH edge, COUNT(n1) AS src_degree
MATCH (t:`{self.namespace}`)
WHERE t.id = edge.tgt_id
WITH edge, src_degree, t
OPTIONAL MATCH (t)-[]-(n2:`{self.namespace}`)
WITH edge.src_id AS src_id, edge.tgt_id AS tgt_id, src_degree, COUNT(n2) AS tgt_degree
RETURN src_id, tgt_id, src_degree + tgt_degree AS degree
""",
edges=edges_params
)
async for record in result:
src_id = record["src_id"]
tgt_id = record["tgt_id"]
degree = record["degree"]
# 更新结果字典
edge_pair = (src_id, tgt_id)
result_dict[edge_pair] = degree
return [result_dict[tuple(edge_pair)] for edge_pair in edge_pairs]
except Exception as e:
logger.error(f"Error in batch edge degree calculation: {e}")
return [0] * len(edge_pairs)
async def get_node(self, node_id: str) -> Union[dict, None]:
result = await self.get_nodes_batch([node_id])
return result[0] if result else None
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, Union[dict, None]]:
if not node_ids:
return {}
result_dict = {node_id: None for node_id in node_ids}
try:
async with self.async_driver.session() as session:
result = await session.run(
f"""
UNWIND $node_ids AS node_id
MATCH (n:`{self.namespace}`)
WHERE n.id = node_id
RETURN node_id, properties(n) AS node_data
""",
node_ids=node_ids
)
async for record in result:
node_id = record["node_id"]
raw_node_data = record["node_data"]
if raw_node_data:
raw_node_data["clusters"] = json.dumps(
[
{
"level": index,
"cluster": cluster_id,
}
for index, cluster_id in enumerate(
raw_node_data.get("communityIds", [])
)
]
)
result_dict[node_id] = raw_node_data
return [result_dict[node_id] for node_id in node_ids]
except Exception as e:
logger.error(f"Error in batch node retrieval: {e}")
raise e
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
results = await self.get_edges_batch([(source_node_id, target_node_id)])
return results[0] if results else None
async def get_edges_batch(
self, edge_pairs: list[tuple[str, str]]
) -> list[Union[dict, None]]:
if not edge_pairs:
return []
result_dict = {tuple(edge_pair): None for edge_pair in edge_pairs}
edges_params = [{"source_id": src, "target_id": tgt} for src, tgt in edge_pairs]
try:
async with self.async_driver.session() as session:
result = await session.run(
f"""
UNWIND $edges AS edge
MATCH (s:`{self.namespace}`)-[r]->(t:`{self.namespace}`)
WHERE s.id = edge.source_id AND t.id = edge.target_id
RETURN edge.source_id AS source_id, edge.target_id AS target_id, properties(r) AS edge_data
""",
edges=edges_params
)
async for record in result:
source_id = record["source_id"]
target_id = record["target_id"]
edge_data = record["edge_data"]
edge_pair = (source_id, target_id)
result_dict[edge_pair] = edge_data
return [result_dict[tuple(edge_pair)] for edge_pair in edge_pairs]
except Exception as e:
logger.error(f"Error in batch edge retrieval: {e}")
return [None] * len(edge_pairs)
async def get_node_edges(
self, source_node_id: str
) -> list[tuple[str, str]]:
results = await self.get_nodes_edges_batch([source_node_id])
return results[0] if results else []
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> list[list[tuple[str, str]]]:
if not node_ids:
return []
result_dict = {node_id: [] for node_id in node_ids}
try:
async with self.async_driver.session() as session:
result = await session.run(
f"""
UNWIND $node_ids AS node_id
MATCH (s:`{self.namespace}`)-[r]->(t:`{self.namespace}`)
WHERE s.id = node_id
RETURN s.id AS source_id, t.id AS target_id
""",
node_ids=node_ids
)
async for record in result:
source_id = record["source_id"]
target_id = record["target_id"]
if source_id in result_dict:
result_dict[source_id].append((source_id, target_id))
return [result_dict[node_id] for node_id in node_ids]
except Exception as e:
logger.error(f"Error in batch node edges retrieval: {e}")
return [[] for _ in node_ids]
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
await self.upsert_nodes_batch([(node_id, node_data)])
async def upsert_nodes_batch(self, nodes_data: list[tuple[str, dict[str, str]]]):
if not nodes_data:
return []
nodes_by_type = {}
for node_id, node_data in nodes_data:
node_type = node_data.get("entity_type", "UNKNOWN").strip('"')
if node_type not in nodes_by_type:
nodes_by_type[node_type] = []
nodes_by_type[node_type].append((node_id, node_data))
async with self.async_driver.session() as session:
for node_type, type_nodes in nodes_by_type.items():
params = [{"id": node_id, "data": node_data} for node_id, node_data in type_nodes]
await session.run(
f"""
UNWIND $nodes AS node
MERGE (n:`{self.namespace}`:`{node_type}` {{id: node.id}})
SET n += node.data
""",
nodes=params
)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
await self.upsert_edges_batch([(source_node_id, target_node_id, edge_data)])
async def upsert_edges_batch(
self, edges_data: list[tuple[str, str, dict[str, str]]]
):
if not edges_data:
return
edges_params = []
for source_id, target_id, edge_data in edges_data:
edge_data_copy = edge_data.copy()
edge_data_copy.setdefault("weight", 0.0)
edges_params.append({
"source_id": source_id,
"target_id": target_id,
"edge_data": edge_data_copy
})
async with self.async_driver.session() as session:
await session.run(
f"""
UNWIND $edges AS edge
MATCH (s:`{self.namespace}`)
WHERE s.id = edge.source_id
WITH edge, s
MATCH (t:`{self.namespace}`)
WHERE t.id = edge.target_id
MERGE (s)-[r:RELATED]->(t)
SET r += edge.edge_data
""",
edges=edges_params
)
async def clustering(self, algorithm: str):
if algorithm != "leiden":
raise ValueError(
f"Clustering algorithm {algorithm} not supported in Neo4j implementation"
)
random_seed = self.global_config["graph_cluster_seed"]
max_level = self.global_config["max_graph_cluster_size"]
async with self.async_driver.session() as session:
try:
# Project the graph with undirected relationships
await session.run(
f"""
CALL gds.graph.project(
'graph_{self.namespace}',
['{self.namespace}'],
{{
RELATED: {{
orientation: 'UNDIRECTED',
properties: ['weight']
}}
}}
)
"""
)
# Run Leiden algorithm
result = await session.run(
f"""
CALL gds.leiden.write(
'graph_{self.namespace}',
{{
writeProperty: 'communityIds',
includeIntermediateCommunities: True,
relationshipWeightProperty: "weight",
maxLevels: {max_level},
tolerance: 0.0001,
gamma: 1.0,
theta: 0.01,
randomSeed: {random_seed}
}}
)
YIELD communityCount, modularities;
"""
)
result = await result.single()
community_count: int = result["communityCount"]
modularities = result["modularities"]
logger.info(
f"Performed graph clustering with {community_count} communities and modularities {modularities}"
)
finally:
# Drop the projected graph
await session.run(f"CALL gds.graph.drop('graph_{self.namespace}')")
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
results = defaultdict(
lambda: dict(
level=None,
title=None,
edges=set(),
nodes=set(),
chunk_ids=set(),
occurrence=0.0,
sub_communities=[],
)
)
async with self.async_driver.session() as session:
# Fetch community data
result = await session.run(
f"""
MATCH (n:`{self.namespace}`)
WITH n, n.communityIds AS communityIds, [(n)-[]-(m:`{self.namespace}`) | m.id] AS connected_nodes
RETURN n.id AS node_id, n.source_id AS source_id,
communityIds AS cluster_key,
connected_nodes
"""
)
# records = await result.fetch()
max_num_ids = 0
async for record in result:
for index, c_id in enumerate(record["cluster_key"]):
node_id = str(record["node_id"])
source_id = record["source_id"]
level = index
cluster_key = str(c_id)
connected_nodes = record["connected_nodes"]
results[cluster_key]["level"] = level
results[cluster_key]["title"] = f"Cluster {cluster_key}"
results[cluster_key]["nodes"].add(node_id)
results[cluster_key]["edges"].update(
[
tuple(sorted([node_id, str(connected)]))
for connected in connected_nodes
if connected != node_id
]
)
chunk_ids = source_id.split(GRAPH_FIELD_SEP)
results[cluster_key]["chunk_ids"].update(chunk_ids)
max_num_ids = max(
max_num_ids, len(results[cluster_key]["chunk_ids"])
)
# Process results
for k, v in results.items():
v["edges"] = [list(e) for e in v["edges"]]
v["nodes"] = list(v["nodes"])
v["chunk_ids"] = list(v["chunk_ids"])
v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
# Compute sub-communities (this is a simplified approach)
for cluster in results.values():
cluster["sub_communities"] = [
sub_key
for sub_key, sub_cluster in results.items()
if sub_cluster["level"] > cluster["level"]
and set(sub_cluster["nodes"]).issubset(set(cluster["nodes"]))
]
return dict(results)
async def index_done_callback(self):
await self.async_driver.close()
async def _debug_delete_all_node_edges(self):
async with self.async_driver.session() as session:
try:
# Delete all relationships in the namespace
await session.run(f"MATCH (n:`{self.namespace}`)-[r]-() DELETE r")
# Delete all nodes in the namespace
await session.run(f"MATCH (n:`{self.namespace}`) DELETE n")
logger.info(
f"All nodes and edges in namespace '{self.namespace}' have been deleted."
)
except Exception as e:
logger.error(f"Error deleting nodes and edges: {str(e)}")
raise

View File

@@ -0,0 +1,268 @@
import html
import json
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Union, cast, List
import networkx as nx
import numpy as np
import asyncio
from .._utils import logger
from ..base import (
BaseGraphStorage,
SingleCommunitySchema,
)
from ..prompt import GRAPH_FIELD_SEP
@dataclass
class NetworkXStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name) -> nx.Graph:
if os.path.exists(file_name):
return nx.read_graphml(file_name)
return None
@staticmethod
def write_nx_graph(graph: nx.Graph, file_name):
logger.info(
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
)
nx.write_graphml(graph, file_name)
@staticmethod
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
"""
from graspologic.utils import largest_connected_component
graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph))
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
graph = nx.relabel_nodes(graph, node_mapping)
return NetworkXStorage._stabilize_graph(graph)
@staticmethod
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Ensure an undirected graph with the same relationships will always be read the same way.
"""
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
sorted_nodes = graph.nodes(data=True)
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
fixed_graph.add_nodes_from(sorted_nodes)
edges = list(graph.edges(data=True))
if not graph.is_directed():
def _sort_source_target(edge):
source, target, edge_data = edge
if source > target:
temp = source
source = target
target = temp
return source, target, edge_data
edges = [_sort_source_target(edge) for edge in edges]
def _get_edge_key(source: Any, target: Any) -> str:
return f"{source} -> {target}"
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
fixed_graph.add_edges_from(edges)
return fixed_graph
def __post_init__(self):
self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
)
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
self._graph = preloaded_graph or nx.Graph()
self._clustering_algorithms = {
"leiden": self._leiden_clustering,
}
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def index_done_callback(self):
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool:
return self._graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> Union[dict, None]:
return self._graph.nodes.get(node_id)
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, Union[dict, None]]:
return await asyncio.gather(*[self.get_node(node_id) for node_id in node_ids])
async def node_degree(self, node_id: str) -> int:
# [numberchiffre]: node_id not part of graph returns `DegreeView({})` instead of 0
return self._graph.degree(node_id) if self._graph.has_node(node_id) else 0
async def node_degrees_batch(self, node_ids: List[str]) -> List[str]:
return await asyncio.gather(*[self.node_degree(node_id) for node_id in node_ids])
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
return (self._graph.degree(src_id) if self._graph.has_node(src_id) else 0) + (
self._graph.degree(tgt_id) if self._graph.has_node(tgt_id) else 0
)
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> list[int]:
return await asyncio.gather(*[self.edge_degree(src_id, tgt_id) for src_id, tgt_id in edge_pairs])
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
return self._graph.edges.get((source_node_id, target_node_id))
async def get_edges_batch(
self, edge_pairs: list[tuple[str, str]]
) -> list[Union[dict, None]]:
return await asyncio.gather(*[self.get_edge(source_node_id, target_node_id) for source_node_id, target_node_id in edge_pairs])
async def get_node_edges(self, source_node_id: str):
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
return None
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> list[list[tuple[str, str]]]:
return await asyncio.gather(*[self.get_node_edges(node_id) for node_id
in node_ids])
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
self._graph.add_node(node_id, **node_data)
async def upsert_nodes_batch(self, nodes_data: list[tuple[str, dict[str, str]]]):
await asyncio.gather(*[self.upsert_node(node_id, node_data) for node_id, node_data in nodes_data])
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def upsert_edges_batch(
self, edges_data: list[tuple[str, str, dict[str, str]]]
):
await asyncio.gather(*[self.upsert_edge(source_node_id, target_node_id, edge_data)
for source_node_id, target_node_id, edge_data in edges_data])
async def clustering(self, algorithm: str):
if algorithm not in self._clustering_algorithms:
raise ValueError(f"Clustering algorithm {algorithm} not supported")
await self._clustering_algorithms[algorithm]()
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
results = defaultdict(
lambda: dict(
level=None,
title=None,
edges=set(),
nodes=set(),
chunk_ids=set(),
occurrence=0.0,
sub_communities=[],
)
)
max_num_ids = 0
levels = defaultdict(set)
for node_id, node_data in self._graph.nodes(data=True):
if "clusters" not in node_data:
continue
clusters = json.loads(node_data["clusters"])
this_node_edges = self._graph.edges(node_id)
for cluster in clusters:
level = cluster["level"]
cluster_key = str(cluster["cluster"])
levels[level].add(cluster_key)
results[cluster_key]["level"] = level
results[cluster_key]["title"] = f"Cluster {cluster_key}"
results[cluster_key]["nodes"].add(node_id)
results[cluster_key]["edges"].update(
[tuple(sorted(e)) for e in this_node_edges]
)
results[cluster_key]["chunk_ids"].update(
node_data["source_id"].split(GRAPH_FIELD_SEP)
)
max_num_ids = max(max_num_ids, len(results[cluster_key]["chunk_ids"]))
ordered_levels = sorted(levels.keys())
for i, curr_level in enumerate(ordered_levels[:-1]):
next_level = ordered_levels[i + 1]
this_level_comms = levels[curr_level]
next_level_comms = levels[next_level]
# compute the sub-communities by nodes intersection
for comm in this_level_comms:
results[comm]["sub_communities"] = [
c
for c in next_level_comms
if results[c]["nodes"].issubset(results[comm]["nodes"])
]
for k, v in results.items():
v["edges"] = list(v["edges"])
v["edges"] = [list(e) for e in v["edges"]]
v["nodes"] = list(v["nodes"])
v["chunk_ids"] = list(v["chunk_ids"])
v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
return dict(results)
def _cluster_data_to_subgraphs(self, cluster_data: dict[str, list[dict[str, str]]]):
for node_id, clusters in cluster_data.items():
self._graph.nodes[node_id]["clusters"] = json.dumps(clusters)
async def _leiden_clustering(self):
from graspologic.partition import hierarchical_leiden
graph = NetworkXStorage.stable_largest_connected_component(self._graph)
community_mapping = hierarchical_leiden(
graph,
max_cluster_size=self.global_config["max_graph_cluster_size"],
random_seed=self.global_config["graph_cluster_seed"],
)
node_communities: dict[str, list[dict[str, str]]] = defaultdict(list)
__levels = defaultdict(set)
for partition in community_mapping:
level_key = partition.level
cluster_id = partition.cluster
node_communities[partition.node].append(
{"level": level_key, "cluster": cluster_id}
)
__levels[level_key].add(cluster_id)
node_communities = dict(node_communities)
__levels = {k: len(v) for k, v in __levels.items()}
logger.info(f"Each level has communities: {dict(__levels)}")
self._cluster_data_to_subgraphs(node_communities)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
async def _node2vec_embed(self):
from graspologic import embed
embeddings, nodes = embed.node2vec_embed(
self._graph,
**self.global_config["node2vec_params"],
)
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids

View File

@@ -0,0 +1,46 @@
import os
from dataclasses import dataclass
from .._utils import load_json, logger, write_json
from ..base import (
BaseKVStorage,
)
@dataclass
class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data = load_json(self._file_name) or {}
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
async def all_keys(self) -> list[str]:
return list(self._data.keys())
async def index_done_callback(self):
write_json(self._data, self._file_name)
async def get_by_id(self, id):
return self._data.get(id, None)
async def get_by_ids(self, ids, fields=None):
if fields is None:
return [self._data.get(id, None) for id in ids]
return [
(
{k: v for k, v in self._data[id].items() if k in fields}
if self._data.get(id, None)
else None
)
for id in ids
]
async def filter_keys(self, data: list[str]) -> set[str]:
return set([s for s in data if s not in self._data])
async def upsert(self, data: dict[str, dict]):
self._data.update(data)
async def drop(self):
self._data = {}

View File

@@ -0,0 +1,141 @@
import asyncio
import os
from dataclasses import dataclass, field
from typing import Any
import pickle
import hnswlib
import numpy as np
import xxhash
from .._utils import logger
from ..base import BaseVectorStorage
@dataclass
class HNSWVectorStorage(BaseVectorStorage):
ef_construction: int = 100
M: int = 16
max_elements: int = 1000000
ef_search: int = 50
num_threads: int = -1
_index: Any = field(init=False)
_metadata: dict[str, dict] = field(default_factory=dict)
_current_elements: int = 0
def __post_init__(self):
self._index_file_name = os.path.join(
self.global_config["working_dir"], f"{self.namespace}_hnsw.index"
)
self._metadata_file_name = os.path.join(
self.global_config["working_dir"], f"{self.namespace}_hnsw_metadata.pkl"
)
self._embedding_batch_num = self.global_config.get("embedding_batch_num", 100)
hnsw_params = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.ef_construction = hnsw_params.get("ef_construction", self.ef_construction)
self.M = hnsw_params.get("M", self.M)
self.max_elements = hnsw_params.get("max_elements", self.max_elements)
self.ef_search = hnsw_params.get("ef_search", self.ef_search)
self.num_threads = hnsw_params.get("num_threads", self.num_threads)
self._index = hnswlib.Index(
space="cosine", dim=self.embedding_func.embedding_dim
)
if os.path.exists(self._index_file_name) and os.path.exists(
self._metadata_file_name
):
self._index.load_index(
self._index_file_name, max_elements=self.max_elements
)
with open(self._metadata_file_name, "rb") as f:
self._metadata, self._current_elements = pickle.load(f)
logger.info(
f"Loaded existing index for {self.namespace} with {self._current_elements} elements"
)
else:
self._index.init_index(
max_elements=self.max_elements,
ef_construction=self.ef_construction,
M=self.M,
)
self._index.set_ef(self.ef_search)
self._metadata = {}
self._current_elements = 0
logger.info(f"Created new index for {self.namespace}")
async def upsert(self, data: dict[str, dict]) -> np.ndarray:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not data:
logger.warning("You insert an empty data to vector DB")
return []
if self._current_elements + len(data) > self.max_elements:
raise ValueError(
f"Cannot insert {len(data)} elements. Current: {self._current_elements}, Max: {self.max_elements}"
)
list_data = [
{
"id": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batch_size = min(self._embedding_batch_num, len(contents))
embeddings = np.concatenate(
await asyncio.gather(
*[
self.embedding_func(contents[i : i + batch_size])
for i in range(0, len(contents), batch_size)
]
)
)
ids = np.fromiter(
(xxhash.xxh32_intdigest(d["id"].encode()) for d in list_data),
dtype=np.uint32,
count=len(list_data),
)
self._metadata.update(
{
id_int: {
k: v for k, v in d.items() if k in self.meta_fields or k == "id"
}
for id_int, d in zip(ids, list_data)
}
)
self._index.add_items(data=embeddings, ids=ids, num_threads=self.num_threads)
self._current_elements = self._index.get_current_count()
return ids
async def query(self, query: str, top_k: int = 5) -> list[dict]:
if self._current_elements == 0:
return []
top_k = min(top_k, self._current_elements)
if top_k > self.ef_search:
logger.warning(
f"Setting ef_search to {top_k} because top_k is larger than ef_search"
)
self._index.set_ef(top_k)
embedding = await self.embedding_func([query])
labels, distances = self._index.knn_query(
data=embedding[0], k=top_k, num_threads=self.num_threads
)
return [
{
**self._metadata.get(label, {}),
"distance": distance,
"similarity": 1 - distance,
}
for label, distance in zip(labels[0], distances[0])
]
async def index_done_callback(self):
self._index.save_index(self._index_file_name)
with open(self._metadata_file_name, "wb") as f:
pickle.dump((self._metadata, self._current_elements), f)

View File

@@ -0,0 +1,68 @@
import asyncio
import os
from dataclasses import dataclass
import numpy as np
from nano_vectordb import NanoVectorDB
from .._utils import logger
from ..base import BaseVectorStorage
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
self.cosine_better_than_threshold = self.global_config.get(
"query_better_than_threshold", self.cosine_better_than_threshold
)
async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
list_data = [
{
"__id__": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data)
return results
async def query(self, query: str, top_k=5):
embedding = await self.embedding_func([query])
embedding = embedding[0]
results = self._client.query(
query=embedding,
top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold,
)
results = [
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
]
return results
async def index_done_callback(self):
self._client.save()

View File

@@ -0,0 +1,307 @@
import asyncio
import html
import json
import logging
import os
import re
import numbers
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Union, Literal
import numpy as np
import tiktoken
try:
from transformers import AutoTokenizer
except ImportError:
AutoTokenizer = None
logger = logging.getLogger("nano-graphrag")
logging.getLogger("neo4j").setLevel(logging.ERROR)
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
# If there is already an event loop, use it.
loop = asyncio.get_event_loop()
except RuntimeError:
# If in a sub-thread, create a new event loop.
logger.info("Creating a new event loop in a sub-thread.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
def extract_first_complete_json(s: str):
"""Extract the first complete JSON object from the string using a stack to track braces."""
stack = []
first_json_start = None
for i, char in enumerate(s):
if char == '{':
stack.append(i)
if first_json_start is None:
first_json_start = i
elif char == '}':
if stack:
start = stack.pop()
if not stack:
first_json_str = s[first_json_start:i+1]
try:
# Attempt to parse the JSON string
return json.loads(first_json_str.replace("\n", ""))
except json.JSONDecodeError as e:
logger.error(f"JSON decoding failed: {e}. Attempted string: {first_json_str[:50]}...")
return None
finally:
first_json_start = None
logger.warning("No complete JSON object found in the input string.")
return None
def parse_value(value: str):
"""Convert a string value to its appropriate type (int, float, bool, None, or keep as string). Work as a more broad 'eval()'"""
value = value.strip()
if value == "null":
return None
elif value == "true":
return True
elif value == "false":
return False
else:
# Try to convert to int or float
try:
if '.' in value: # If there's a dot, it might be a float
return float(value)
else:
return int(value)
except ValueError:
# If conversion fails, return the value as-is (likely a string)
return value.strip('"') # Remove surrounding quotes if they exist
def extract_values_from_json(json_string, keys=["reasoning", "answer", "data"], allow_no_quotes=False):
"""Extract key values from a non-standard or malformed JSON string, handling nested objects."""
extracted_values = {}
# Enhanced pattern to match both quoted and unquoted values, as well as nested objects
regex_pattern = r'(?P<key>"?\w+"?)\s*:\s*(?P<value>{[^}]*}|".*?"|[^,}]+)'
for match in re.finditer(regex_pattern, json_string, re.DOTALL):
key = match.group('key').strip('"') # Strip quotes from key
value = match.group('value').strip()
# If the value is another nested JSON (starts with '{' and ends with '}'), recursively parse it
if value.startswith('{') and value.endswith('}'):
extracted_values[key] = extract_values_from_json(value)
else:
# Parse the value into the appropriate type (int, float, bool, etc.)
extracted_values[key] = parse_value(value)
if not extracted_values:
logger.warning("No values could be extracted from the string.")
return extracted_values
def convert_response_to_json(response: str) -> dict:
"""Convert response string to JSON, with error handling and fallback to non-standard JSON extraction."""
prediction_json = extract_first_complete_json(response)
if prediction_json is None:
logger.info("Attempting to extract values from a non-standard JSON string...")
prediction_json = extract_values_from_json(response, allow_no_quotes=True)
if not prediction_json:
logger.error("Unable to extract meaningful data from the response.")
else:
logger.info("JSON data successfully extracted.")
return prediction_json
class TokenizerWrapper:
def __init__(self, tokenizer_type: Literal["tiktoken", "huggingface"] = "tiktoken", model_name: str = "gpt-4o"):
self.tokenizer_type = tokenizer_type
self.model_name = model_name
self._tokenizer = None
self._lazy_load_tokenizer()
def _lazy_load_tokenizer(self):
if self._tokenizer is not None:
return
logger.info(f"Loading tokenizer: type='{self.tokenizer_type}', name='{self.model_name}'")
if self.tokenizer_type == "tiktoken":
self._tokenizer = tiktoken.encoding_for_model(self.model_name)
elif self.tokenizer_type == "huggingface":
if AutoTokenizer is None:
raise ImportError("`transformers` is not installed. Please install it via `pip install transformers` to use HuggingFace tokenizers.")
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
else:
raise ValueError(f"Unknown tokenizer_type: {self.tokenizer_type}")
def get_tokenizer(self):
"""提供对底层 tokenizer 对象的访问,用于特殊情况(如 decode_batch"""
self._lazy_load_tokenizer()
return self._tokenizer
def encode(self, text: str) -> list[int]:
self._lazy_load_tokenizer()
return self._tokenizer.encode(text)
def decode(self, tokens: list[int]) -> str:
self._lazy_load_tokenizer()
return self._tokenizer.decode(tokens)
# +++ 新增 +++: 增加一个批量解码的方法以提高效率,并保持接口一致性
def decode_batch(self, tokens_list: list[list[int]]) -> list[str]:
self._lazy_load_tokenizer()
# HuggingFace tokenizer 有 decode_batch但 tiktoken 没有,我们用列表推导来模拟
if self.tokenizer_type == "tiktoken":
return [self._tokenizer.decode(tokens) for tokens in tokens_list]
elif self.tokenizer_type == "huggingface":
return self._tokenizer.batch_decode(tokens_list, skip_special_tokens=True)
else:
raise ValueError(f"Unknown tokenizer_type: {self.tokenizer_type}")
def truncate_list_by_token_size(
list_data: list,
key: callable,
max_token_size: int,
tokenizer_wrapper: TokenizerWrapper
):
"""Truncate a list of data by token size using a provided tokenizer wrapper."""
if max_token_size <= 0:
return []
tokens = 0
for i, data in enumerate(list_data):
tokens += len(tokenizer_wrapper.encode(key(data))) + 1 # 防御性,模拟通过\n拼接列表的情况
if tokens > max_token_size:
return list_data[:i]
return list_data
def compute_mdhash_id(content, prefix: str = ""):
return prefix + md5(content.encode()).hexdigest()
def write_json(json_obj, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False)
def load_json(file_name):
if not os.path.exists(file_name):
return None
with open(file_name, encoding="utf-8") as f:
return json.load(f)
# it's dirty to type, so it's a good way to have fun
def pack_user_ass_to_openai_messages(prompt: str, generated_content: str, using_amazon_bedrock: bool):
if using_amazon_bedrock:
return [
{"role": "user", "content": [{"text": prompt}]},
{"role": "assistant", "content": [{"text": generated_content}]},
]
else:
return [
{"role": "user", "content": prompt},
{"role": "assistant", "content": generated_content},
]
def is_float_regex(value):
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest()
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
"""Split a string by multiple markers"""
if not markers:
return [content]
results = re.split("|".join(re.escape(marker) for marker in markers), content)
return [r.strip() for r in results if r.strip()]
def enclose_string_with_quotes(content: Any) -> str:
"""Enclose a string with quotes"""
if isinstance(content, numbers.Number):
return str(content)
content = str(content)
content = content.strip().strip("'").strip('"')
return f'"{content}"'
def list_of_list_to_csv(data: list[list]):
return "\n".join(
[
",\t".join([f"{enclose_string_with_quotes(data_dd)}" for data_dd in data_d])
for data_d in data
]
)
# -----------------------------------------------------------------------------------
# Refer the utils functions of the official GraphRAG implementation:
# https://github.com/microsoft/graphrag
def clean_str(input: Any) -> str:
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
# If we get non-string input, just give it back
if not isinstance(input, str):
return input
result = html.unescape(input.strip())
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
# Utils types -----------------------------------------------------------------------
@dataclass
class EmbeddingFunc:
embedding_dim: int
max_token_size: int
func: callable
async def __call__(self, *args, **kwargs) -> np.ndarray:
return await self.func(*args, **kwargs)
# Decorators ------------------------------------------------------------------------
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
"""Add restriction of maximum async calling times for a async func"""
def final_decro(func):
"""Not using async.Semaphore to aovid use nest-asyncio"""
__current_size = 0
@wraps(func)
async def wait_func(*args, **kwargs):
nonlocal __current_size
while __current_size >= max_size:
await asyncio.sleep(waitting_time)
__current_size += 1
result = await func(*args, **kwargs)
__current_size -= 1
return result
return wait_func
return final_decro
def wrap_embedding_func_with_attrs(**kwargs):
"""Wrap a function with attributes"""
def final_decro(func) -> EmbeddingFunc:
new_func = EmbeddingFunc(**kwargs, func=func)
return new_func
return final_decro

View File

@@ -0,0 +1,186 @@
from dataclasses import dataclass, field
from typing import TypedDict, Union, Literal, Generic, TypeVar, List
import numpy as np
from ._utils import EmbeddingFunc
@dataclass
class QueryParam:
mode: Literal["local", "global", "naive"] = "global"
only_need_context: bool = False
response_type: str = "Multiple Paragraphs"
level: int = 2
top_k: int = 20
# naive search
naive_max_token_for_text_unit = 12000
# local search
local_max_token_for_text_unit: int = 4000 # 12000 * 0.33
local_max_token_for_local_context: int = 4800 # 12000 * 0.4
local_max_token_for_community_report: int = 3200 # 12000 * 0.27
local_community_single_one: bool = False
# global search
global_min_community_rating: float = 0
global_max_consider_community: float = 512
global_max_token_for_community_report: int = 16384
global_special_community_map_llm_kwargs: dict = field(
default_factory=lambda: {"response_format": {"type": "json_object"}}
)
TextChunkSchema = TypedDict(
"TextChunkSchema",
{"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
)
SingleCommunitySchema = TypedDict(
"SingleCommunitySchema",
{
"level": int,
"title": str,
"edges": list[list[str, str]],
"nodes": list[str],
"chunk_ids": list[str],
"occurrence": float,
"sub_communities": list[str],
},
)
class CommunitySchema(SingleCommunitySchema):
report_string: str
report_json: dict
T = TypeVar("T")
@dataclass
class StorageNameSpace:
namespace: str
global_config: dict
async def index_start_callback(self):
"""commit the storage operations after indexing"""
pass
async def index_done_callback(self):
"""commit the storage operations after indexing"""
pass
async def query_done_callback(self):
"""commit the storage operations after querying"""
pass
@dataclass
class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
meta_fields: set = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict]:
raise NotImplementedError
async def upsert(self, data: dict[str, dict]):
"""Use 'content' field from value for embedding, use key as id.
If embedding_func is None, use 'embedding' field from value
"""
raise NotImplementedError
@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
async def all_keys(self) -> list[str]:
raise NotImplementedError
async def get_by_id(self, id: str) -> Union[T, None]:
raise NotImplementedError
async def get_by_ids(
self, ids: list[str], fields: Union[set[str], None] = None
) -> list[Union[T, None]]:
raise NotImplementedError
async def filter_keys(self, data: list[str]) -> set[str]:
"""return un-exist keys"""
raise NotImplementedError
async def upsert(self, data: dict[str, T]):
raise NotImplementedError
async def drop(self):
raise NotImplementedError
@dataclass
class BaseGraphStorage(StorageNameSpace):
async def has_node(self, node_id: str) -> bool:
raise NotImplementedError
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
raise NotImplementedError
async def node_degree(self, node_id: str) -> int:
raise NotImplementedError
async def node_degrees_batch(self, node_ids: List[str]) -> List[str]:
raise NotImplementedError
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
raise NotImplementedError
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> list[int]:
raise NotImplementedError
async def get_node(self, node_id: str) -> Union[dict, None]:
raise NotImplementedError
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, Union[dict, None]]:
raise NotImplementedError
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
raise NotImplementedError
async def get_edges_batch(
self, edge_pairs: list[tuple[str, str]]
) -> list[Union[dict, None]]:
raise NotImplementedError
async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
raise NotImplementedError
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> list[list[tuple[str, str]]]:
raise NotImplementedError
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
raise NotImplementedError
async def upsert_nodes_batch(self, nodes_data: list[tuple[str, dict[str, str]]]):
raise NotImplementedError
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
raise NotImplementedError
async def upsert_edges_batch(
self, edges_data: list[tuple[str, str, dict[str, str]]]
):
raise NotImplementedError
async def clustering(self, algorithm: str):
raise NotImplementedError
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
"""Return the community representation with report and nodes"""
raise NotImplementedError
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
raise NotImplementedError("Node embedding is not used in nano-graphrag.")

View File

@@ -0,0 +1,171 @@
from typing import Union
import pickle
import asyncio
from openai import BadRequestError
from collections import defaultdict
import dspy
from nano_graphrag.base import (
BaseGraphStorage,
BaseVectorStorage,
TextChunkSchema,
)
from nano_graphrag.prompt import PROMPTS
from nano_graphrag._utils import logger, compute_mdhash_id
from nano_graphrag.entity_extraction.module import TypedEntityRelationshipExtractor
from nano_graphrag._op import _merge_edges_then_upsert, _merge_nodes_then_upsert
async def generate_dataset(
chunks: dict[str, TextChunkSchema],
filepath: str,
save_dataset: bool = True,
global_config: dict = {},
) -> list[dspy.Example]:
entity_extractor = TypedEntityRelationshipExtractor(num_refine_turns=1, self_refine=True)
if global_config.get("use_compiled_dspy_entity_relationship", False):
entity_extractor.load(global_config["entity_relationship_module_path"])
ordered_chunks = list(chunks.items())
already_processed = 0
already_entities = 0
already_relations = 0
async def _process_single_content(
chunk_key_dp: tuple[str, TextChunkSchema]
) -> dspy.Example:
nonlocal already_processed, already_entities, already_relations
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
try:
prediction = await asyncio.to_thread(entity_extractor, input_text=content)
entities, relationships = prediction.entities, prediction.relationships
except BadRequestError as e:
logger.error(f"Error in TypedEntityRelationshipExtractor: {e}")
entities, relationships = [], []
example = dspy.Example(
input_text=content, entities=entities, relationships=relationships
).with_inputs("input_text")
already_entities += len(entities)
already_relations += len(relationships)
already_processed += 1
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
print(
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
end="",
flush=True,
)
return example
examples = await asyncio.gather(
*[_process_single_content(c) for c in ordered_chunks]
)
filtered_examples = [
example
for example in examples
if len(example.entities) > 0 and len(example.relationships) > 0
]
num_filtered_examples = len(examples) - len(filtered_examples)
if save_dataset:
with open(filepath, "wb") as f:
pickle.dump(filtered_examples, f)
logger.info(
f"Saved {len(filtered_examples)} examples with keys: {filtered_examples[0].keys()}, filtered {num_filtered_examples} examples"
)
return filtered_examples
async def extract_entities_dspy(
chunks: dict[str, TextChunkSchema],
knwoledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
global_config: dict,
) -> Union[BaseGraphStorage, None]:
entity_extractor = TypedEntityRelationshipExtractor(num_refine_turns=1, self_refine=True)
if global_config.get("use_compiled_dspy_entity_relationship", False):
entity_extractor.load(global_config["entity_relationship_module_path"])
ordered_chunks = list(chunks.items())
already_processed = 0
already_entities = 0
already_relations = 0
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
nonlocal already_processed, already_entities, already_relations
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
try:
prediction = await asyncio.to_thread(entity_extractor, input_text=content)
entities, relationships = prediction.entities, prediction.relationships
except BadRequestError as e:
logger.error(f"Error in TypedEntityRelationshipExtractor: {e}")
entities, relationships = [], []
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for entity in entities:
entity["source_id"] = chunk_key
maybe_nodes[entity["entity_name"]].append(entity)
already_entities += 1
for relationship in relationships:
relationship["source_id"] = chunk_key
maybe_edges[(relationship["src_id"], relationship["tgt_id"])].append(
relationship
)
already_relations += 1
already_processed += 1
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
print(
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
end="",
flush=True,
)
return dict(maybe_nodes), dict(maybe_edges)
results = await asyncio.gather(
*[_process_single_content(c) for c in ordered_chunks]
)
print()
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for m_nodes, m_edges in results:
for k, v in m_nodes.items():
maybe_nodes[k].extend(v)
for k, v in m_edges.items():
maybe_edges[k].extend(v)
all_entities_data = await asyncio.gather(
*[
_merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
for k, v in maybe_nodes.items()
]
)
await asyncio.gather(
*[
_merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
for k, v in maybe_edges.items()
]
)
if not len(all_entities_data):
logger.warning("Didn't extract any entities, maybe your LLM is not working")
return None
if entity_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
}
for dp in all_entities_data
}
await entity_vdb.upsert(data_for_vdb)
return knwoledge_graph_inst

View File

@@ -0,0 +1,62 @@
import dspy
from nano_graphrag.entity_extraction.module import Relationship
class AssessRelationships(dspy.Signature):
"""
Assess the similarity between gold and predicted relationships:
1. Match relationships based on src_id and tgt_id pairs, allowing for slight variations in entity names.
2. For matched pairs, compare:
a) Description similarity (semantic meaning)
b) Weight similarity
c) Order similarity
3. Consider unmatched relationships as penalties.
4. Aggregate scores, accounting for precision and recall.
5. Return a final similarity score between 0 (no similarity) and 1 (perfect match).
Key considerations:
- Prioritize matching based on entity pairs over exact string matches.
- Use semantic similarity for descriptions rather than exact matches.
- Weight the importance of different aspects (e.g., entity matching, description, weight, order).
- Balance the impact of matched and unmatched relationships in the final score.
"""
gold_relationships: list[Relationship] = dspy.InputField(
desc="The gold-standard relationships to compare against."
)
predicted_relationships: list[Relationship] = dspy.InputField(
desc="The predicted relationships to compare against the gold-standard relationships."
)
similarity_score: float = dspy.OutputField(
desc="Similarity score between 0 and 1, with 1 being the highest similarity."
)
def relationships_similarity_metric(
gold: dspy.Example, pred: dspy.Prediction, trace=None
) -> float:
model = dspy.ChainOfThought(AssessRelationships)
gold_relationships = [Relationship(**item) for item in gold["relationships"]]
predicted_relationships = [Relationship(**item) for item in pred["relationships"]]
similarity_score = float(
model(
gold_relationships=gold_relationships,
predicted_relationships=predicted_relationships,
).similarity_score
)
return similarity_score
def entity_recall_metric(
gold: dspy.Example, pred: dspy.Prediction, trace=None
) -> float:
true_set = set(item["entity_name"] for item in gold["entities"])
pred_set = set(item["entity_name"] for item in pred["entities"])
true_positives = len(pred_set.intersection(true_set))
false_negatives = len(true_set - pred_set)
recall = (
true_positives / (true_positives + false_negatives)
if (true_positives + false_negatives) > 0
else 0
)
return recall

View File

@@ -0,0 +1,330 @@
import dspy
from pydantic import BaseModel, Field
from nano_graphrag._utils import clean_str
from nano_graphrag._utils import logger
"""
Obtained from:
https://github.com/SciPhi-AI/R2R/blob/6e958d1e451c1cb10b6fc868572659785d1091cb/r2r/providers/prompts/defaults.jsonl
"""
ENTITY_TYPES = [
"PERSON",
"ORGANIZATION",
"LOCATION",
"DATE",
"TIME",
"MONEY",
"PERCENTAGE",
"PRODUCT",
"EVENT",
"LANGUAGE",
"NATIONALITY",
"RELIGION",
"TITLE",
"PROFESSION",
"ANIMAL",
"PLANT",
"DISEASE",
"MEDICATION",
"CHEMICAL",
"MATERIAL",
"COLOR",
"SHAPE",
"MEASUREMENT",
"WEATHER",
"NATURAL_DISASTER",
"AWARD",
"LAW",
"CRIME",
"TECHNOLOGY",
"SOFTWARE",
"HARDWARE",
"VEHICLE",
"FOOD",
"DRINK",
"SPORT",
"MUSIC_GENRE",
"INSTRUMENT",
"ARTWORK",
"BOOK",
"MOVIE",
"TV_SHOW",
"ACADEMIC_SUBJECT",
"SCIENTIFIC_THEORY",
"POLITICAL_PARTY",
"CURRENCY",
"STOCK_SYMBOL",
"FILE_TYPE",
"PROGRAMMING_LANGUAGE",
"MEDICAL_PROCEDURE",
"CELESTIAL_BODY",
]
class Entity(BaseModel):
entity_name: str = Field(..., description="The name of the entity.")
entity_type: str = Field(..., description="The type of the entity.")
description: str = Field(
..., description="The description of the entity, in details and comprehensive."
)
importance_score: float = Field(
...,
ge=0,
le=1,
description="Importance score of the entity. Should be between 0 and 1 with 1 being the most important.",
)
def to_dict(self):
return {
"entity_name": clean_str(self.entity_name.upper()),
"entity_type": clean_str(self.entity_type.upper()),
"description": clean_str(self.description),
"importance_score": float(self.importance_score),
}
class Relationship(BaseModel):
src_id: str = Field(..., description="The name of the source entity.")
tgt_id: str = Field(..., description="The name of the target entity.")
description: str = Field(
...,
description="The description of the relationship between the source and target entity, in details and comprehensive.",
)
weight: float = Field(
...,
ge=0,
le=1,
description="The weight of the relationship. Should be between 0 and 1 with 1 being the strongest relationship.",
)
order: int = Field(
...,
ge=1,
le=3,
description="The order of the relationship. 1 for direct relationships, 2 for second-order, 3 for third-order.",
)
def to_dict(self):
return {
"src_id": clean_str(self.src_id.upper()),
"tgt_id": clean_str(self.tgt_id.upper()),
"description": clean_str(self.description),
"weight": float(self.weight),
"order": int(self.order),
}
class CombinedExtraction(dspy.Signature):
"""
Given a text document that is potentially relevant to this activity and a list of entity types,
identify all entities of those types from the text and all relationships among the identified entities.
Entity Guidelines:
1. Each entity name should be an actual atomic word from the input text.
2. Avoid duplicates and generic terms.
3. Make sure descriptions are detailed and comprehensive. Use multiple complete sentences for each point below:
a). The entity's role or significance in the context
b). Key attributes or characteristics
c). Relationships to other entities (if applicable)
d). Historical or cultural relevance (if applicable)
e). Any notable actions or events associated with the entity
4. All entity types from the text must be included.
5. IMPORTANT: Only use entity types from the provided 'entity_types' list. Do not introduce new entity types.
Relationship Guidelines:
1. Make sure relationship descriptions are detailed and comprehensive. Use multiple complete sentences for each point below:
a). The nature of the relationship (e.g., familial, professional, causal)
b). The impact or significance of the relationship on both entities
c). Any historical or contextual information relevant to the relationship
d). How the relationship evolved over time (if applicable)
e). Any notable events or actions that resulted from this relationship
2. Include direct relationships (order 1) as well as higher-order relationships (order 2 and 3):
a). Direct relationships: Immediate connections between entities.
b). Second-order relationships: Indirect effects or connections that result from direct relationships.
c). Third-order relationships: Further indirect effects that result from second-order relationships.
3. The "src_id" and "tgt_id" fields must exactly match entity names from the extracted entities list.
"""
input_text: str = dspy.InputField(
desc="The text to extract entities and relationships from."
)
entity_types: list[str] = dspy.InputField(
desc="List of entity types used for extraction."
)
entities: list[Entity] = dspy.OutputField(
desc="List of entities extracted from the text and the entity types."
)
relationships: list[Relationship] = dspy.OutputField(
desc="List of relationships extracted from the text and the entity types."
)
class CritiqueCombinedExtraction(dspy.Signature):
"""
Critique the current extraction of entities and relationships from a given text.
Focus on completeness, accuracy, and adherence to the provided entity types and extraction guidelines.
Critique Guidelines:
1. Evaluate if all relevant entities from the input text are captured and correctly typed.
2. Check if entity descriptions are comprehensive and follow the provided guidelines.
3. Assess the completeness of relationship extractions, including higher-order relationships.
4. Verify that relationship descriptions are detailed and follow the provided guidelines.
5. Identify any inconsistencies, errors, or missed opportunities in the current extraction.
6. Suggest specific improvements or additions to enhance the quality of the extraction.
"""
input_text: str = dspy.InputField(
desc="The original text from which entities and relationships were extracted."
)
entity_types: list[str] = dspy.InputField(
desc="List of valid entity types for this extraction task."
)
current_entities: list[Entity] = dspy.InputField(
desc="List of currently extracted entities to be critiqued."
)
current_relationships: list[Relationship] = dspy.InputField(
desc="List of currently extracted relationships to be critiqued."
)
entity_critique: str = dspy.OutputField(
desc="Detailed critique of the current entities, highlighting areas for improvement for completeness and accuracy.."
)
relationship_critique: str = dspy.OutputField(
desc="Detailed critique of the current relationships, highlighting areas for improvement for completeness and accuracy.."
)
class RefineCombinedExtraction(dspy.Signature):
"""
Refine the current extraction of entities and relationships based on the provided critique.
Improve completeness, accuracy, and adherence to the extraction guidelines.
Refinement Guidelines:
1. Address all points raised in the entity and relationship critiques.
2. Add missing entities and relationships identified in the critique.
3. Improve entity and relationship descriptions as suggested.
4. Ensure all refinements still adhere to the original extraction guidelines.
5. Maintain consistency between entities and relationships during refinement.
6. Focus on enhancing the overall quality and comprehensiveness of the extraction.
"""
input_text: str = dspy.InputField(
desc="The original text from which entities and relationships were extracted."
)
entity_types: list[str] = dspy.InputField(
desc="List of valid entity types for this extraction task."
)
current_entities: list[Entity] = dspy.InputField(
desc="List of currently extracted entities to be refined."
)
current_relationships: list[Relationship] = dspy.InputField(
desc="List of currently extracted relationships to be refined."
)
entity_critique: str = dspy.InputField(
desc="Detailed critique of the current entities to guide refinement."
)
relationship_critique: str = dspy.InputField(
desc="Detailed critique of the current relationships to guide refinement."
)
refined_entities: list[Entity] = dspy.OutputField(
desc="List of refined entities, addressing the entity critique and improving upon the current entities."
)
refined_relationships: list[Relationship] = dspy.OutputField(
desc="List of refined relationships, addressing the relationship critique and improving upon the current relationships."
)
class TypedEntityRelationshipExtractorException(dspy.Module):
def __init__(
self,
predictor: dspy.Module,
exception_types: tuple[type[Exception]] = (Exception,),
):
super().__init__()
self.predictor = predictor
self.exception_types = exception_types
def copy(self):
return TypedEntityRelationshipExtractorException(self.predictor)
def forward(self, **kwargs):
try:
prediction = self.predictor(**kwargs)
return prediction
except Exception as e:
if isinstance(e, self.exception_types):
return dspy.Prediction(entities=[], relationships=[])
raise e
class TypedEntityRelationshipExtractor(dspy.Module):
def __init__(
self,
lm: dspy.LM = None,
max_retries: int = 3,
entity_types: list[str] = ENTITY_TYPES,
self_refine: bool = False,
num_refine_turns: int = 1,
):
super().__init__()
self.lm = lm
self.entity_types = entity_types
self.self_refine = self_refine
self.num_refine_turns = num_refine_turns
self.extractor = dspy.ChainOfThought(
signature=CombinedExtraction, max_retries=max_retries
)
self.extractor = TypedEntityRelationshipExtractorException(
self.extractor, exception_types=(ValueError,)
)
if self.self_refine:
self.critique = dspy.ChainOfThought(
signature=CritiqueCombinedExtraction, max_retries=max_retries
)
self.refine = dspy.ChainOfThought(
signature=RefineCombinedExtraction, max_retries=max_retries
)
def forward(self, input_text: str) -> dspy.Prediction:
with dspy.context(lm=self.lm if self.lm is not None else dspy.settings.lm):
extraction_result = self.extractor(
input_text=input_text, entity_types=self.entity_types
)
current_entities: list[Entity] = extraction_result.entities
current_relationships: list[Relationship] = extraction_result.relationships
if self.self_refine:
for _ in range(self.num_refine_turns):
critique_result = self.critique(
input_text=input_text,
entity_types=self.entity_types,
current_entities=current_entities,
current_relationships=current_relationships,
)
refined_result = self.refine(
input_text=input_text,
entity_types=self.entity_types,
current_entities=current_entities,
current_relationships=current_relationships,
entity_critique=critique_result.entity_critique,
relationship_critique=critique_result.relationship_critique,
)
logger.debug(
f"entities: {len(current_entities)} | refined_entities: {len(refined_result.refined_entities)}"
)
logger.debug(
f"relationships: {len(current_relationships)} | refined_relationships: {len(refined_result.refined_relationships)}"
)
current_entities = refined_result.refined_entities
current_relationships = refined_result.refined_relationships
entities = [entity.to_dict() for entity in current_entities]
relationships = [
relationship.to_dict() for relationship in current_relationships
]
return dspy.Prediction(entities=entities, relationships=relationships)

View File

@@ -0,0 +1,382 @@
import asyncio
import os
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Callable, Dict, List, Optional, Type, Union, cast
from ._llm import (
amazon_bedrock_embedding,
create_amazon_bedrock_complete_function,
gpt_4o_complete,
gpt_4o_mini_complete,
openai_embedding,
azure_gpt_4o_complete,
azure_openai_embedding,
azure_gpt_4o_mini_complete,
)
from ._op import (
chunking_by_token_size,
extract_entities,
generate_community_report,
get_chunks,
local_query,
global_query,
naive_query,
)
from ._storage import (
JsonKVStorage,
NanoVectorDBStorage,
NetworkXStorage,
)
from ._utils import (
EmbeddingFunc,
compute_mdhash_id,
limit_async_func_call,
convert_response_to_json,
always_get_an_event_loop,
logger,
TokenizerWrapper,
)
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
StorageNameSpace,
QueryParam,
)
@dataclass
class GraphRAG:
working_dir: str = field(
default_factory=lambda: f"./nano_graphrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
)
# graph mode
enable_local: bool = True
enable_naive_rag: bool = False
# text chunking
tokenizer_type: str = "tiktoken" # or 'huggingface'
tiktoken_model_name: str = "gpt-4o"
huggingface_model_name: str = "bert-base-uncased" # default HF model
chunk_func: Callable[
[
list[list[int]],
List[str],
TokenizerWrapper,
Optional[int],
Optional[int],
],
List[Dict[str, Union[str, int]]],
] = chunking_by_token_size
chunk_token_size: int = 1200
chunk_overlap_token_size: int = 100
# entity extraction
entity_extract_max_gleaning: int = 1
entity_summary_to_max_tokens: int = 500
# graph clustering
graph_cluster_algorithm: str = "leiden"
max_graph_cluster_size: int = 10
graph_cluster_seed: int = 0xDEADBEEF
# node embedding
node_embedding_algorithm: str = "node2vec"
node2vec_params: dict = field(
default_factory=lambda: {
"dimensions": 1536,
"num_walks": 10,
"walk_length": 40,
"num_walks": 10,
"window_size": 2,
"iterations": 3,
"random_seed": 3,
}
)
# community reports
special_community_report_llm_kwargs: dict = field(
default_factory=lambda: {"response_format": {"type": "json_object"}}
)
# text embedding
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
embedding_batch_num: int = 32
embedding_func_max_async: int = 16
query_better_than_threshold: float = 0.2
# LLM
using_azure_openai: bool = False
using_amazon_bedrock: bool = False
best_model_id: str = "us.anthropic.claude-3-sonnet-20240229-v1:0"
cheap_model_id: str = "us.anthropic.claude-3-haiku-20240307-v1:0"
best_model_func: callable = gpt_4o_complete
best_model_max_token_size: int = 32768
best_model_max_async: int = 16
cheap_model_func: callable = gpt_4o_mini_complete
cheap_model_max_token_size: int = 32768
cheap_model_max_async: int = 16
# entity extraction
entity_extraction_func: callable = extract_entities
# storage
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
enable_llm_cache: bool = True
# extension
always_create_working_dir: bool = True
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json
def __post_init__(self):
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"GraphRAG init with param:\n\n {_print_config}\n")
self.tokenizer_wrapper = TokenizerWrapper(
tokenizer_type=self.tokenizer_type,
model_name=self.tiktoken_model_name if self.tokenizer_type == "tiktoken" else self.huggingface_model_name
)
if self.using_azure_openai:
# If there's no OpenAI API key, use Azure OpenAI
if self.best_model_func == gpt_4o_complete:
self.best_model_func = azure_gpt_4o_complete
if self.cheap_model_func == gpt_4o_mini_complete:
self.cheap_model_func = azure_gpt_4o_mini_complete
if self.embedding_func == openai_embedding:
self.embedding_func = azure_openai_embedding
logger.info(
"Switched the default openai funcs to Azure OpenAI if you didn't set any of it"
)
if self.using_amazon_bedrock:
self.best_model_func = create_amazon_bedrock_complete_function(self.best_model_id)
self.cheap_model_func = create_amazon_bedrock_complete_function(self.cheap_model_id)
self.embedding_func = amazon_bedrock_embedding
logger.info(
"Switched the default openai funcs to Amazon Bedrock"
)
if not os.path.exists(self.working_dir) and self.always_create_working_dir:
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
self.full_docs = self.key_string_value_json_storage_cls(
namespace="full_docs", global_config=asdict(self)
)
self.text_chunks = self.key_string_value_json_storage_cls(
namespace="text_chunks", global_config=asdict(self)
)
self.llm_response_cache = (
self.key_string_value_json_storage_cls(
namespace="llm_response_cache", global_config=asdict(self)
)
if self.enable_llm_cache
else None
)
self.community_reports = self.key_string_value_json_storage_cls(
namespace="community_reports", global_config=asdict(self)
)
self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace="chunk_entity_relation", global_config=asdict(self)
)
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
)
self.entities_vdb = (
self.vector_db_storage_cls(
namespace="entities",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
if self.enable_local
else None
)
self.chunks_vdb = (
self.vector_db_storage_cls(
namespace="chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
if self.enable_naive_rag
else None
)
self.best_model_func = limit_async_func_call(self.best_model_max_async)(
partial(self.best_model_func, hashing_kv=self.llm_response_cache)
)
self.cheap_model_func = limit_async_func_call(self.cheap_model_max_async)(
partial(self.cheap_model_func, hashing_kv=self.llm_response_cache)
)
def insert(self, string_or_strings):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert(string_or_strings))
def query(self, query: str, param: QueryParam = QueryParam()):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param))
async def aquery(self, query: str, param: QueryParam = QueryParam()):
if param.mode == "local" and not self.enable_local:
raise ValueError("enable_local is False, cannot query in local mode")
if param.mode == "naive" and not self.enable_naive_rag:
raise ValueError("enable_naive_rag is False, cannot query in naive mode")
if param.mode == "local":
response = await local_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.community_reports,
self.text_chunks,
param,
self.tokenizer_wrapper,
asdict(self),
)
elif param.mode == "global":
response = await global_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.community_reports,
self.text_chunks,
param,
self.tokenizer_wrapper,
asdict(self),
)
elif param.mode == "naive":
response = await naive_query(
query,
self.chunks_vdb,
self.text_chunks,
param,
self.tokenizer_wrapper,
asdict(self),
)
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
return response
async def ainsert(self, string_or_strings):
await self._insert_start()
try:
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
# ---------- new docs
new_docs = {
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
for c in string_or_strings
}
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning(f"All docs are already in the storage")
return
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
# ---------- chunking
inserting_chunks = get_chunks(
new_docs=new_docs,
chunk_func=self.chunk_func,
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tokenizer_wrapper=self.tokenizer_wrapper,
)
_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
logger.warning(f"All chunks are already in the storage")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
if self.enable_naive_rag:
logger.info("Insert chunks for naive RAG")
await self.chunks_vdb.upsert(inserting_chunks)
# TODO: don't support incremental update for communities now, so we have to drop all
await self.community_reports.drop()
# ---------- extract/summary entity and upsert to graph
logger.info("[Entity Extraction]...")
maybe_new_kg = await self.entity_extraction_func(
inserting_chunks,
knwoledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
tokenizer_wrapper=self.tokenizer_wrapper,
global_config=asdict(self),
using_amazon_bedrock=self.using_amazon_bedrock,
)
if maybe_new_kg is None:
logger.warning("No new entities found")
return
self.chunk_entity_relation_graph = maybe_new_kg
# ---------- update clusterings of graph
logger.info("[Community Report]...")
await self.chunk_entity_relation_graph.clustering(
self.graph_cluster_algorithm
)
await generate_community_report(
self.community_reports, self.chunk_entity_relation_graph, self.tokenizer_wrapper, asdict(self)
)
# ---------- commit upsertings and indexing
await self.full_docs.upsert(new_docs)
await self.text_chunks.upsert(inserting_chunks)
finally:
await self._insert_done()
async def _insert_start(self):
tasks = []
for storage_inst in [
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_start_callback())
await asyncio.gather(*tasks)
async def _insert_done(self):
tasks = []
for storage_inst in [
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.community_reports,
self.entities_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
async def _query_done(self):
tasks = []
for storage_inst in [self.llm_response_cache]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)

View File

@@ -0,0 +1,305 @@
"""
GraphRAG core prompts (Chinese, aerospace-oriented).
"""
GRAPH_FIELD_SEP = "<SEP>"
PROMPTS = {}
PROMPTS[
"claim_extraction"
] = """-任务定位-
你是航天知识情报分析助手,负责从文本中抽取实体相关的主张/断言claim
-目标-
给定输入文本、实体约束和主张说明,抽取满足条件的实体及其对应主张,结果必须可溯源。
-执行步骤-
1. 先识别满足实体约束的命名实体。实体约束可能是实体名称列表,也可能是实体类型列表。
2. 对步骤1中的每个实体抽取其作为主语的主张。每条主张需输出
- Subject: 主张主体实体名大写必须来自步骤1
- Object: 客体实体名(大写;未知时使用 **NONE**
- Claim Type: 主张类型(大写;应可复用)
- Claim Status: **TRUE**、**FALSE** 或 **SUSPECTED**
- Claim Description: 说明主张的依据、逻辑和关键证据
- Claim Date: 起止时间ISO-8601未知时为 **NONE**
- Claim Source Text: 与主张直接相关的原文引文(尽量完整)
格式要求:
(<subject_entity>{tuple_delimiter}<object_entity>{tuple_delimiter}<claim_type>{tuple_delimiter}<claim_status>{tuple_delimiter}<claim_start_date>{tuple_delimiter}<claim_end_date>{tuple_delimiter}<claim_description>{tuple_delimiter}<claim_source>)
3. 使用 **{record_delimiter}** 连接所有记录。
4. 结束时输出 {completion_delimiter}
-约束-
- 仅基于输入文本,不得编造。
- 航天语境优先(如航天器、推进系统、姿态控制、任务阶段、地面系统、试验事件、故障模式等)。
-输入-
Entity specification: {entity_specs}
Claim description: {claim_description}
Text: {input_text}
Output: """
PROMPTS[
"community_report"
] = """你是航天领域知识图谱分析助手负责为一个社区community生成结构化研判报告。
# 目标
根据给定的实体、关系和可选主张,输出可用于技术评审与任务决策的社区报告。
# 报告结构
必须返回 JSON 字符串,结构如下:
{
"title": <标题>,
"summary": <执行摘要>,
"rating": <0-10 浮点评分>,
"rating_explanation": <评分说明>,
"findings": [
{
"summary": <要点小结>,
"explanation": <详细说明>
}
]
}
# 字段要求
- title: 简洁且具体,尽量包含代表性实体。
- summary: 说明社区整体结构、核心实体、关键关系与主要风险/价值。
- rating: 社区影响度/风险度评分0-10
- rating_explanation: 单句说明评分依据。
- findings: 5-10 条关键发现,覆盖技术链路、任务影响、可靠性风险、协同关系等。
# 领域要求(航天)
优先关注以下维度:
- 任务阶段:论证、研制、总装、测试、发射、入轨、在轨运行、回收/退役
- 系统层级:航天器、有效载荷、推进系统、姿态控制、测控通信、地面系统
- 关键指标:推力、比冲、功率、带宽、精度、寿命、可靠性、故障率
- 风险与依赖:单点故障、接口依赖、时序耦合、供应链与试验验证缺口
# 证据约束
- 仅使用输入中可证据化信息。
- 无证据内容不得写入。
- 若信息不足,应明确指出不确定性与缺失点。
# 输入
Text:
```
{input_text}
```
Output:
"""
PROMPTS[
"entity_extraction"
] = """-任务目标-
给定文本与实体类型列表,识别所有相关实体,并抽取实体间“明确存在”的关系。
-实体类型约束-
- entity_type 必须来自给定集合:[{entity_types}]
- 若无法确定,使用最接近类型,不可臆造新类型
-关系类型约束-
关系描述必须以“关系类型=<类型>;依据=<说明>”开头。
关系类型从以下集合中选择:
[组成, 隶属, 控制, 被控制, 供能, 支撑, 测量, 感知, 执行, 通信, 影响, 制约, 因果, 时序前后, 协同, 风险关联, 其他]
-执行步骤-
1. 抽取实体。每个实体输出:
- entity_name: 实体名(保留原文专有名词;必要时标准化)
- entity_type: 实体类型(必须在给定集合中)
- entity_description: 面向航天任务语境的实体描述(属性、职责、行为)
实体格式:
("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>)
2. 在步骤1实体中抽取“证据充分且语义明确”的关系对(source_entity, target_entity)。
每条关系输出:
- source_entity
- target_entity
- relationship_description: 必须以“关系类型=<类型>;依据=<说明>”开头
- relationship_strength: 1-10 数值,表示关系强度
关系格式:
("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_strength>)
3. 结果使用 **{record_delimiter}** 拼接。
4. 结束时输出 {completion_delimiter}
-质量规则-
- 只抽取文本中可直接支持的实体和关系。
- 不输出模糊、猜测或无依据关系。
- 对航天语义优先:航天器、分系统、任务阶段、参数指标、故障模式、试验事件。
-输入-
Entity_types: {entity_types}
Text: {input_text}
Output:
"""
PROMPTS[
"summarize_entity_descriptions"
] = """你是航天知识库整理助手。
给定一个或两个实体名称,以及若干描述片段,请合并为一段一致、完整、可复用的摘要。
要求:
- 覆盖所有有效信息
- 若描述冲突,给出最一致、最保守的综合结论
- 使用第三人称
- 保留实体名,避免指代不清
- 优先保留任务阶段、技术指标、系统依赖和风险信息
#######
-输入数据-
Entities: {entity_name}
Description List: {description_list}
#######
Output:
"""
PROMPTS[
"entiti_continue_extraction"
] = """上一轮可能遗漏了实体或关系。请仅补充遗漏项,严格沿用既定输出格式,不要重复已输出记录:"""
PROMPTS[
"entiti_if_loop_extraction"
] = """请判断是否仍有遗漏实体或关系。仅回答 YES 或 NO。"""
PROMPTS["DEFAULT_ENTITY_TYPES"] = [
"航天器",
"任务",
"任务阶段",
"有效载荷",
"推进系统",
"姿态控制系统",
"测控通信系统",
"电源系统",
"热控系统",
"结构机构",
"传感器",
"执行机构",
"地面系统",
"组织机构",
"试验事件",
"故障模式",
"参数指标",
"轨道",
"地点",
"时间",
]
PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
PROMPTS[
"local_rag_response"
] = """---角色---
你是航天领域图谱问答助手,擅长基于局部子图进行多跳推理。
---任务---
根据输入的数据表(实体、关系、证据片段)回答用户问题。
---回答要求---
1. 先给结论,再给推理链路。
2. 推理链路至少包含:
- 关键实体
- 关键关系
- 中间推断
- 最终结论
3. 若证据不足,明确说明“证据不足以得出结论”。
4. 严禁编造。
5. 使用中文、专业且简洁。
---目标长度与格式---
{response_type}
---输入数据表---
{context_data}
---输出格式建议---
- 结论
- 推理链路
- 证据与不确定性
"""
PROMPTS[
"global_map_rag_points"
] = """---角色---
你是航天知识全局研判助手,负责从社区级信息中提炼关键观点。
---任务---
根据输入数据表,输出一组可用于后续全局汇总的关键点。
---输出要求---
- 必须输出 JSON
{
"points": [
{"description": "观点描述", "score": 0-100整数}
]
}
- description: 观点需可证据化,优先覆盖跨系统影响、任务阶段耦合、技术风险和性能趋势。
- score: 对回答用户问题的重要性分值。
- 若无法回答,输出 1 条 score=0 且明确说明信息不足。
- 仅使用输入证据,不得编造。
---输入数据表---
{context_data}
"""
PROMPTS[
"global_reduce_rag_response"
] = """---角色---
你是航天领域总师级分析助手,负责融合多位分析员报告并给出全局结论。
---任务---
根据按重要性降序排列的分析员报告,输出面向决策的综合回答。
---要求---
1. 先给总体结论,再给分项分析(趋势、共性风险、关键差异、建议动作)。
2. 报告融合时去重、去噪,保留高证据密度信息。
3. 必须指出不确定性与信息缺口。
4. 仅基于输入报告,不得编造。
5. 使用中文,风格专业、严谨。
---目标长度与格式---
{response_type}
---分析员报告---
{report_data}
"""
PROMPTS[
"naive_rag_response"
] = """你是航天知识问答助手。
下面是可用知识:
{content_data}
---
请基于上述知识回答用户问题,要求:
- 回答准确、简洁、专业
- 若信息不足,明确说明缺失点
- 不得编造
---目标长度与格式---
{response_type}
"""
PROMPTS["fail_response"] = "抱歉,当前无法基于现有信息回答该问题。"
PROMPTS["process_tickers"] = ["-", "\\", "|", "/"]
PROMPTS["default_text_separator"] = [
"\n\n",
"\r\n\r\n",
"\n",
"\r\n",
"",
"",
".",
"",
"!",
"",
"?",
" ",
"\t",
"\u3000",
"\u200b",
]