init. project
This commit is contained in:
9
rag-web-ui/backend/nano_graphrag/_storage/__init__.py
Normal file
9
rag-web-ui/backend/nano_graphrag/_storage/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .gdb_networkx import NetworkXStorage
|
||||
from .gdb_neo4j import Neo4jStorage
|
||||
from .vdb_nanovectordb import NanoVectorDBStorage
|
||||
from .kv_json import JsonKVStorage
|
||||
|
||||
try:
|
||||
from .vdb_hnswlib import HNSWVectorStorage
|
||||
except ImportError:
|
||||
HNSWVectorStorage = None
|
||||
529
rag-web-ui/backend/nano_graphrag/_storage/gdb_neo4j.py
Normal file
529
rag-web-ui/backend/nano_graphrag/_storage/gdb_neo4j.py
Normal file
@@ -0,0 +1,529 @@
|
||||
import json
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
from neo4j import AsyncGraphDatabase
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
from ..base import BaseGraphStorage, SingleCommunitySchema
|
||||
from .._utils import logger
|
||||
from ..prompt import GRAPH_FIELD_SEP
|
||||
|
||||
neo4j_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def make_path_idable(path):
|
||||
return path.replace(".", "_").replace("/", "__").replace("-", "_").replace(":", "_").replace("\\", "__")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Neo4jStorage(BaseGraphStorage):
|
||||
def __post_init__(self):
|
||||
self.neo4j_url = self.global_config["addon_params"].get("neo4j_url", None)
|
||||
self.neo4j_auth = self.global_config["addon_params"].get("neo4j_auth", None)
|
||||
self.namespace = (
|
||||
f"{make_path_idable(self.global_config['working_dir'])}__{self.namespace}"
|
||||
)
|
||||
logger.info(f"Using the label {self.namespace} for Neo4j as identifier")
|
||||
if self.neo4j_url is None or self.neo4j_auth is None:
|
||||
raise ValueError("Missing neo4j_url or neo4j_auth in addon_params")
|
||||
self.async_driver = AsyncGraphDatabase.driver(
|
||||
self.neo4j_url, auth=self.neo4j_auth, max_connection_pool_size=50,
|
||||
)
|
||||
|
||||
# async def create_database(self):
|
||||
# async with self.async_driver.session() as session:
|
||||
# try:
|
||||
# constraints = await session.run("SHOW CONSTRAINTS")
|
||||
# # TODO I don't know why CREATE CONSTRAINT IF NOT EXISTS still trigger error
|
||||
# # so have to check if the constrain exists
|
||||
# constrain_exists = False
|
||||
|
||||
# async for record in constraints:
|
||||
# if (
|
||||
# self.namespace in record["labelsOrTypes"]
|
||||
# and "id" in record["properties"]
|
||||
# and record["type"] == "UNIQUENESS"
|
||||
# ):
|
||||
# constrain_exists = True
|
||||
# break
|
||||
# if not constrain_exists:
|
||||
# await session.run(
|
||||
# f"CREATE CONSTRAINT FOR (n:{self.namespace}) REQUIRE n.id IS UNIQUE"
|
||||
# )
|
||||
# logger.info(f"Add constraint for namespace: {self.namespace}")
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error accessing or setting up the database: {str(e)}")
|
||||
# raise
|
||||
|
||||
async def _init_workspace(self):
|
||||
await self.async_driver.verify_authentication()
|
||||
await self.async_driver.verify_connectivity()
|
||||
# TODOLater: create database if not exists always cause an error when async
|
||||
# await self.create_database()
|
||||
|
||||
async def index_start_callback(self):
|
||||
logger.info("Init Neo4j workspace")
|
||||
await self._init_workspace()
|
||||
|
||||
# create index for faster searching
|
||||
try:
|
||||
async with self.async_driver.session() as session:
|
||||
await session.run(
|
||||
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.id)"
|
||||
)
|
||||
|
||||
await session.run(
|
||||
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.entity_type)"
|
||||
)
|
||||
|
||||
await session.run(
|
||||
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.communityIds)"
|
||||
)
|
||||
|
||||
await session.run(
|
||||
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.source_id)"
|
||||
)
|
||||
logger.info("Neo4j indexes created successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create indexes: {e}")
|
||||
raise e
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
async with self.async_driver.session() as session:
|
||||
result = await session.run(
|
||||
f"MATCH (n:`{self.namespace}`) WHERE n.id = $node_id RETURN COUNT(n) > 0 AS exists",
|
||||
node_id=node_id,
|
||||
)
|
||||
record = await result.single()
|
||||
return record["exists"] if record else False
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
async with self.async_driver.session() as session:
|
||||
result = await session.run(
|
||||
f"""
|
||||
MATCH (s:`{self.namespace}`)
|
||||
WHERE s.id = $source_id
|
||||
MATCH (t:`{self.namespace}`)
|
||||
WHERE t.id = $target_id
|
||||
RETURN EXISTS((s)-[]->(t)) AS exists
|
||||
""",
|
||||
source_id=source_node_id,
|
||||
target_id=target_node_id,
|
||||
)
|
||||
|
||||
record = await result.single()
|
||||
return record["exists"] if record else False
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
results = await self.node_degrees_batch([node_id])
|
||||
return results[0] if results else 0
|
||||
|
||||
async def node_degrees_batch(self, node_ids: List[str]) -> List[str]:
|
||||
if not node_ids:
|
||||
return {}
|
||||
|
||||
result_dict = {node_id: 0 for node_id in node_ids}
|
||||
async with self.async_driver.session() as session:
|
||||
result = await session.run(
|
||||
f"""
|
||||
UNWIND $node_ids AS node_id
|
||||
MATCH (n:`{self.namespace}`)
|
||||
WHERE n.id = node_id
|
||||
OPTIONAL MATCH (n)-[]-(m:`{self.namespace}`)
|
||||
RETURN node_id, COUNT(m) AS degree
|
||||
""",
|
||||
node_ids=node_ids
|
||||
)
|
||||
|
||||
async for record in result:
|
||||
result_dict[record["node_id"]] = record["degree"]
|
||||
|
||||
return [result_dict[node_id] for node_id in node_ids]
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
results = await self.edge_degrees_batch([(src_id, tgt_id)])
|
||||
return results[0] if results else 0
|
||||
|
||||
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> list[int]:
|
||||
if not edge_pairs:
|
||||
return []
|
||||
|
||||
result_dict = {tuple(edge_pair): 0 for edge_pair in edge_pairs}
|
||||
|
||||
edges_params = [{"src_id": src, "tgt_id": tgt} for src, tgt in edge_pairs]
|
||||
|
||||
try:
|
||||
async with self.async_driver.session() as session:
|
||||
result = await session.run(
|
||||
f"""
|
||||
UNWIND $edges AS edge
|
||||
|
||||
MATCH (s:`{self.namespace}`)
|
||||
WHERE s.id = edge.src_id
|
||||
WITH edge, s
|
||||
OPTIONAL MATCH (s)-[]-(n1:`{self.namespace}`)
|
||||
WITH edge, COUNT(n1) AS src_degree
|
||||
|
||||
MATCH (t:`{self.namespace}`)
|
||||
WHERE t.id = edge.tgt_id
|
||||
WITH edge, src_degree, t
|
||||
OPTIONAL MATCH (t)-[]-(n2:`{self.namespace}`)
|
||||
WITH edge.src_id AS src_id, edge.tgt_id AS tgt_id, src_degree, COUNT(n2) AS tgt_degree
|
||||
|
||||
RETURN src_id, tgt_id, src_degree + tgt_degree AS degree
|
||||
""",
|
||||
edges=edges_params
|
||||
)
|
||||
|
||||
async for record in result:
|
||||
src_id = record["src_id"]
|
||||
tgt_id = record["tgt_id"]
|
||||
degree = record["degree"]
|
||||
|
||||
# 更新结果字典
|
||||
edge_pair = (src_id, tgt_id)
|
||||
result_dict[edge_pair] = degree
|
||||
|
||||
return [result_dict[tuple(edge_pair)] for edge_pair in edge_pairs]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch edge degree calculation: {e}")
|
||||
return [0] * len(edge_pairs)
|
||||
|
||||
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
result = await self.get_nodes_batch([node_id])
|
||||
return result[0] if result else None
|
||||
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, Union[dict, None]]:
|
||||
if not node_ids:
|
||||
return {}
|
||||
|
||||
result_dict = {node_id: None for node_id in node_ids}
|
||||
|
||||
try:
|
||||
async with self.async_driver.session() as session:
|
||||
result = await session.run(
|
||||
f"""
|
||||
UNWIND $node_ids AS node_id
|
||||
MATCH (n:`{self.namespace}`)
|
||||
WHERE n.id = node_id
|
||||
RETURN node_id, properties(n) AS node_data
|
||||
""",
|
||||
node_ids=node_ids
|
||||
)
|
||||
|
||||
async for record in result:
|
||||
node_id = record["node_id"]
|
||||
raw_node_data = record["node_data"]
|
||||
|
||||
if raw_node_data:
|
||||
raw_node_data["clusters"] = json.dumps(
|
||||
[
|
||||
{
|
||||
"level": index,
|
||||
"cluster": cluster_id,
|
||||
}
|
||||
for index, cluster_id in enumerate(
|
||||
raw_node_data.get("communityIds", [])
|
||||
)
|
||||
]
|
||||
)
|
||||
result_dict[node_id] = raw_node_data
|
||||
return [result_dict[node_id] for node_id in node_ids]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch node retrieval: {e}")
|
||||
raise e
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
results = await self.get_edges_batch([(source_node_id, target_node_id)])
|
||||
return results[0] if results else None
|
||||
|
||||
async def get_edges_batch(
|
||||
self, edge_pairs: list[tuple[str, str]]
|
||||
) -> list[Union[dict, None]]:
|
||||
if not edge_pairs:
|
||||
return []
|
||||
|
||||
result_dict = {tuple(edge_pair): None for edge_pair in edge_pairs}
|
||||
|
||||
edges_params = [{"source_id": src, "target_id": tgt} for src, tgt in edge_pairs]
|
||||
|
||||
try:
|
||||
async with self.async_driver.session() as session:
|
||||
result = await session.run(
|
||||
f"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (s:`{self.namespace}`)-[r]->(t:`{self.namespace}`)
|
||||
WHERE s.id = edge.source_id AND t.id = edge.target_id
|
||||
RETURN edge.source_id AS source_id, edge.target_id AS target_id, properties(r) AS edge_data
|
||||
""",
|
||||
edges=edges_params
|
||||
)
|
||||
|
||||
async for record in result:
|
||||
source_id = record["source_id"]
|
||||
target_id = record["target_id"]
|
||||
edge_data = record["edge_data"]
|
||||
|
||||
edge_pair = (source_id, target_id)
|
||||
result_dict[edge_pair] = edge_data
|
||||
|
||||
return [result_dict[tuple(edge_pair)] for edge_pair in edge_pairs]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch edge retrieval: {e}")
|
||||
return [None] * len(edge_pairs)
|
||||
|
||||
async def get_node_edges(
|
||||
self, source_node_id: str
|
||||
) -> list[tuple[str, str]]:
|
||||
results = await self.get_nodes_edges_batch([source_node_id])
|
||||
return results[0] if results else []
|
||||
|
||||
async def get_nodes_edges_batch(
|
||||
self, node_ids: list[str]
|
||||
) -> list[list[tuple[str, str]]]:
|
||||
if not node_ids:
|
||||
return []
|
||||
|
||||
result_dict = {node_id: [] for node_id in node_ids}
|
||||
|
||||
try:
|
||||
async with self.async_driver.session() as session:
|
||||
result = await session.run(
|
||||
f"""
|
||||
UNWIND $node_ids AS node_id
|
||||
MATCH (s:`{self.namespace}`)-[r]->(t:`{self.namespace}`)
|
||||
WHERE s.id = node_id
|
||||
RETURN s.id AS source_id, t.id AS target_id
|
||||
""",
|
||||
node_ids=node_ids
|
||||
)
|
||||
|
||||
async for record in result:
|
||||
source_id = record["source_id"]
|
||||
target_id = record["target_id"]
|
||||
|
||||
if source_id in result_dict:
|
||||
result_dict[source_id].append((source_id, target_id))
|
||||
|
||||
return [result_dict[node_id] for node_id in node_ids]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch node edges retrieval: {e}")
|
||||
return [[] for _ in node_ids]
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
await self.upsert_nodes_batch([(node_id, node_data)])
|
||||
|
||||
async def upsert_nodes_batch(self, nodes_data: list[tuple[str, dict[str, str]]]):
|
||||
if not nodes_data:
|
||||
return []
|
||||
|
||||
nodes_by_type = {}
|
||||
for node_id, node_data in nodes_data:
|
||||
node_type = node_data.get("entity_type", "UNKNOWN").strip('"')
|
||||
if node_type not in nodes_by_type:
|
||||
nodes_by_type[node_type] = []
|
||||
nodes_by_type[node_type].append((node_id, node_data))
|
||||
|
||||
async with self.async_driver.session() as session:
|
||||
for node_type, type_nodes in nodes_by_type.items():
|
||||
params = [{"id": node_id, "data": node_data} for node_id, node_data in type_nodes]
|
||||
|
||||
await session.run(
|
||||
f"""
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:`{self.namespace}`:`{node_type}` {{id: node.id}})
|
||||
SET n += node.data
|
||||
""",
|
||||
nodes=params
|
||||
)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
await self.upsert_edges_batch([(source_node_id, target_node_id, edge_data)])
|
||||
|
||||
|
||||
async def upsert_edges_batch(
|
||||
self, edges_data: list[tuple[str, str, dict[str, str]]]
|
||||
):
|
||||
if not edges_data:
|
||||
return
|
||||
|
||||
edges_params = []
|
||||
for source_id, target_id, edge_data in edges_data:
|
||||
edge_data_copy = edge_data.copy()
|
||||
edge_data_copy.setdefault("weight", 0.0)
|
||||
|
||||
edges_params.append({
|
||||
"source_id": source_id,
|
||||
"target_id": target_id,
|
||||
"edge_data": edge_data_copy
|
||||
})
|
||||
|
||||
async with self.async_driver.session() as session:
|
||||
await session.run(
|
||||
f"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (s:`{self.namespace}`)
|
||||
WHERE s.id = edge.source_id
|
||||
WITH edge, s
|
||||
MATCH (t:`{self.namespace}`)
|
||||
WHERE t.id = edge.target_id
|
||||
MERGE (s)-[r:RELATED]->(t)
|
||||
SET r += edge.edge_data
|
||||
""",
|
||||
edges=edges_params
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
async def clustering(self, algorithm: str):
|
||||
if algorithm != "leiden":
|
||||
raise ValueError(
|
||||
f"Clustering algorithm {algorithm} not supported in Neo4j implementation"
|
||||
)
|
||||
|
||||
random_seed = self.global_config["graph_cluster_seed"]
|
||||
max_level = self.global_config["max_graph_cluster_size"]
|
||||
async with self.async_driver.session() as session:
|
||||
try:
|
||||
# Project the graph with undirected relationships
|
||||
await session.run(
|
||||
f"""
|
||||
CALL gds.graph.project(
|
||||
'graph_{self.namespace}',
|
||||
['{self.namespace}'],
|
||||
{{
|
||||
RELATED: {{
|
||||
orientation: 'UNDIRECTED',
|
||||
properties: ['weight']
|
||||
}}
|
||||
}}
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Run Leiden algorithm
|
||||
result = await session.run(
|
||||
f"""
|
||||
CALL gds.leiden.write(
|
||||
'graph_{self.namespace}',
|
||||
{{
|
||||
writeProperty: 'communityIds',
|
||||
includeIntermediateCommunities: True,
|
||||
relationshipWeightProperty: "weight",
|
||||
maxLevels: {max_level},
|
||||
tolerance: 0.0001,
|
||||
gamma: 1.0,
|
||||
theta: 0.01,
|
||||
randomSeed: {random_seed}
|
||||
}}
|
||||
)
|
||||
YIELD communityCount, modularities;
|
||||
"""
|
||||
)
|
||||
result = await result.single()
|
||||
community_count: int = result["communityCount"]
|
||||
modularities = result["modularities"]
|
||||
logger.info(
|
||||
f"Performed graph clustering with {community_count} communities and modularities {modularities}"
|
||||
)
|
||||
finally:
|
||||
# Drop the projected graph
|
||||
await session.run(f"CALL gds.graph.drop('graph_{self.namespace}')")
|
||||
|
||||
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
|
||||
results = defaultdict(
|
||||
lambda: dict(
|
||||
level=None,
|
||||
title=None,
|
||||
edges=set(),
|
||||
nodes=set(),
|
||||
chunk_ids=set(),
|
||||
occurrence=0.0,
|
||||
sub_communities=[],
|
||||
)
|
||||
)
|
||||
|
||||
async with self.async_driver.session() as session:
|
||||
# Fetch community data
|
||||
result = await session.run(
|
||||
f"""
|
||||
MATCH (n:`{self.namespace}`)
|
||||
WITH n, n.communityIds AS communityIds, [(n)-[]-(m:`{self.namespace}`) | m.id] AS connected_nodes
|
||||
RETURN n.id AS node_id, n.source_id AS source_id,
|
||||
communityIds AS cluster_key,
|
||||
connected_nodes
|
||||
"""
|
||||
)
|
||||
|
||||
# records = await result.fetch()
|
||||
|
||||
max_num_ids = 0
|
||||
async for record in result:
|
||||
for index, c_id in enumerate(record["cluster_key"]):
|
||||
node_id = str(record["node_id"])
|
||||
source_id = record["source_id"]
|
||||
level = index
|
||||
cluster_key = str(c_id)
|
||||
connected_nodes = record["connected_nodes"]
|
||||
|
||||
results[cluster_key]["level"] = level
|
||||
results[cluster_key]["title"] = f"Cluster {cluster_key}"
|
||||
results[cluster_key]["nodes"].add(node_id)
|
||||
results[cluster_key]["edges"].update(
|
||||
[
|
||||
tuple(sorted([node_id, str(connected)]))
|
||||
for connected in connected_nodes
|
||||
if connected != node_id
|
||||
]
|
||||
)
|
||||
chunk_ids = source_id.split(GRAPH_FIELD_SEP)
|
||||
results[cluster_key]["chunk_ids"].update(chunk_ids)
|
||||
max_num_ids = max(
|
||||
max_num_ids, len(results[cluster_key]["chunk_ids"])
|
||||
)
|
||||
|
||||
# Process results
|
||||
for k, v in results.items():
|
||||
v["edges"] = [list(e) for e in v["edges"]]
|
||||
v["nodes"] = list(v["nodes"])
|
||||
v["chunk_ids"] = list(v["chunk_ids"])
|
||||
v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
|
||||
|
||||
# Compute sub-communities (this is a simplified approach)
|
||||
for cluster in results.values():
|
||||
cluster["sub_communities"] = [
|
||||
sub_key
|
||||
for sub_key, sub_cluster in results.items()
|
||||
if sub_cluster["level"] > cluster["level"]
|
||||
and set(sub_cluster["nodes"]).issubset(set(cluster["nodes"]))
|
||||
]
|
||||
|
||||
return dict(results)
|
||||
|
||||
async def index_done_callback(self):
|
||||
await self.async_driver.close()
|
||||
|
||||
async def _debug_delete_all_node_edges(self):
|
||||
async with self.async_driver.session() as session:
|
||||
try:
|
||||
# Delete all relationships in the namespace
|
||||
await session.run(f"MATCH (n:`{self.namespace}`)-[r]-() DELETE r")
|
||||
|
||||
# Delete all nodes in the namespace
|
||||
await session.run(f"MATCH (n:`{self.namespace}`) DELETE n")
|
||||
|
||||
logger.info(
|
||||
f"All nodes and edges in namespace '{self.namespace}' have been deleted."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting nodes and edges: {str(e)}")
|
||||
raise
|
||||
268
rag-web-ui/backend/nano_graphrag/_storage/gdb_networkx.py
Normal file
268
rag-web-ui/backend/nano_graphrag/_storage/gdb_networkx.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import html
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, cast, List
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import asyncio
|
||||
|
||||
from .._utils import logger
|
||||
from ..base import (
|
||||
BaseGraphStorage,
|
||||
SingleCommunitySchema,
|
||||
)
|
||||
from ..prompt import GRAPH_FIELD_SEP
|
||||
|
||||
|
||||
@dataclass
|
||||
class NetworkXStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name) -> nx.Graph:
|
||||
if os.path.exists(file_name):
|
||||
return nx.read_graphml(file_name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def write_nx_graph(graph: nx.Graph, file_name):
|
||||
logger.info(
|
||||
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
|
||||
)
|
||||
nx.write_graphml(graph, file_name)
|
||||
|
||||
@staticmethod
|
||||
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
|
||||
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
||||
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
|
||||
"""
|
||||
from graspologic.utils import largest_connected_component
|
||||
|
||||
graph = graph.copy()
|
||||
graph = cast(nx.Graph, largest_connected_component(graph))
|
||||
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
|
||||
graph = nx.relabel_nodes(graph, node_mapping)
|
||||
return NetworkXStorage._stabilize_graph(graph)
|
||||
|
||||
@staticmethod
|
||||
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
||||
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
||||
Ensure an undirected graph with the same relationships will always be read the same way.
|
||||
"""
|
||||
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
|
||||
|
||||
sorted_nodes = graph.nodes(data=True)
|
||||
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
|
||||
|
||||
fixed_graph.add_nodes_from(sorted_nodes)
|
||||
edges = list(graph.edges(data=True))
|
||||
|
||||
if not graph.is_directed():
|
||||
|
||||
def _sort_source_target(edge):
|
||||
source, target, edge_data = edge
|
||||
if source > target:
|
||||
temp = source
|
||||
source = target
|
||||
target = temp
|
||||
return source, target, edge_data
|
||||
|
||||
edges = [_sort_source_target(edge) for edge in edges]
|
||||
|
||||
def _get_edge_key(source: Any, target: Any) -> str:
|
||||
return f"{source} -> {target}"
|
||||
|
||||
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
|
||||
|
||||
fixed_graph.add_edges_from(edges)
|
||||
return fixed_graph
|
||||
|
||||
def __post_init__(self):
|
||||
self._graphml_xml_file = os.path.join(
|
||||
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
||||
)
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
||||
if preloaded_graph is not None:
|
||||
logger.info(
|
||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||
)
|
||||
self._graph = preloaded_graph or nx.Graph()
|
||||
self._clustering_algorithms = {
|
||||
"leiden": self._leiden_clustering,
|
||||
}
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def index_done_callback(self):
|
||||
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
return self._graph.has_node(node_id)
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
return self._graph.has_edge(source_node_id, target_node_id)
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
return self._graph.nodes.get(node_id)
|
||||
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, Union[dict, None]]:
|
||||
return await asyncio.gather(*[self.get_node(node_id) for node_id in node_ids])
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
# [numberchiffre]: node_id not part of graph returns `DegreeView({})` instead of 0
|
||||
return self._graph.degree(node_id) if self._graph.has_node(node_id) else 0
|
||||
|
||||
async def node_degrees_batch(self, node_ids: List[str]) -> List[str]:
|
||||
return await asyncio.gather(*[self.node_degree(node_id) for node_id in node_ids])
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
return (self._graph.degree(src_id) if self._graph.has_node(src_id) else 0) + (
|
||||
self._graph.degree(tgt_id) if self._graph.has_node(tgt_id) else 0
|
||||
)
|
||||
|
||||
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> list[int]:
|
||||
return await asyncio.gather(*[self.edge_degree(src_id, tgt_id) for src_id, tgt_id in edge_pairs])
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
return self._graph.edges.get((source_node_id, target_node_id))
|
||||
|
||||
async def get_edges_batch(
|
||||
self, edge_pairs: list[tuple[str, str]]
|
||||
) -> list[Union[dict, None]]:
|
||||
return await asyncio.gather(*[self.get_edge(source_node_id, target_node_id) for source_node_id, target_node_id in edge_pairs])
|
||||
|
||||
async def get_node_edges(self, source_node_id: str):
|
||||
if self._graph.has_node(source_node_id):
|
||||
return list(self._graph.edges(source_node_id))
|
||||
return None
|
||||
|
||||
async def get_nodes_edges_batch(
|
||||
self, node_ids: list[str]
|
||||
) -> list[list[tuple[str, str]]]:
|
||||
return await asyncio.gather(*[self.get_node_edges(node_id) for node_id
|
||||
in node_ids])
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
self._graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_nodes_batch(self, nodes_data: list[tuple[str, dict[str, str]]]):
|
||||
await asyncio.gather(*[self.upsert_node(node_id, node_data) for node_id, node_data in nodes_data])
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def upsert_edges_batch(
|
||||
self, edges_data: list[tuple[str, str, dict[str, str]]]
|
||||
):
|
||||
await asyncio.gather(*[self.upsert_edge(source_node_id, target_node_id, edge_data)
|
||||
for source_node_id, target_node_id, edge_data in edges_data])
|
||||
|
||||
async def clustering(self, algorithm: str):
|
||||
if algorithm not in self._clustering_algorithms:
|
||||
raise ValueError(f"Clustering algorithm {algorithm} not supported")
|
||||
await self._clustering_algorithms[algorithm]()
|
||||
|
||||
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
|
||||
results = defaultdict(
|
||||
lambda: dict(
|
||||
level=None,
|
||||
title=None,
|
||||
edges=set(),
|
||||
nodes=set(),
|
||||
chunk_ids=set(),
|
||||
occurrence=0.0,
|
||||
sub_communities=[],
|
||||
)
|
||||
)
|
||||
max_num_ids = 0
|
||||
levels = defaultdict(set)
|
||||
for node_id, node_data in self._graph.nodes(data=True):
|
||||
if "clusters" not in node_data:
|
||||
continue
|
||||
clusters = json.loads(node_data["clusters"])
|
||||
this_node_edges = self._graph.edges(node_id)
|
||||
|
||||
for cluster in clusters:
|
||||
level = cluster["level"]
|
||||
cluster_key = str(cluster["cluster"])
|
||||
levels[level].add(cluster_key)
|
||||
results[cluster_key]["level"] = level
|
||||
results[cluster_key]["title"] = f"Cluster {cluster_key}"
|
||||
results[cluster_key]["nodes"].add(node_id)
|
||||
results[cluster_key]["edges"].update(
|
||||
[tuple(sorted(e)) for e in this_node_edges]
|
||||
)
|
||||
results[cluster_key]["chunk_ids"].update(
|
||||
node_data["source_id"].split(GRAPH_FIELD_SEP)
|
||||
)
|
||||
max_num_ids = max(max_num_ids, len(results[cluster_key]["chunk_ids"]))
|
||||
|
||||
ordered_levels = sorted(levels.keys())
|
||||
for i, curr_level in enumerate(ordered_levels[:-1]):
|
||||
next_level = ordered_levels[i + 1]
|
||||
this_level_comms = levels[curr_level]
|
||||
next_level_comms = levels[next_level]
|
||||
# compute the sub-communities by nodes intersection
|
||||
for comm in this_level_comms:
|
||||
results[comm]["sub_communities"] = [
|
||||
c
|
||||
for c in next_level_comms
|
||||
if results[c]["nodes"].issubset(results[comm]["nodes"])
|
||||
]
|
||||
|
||||
for k, v in results.items():
|
||||
v["edges"] = list(v["edges"])
|
||||
v["edges"] = [list(e) for e in v["edges"]]
|
||||
v["nodes"] = list(v["nodes"])
|
||||
v["chunk_ids"] = list(v["chunk_ids"])
|
||||
v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
|
||||
return dict(results)
|
||||
|
||||
def _cluster_data_to_subgraphs(self, cluster_data: dict[str, list[dict[str, str]]]):
|
||||
for node_id, clusters in cluster_data.items():
|
||||
self._graph.nodes[node_id]["clusters"] = json.dumps(clusters)
|
||||
|
||||
async def _leiden_clustering(self):
|
||||
from graspologic.partition import hierarchical_leiden
|
||||
|
||||
graph = NetworkXStorage.stable_largest_connected_component(self._graph)
|
||||
community_mapping = hierarchical_leiden(
|
||||
graph,
|
||||
max_cluster_size=self.global_config["max_graph_cluster_size"],
|
||||
random_seed=self.global_config["graph_cluster_seed"],
|
||||
)
|
||||
|
||||
node_communities: dict[str, list[dict[str, str]]] = defaultdict(list)
|
||||
__levels = defaultdict(set)
|
||||
for partition in community_mapping:
|
||||
level_key = partition.level
|
||||
cluster_id = partition.cluster
|
||||
node_communities[partition.node].append(
|
||||
{"level": level_key, "cluster": cluster_id}
|
||||
)
|
||||
__levels[level_key].add(cluster_id)
|
||||
node_communities = dict(node_communities)
|
||||
__levels = {k: len(v) for k, v in __levels.items()}
|
||||
logger.info(f"Each level has communities: {dict(__levels)}")
|
||||
self._cluster_data_to_subgraphs(node_communities)
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
|
||||
async def _node2vec_embed(self):
|
||||
from graspologic import embed
|
||||
|
||||
embeddings, nodes = embed.node2vec_embed(
|
||||
self._graph,
|
||||
**self.global_config["node2vec_params"],
|
||||
)
|
||||
|
||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
return embeddings, nodes_ids
|
||||
46
rag-web-ui/backend/nano_graphrag/_storage/kv_json.py
Normal file
46
rag-web-ui/backend/nano_graphrag/_storage/kv_json.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .._utils import load_json, logger, write_json
|
||||
from ..base import (
|
||||
BaseKVStorage,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data = load_json(self._file_name) or {}
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
return list(self._data.keys())
|
||||
|
||||
async def index_done_callback(self):
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def get_by_id(self, id):
|
||||
return self._data.get(id, None)
|
||||
|
||||
async def get_by_ids(self, ids, fields=None):
|
||||
if fields is None:
|
||||
return [self._data.get(id, None) for id in ids]
|
||||
return [
|
||||
(
|
||||
{k: v for k, v in self._data[id].items() if k in fields}
|
||||
if self._data.get(id, None)
|
||||
else None
|
||||
)
|
||||
for id in ids
|
||||
]
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
return set([s for s in data if s not in self._data])
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
self._data.update(data)
|
||||
|
||||
async def drop(self):
|
||||
self._data = {}
|
||||
141
rag-web-ui/backend/nano_graphrag/_storage/vdb_hnswlib.py
Normal file
141
rag-web-ui/backend/nano_graphrag/_storage/vdb_hnswlib.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
import pickle
|
||||
import hnswlib
|
||||
import numpy as np
|
||||
import xxhash
|
||||
|
||||
from .._utils import logger
|
||||
from ..base import BaseVectorStorage
|
||||
|
||||
|
||||
@dataclass
|
||||
class HNSWVectorStorage(BaseVectorStorage):
|
||||
ef_construction: int = 100
|
||||
M: int = 16
|
||||
max_elements: int = 1000000
|
||||
ef_search: int = 50
|
||||
num_threads: int = -1
|
||||
_index: Any = field(init=False)
|
||||
_metadata: dict[str, dict] = field(default_factory=dict)
|
||||
_current_elements: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
self._index_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"{self.namespace}_hnsw.index"
|
||||
)
|
||||
self._metadata_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"{self.namespace}_hnsw_metadata.pkl"
|
||||
)
|
||||
self._embedding_batch_num = self.global_config.get("embedding_batch_num", 100)
|
||||
|
||||
hnsw_params = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
self.ef_construction = hnsw_params.get("ef_construction", self.ef_construction)
|
||||
self.M = hnsw_params.get("M", self.M)
|
||||
self.max_elements = hnsw_params.get("max_elements", self.max_elements)
|
||||
self.ef_search = hnsw_params.get("ef_search", self.ef_search)
|
||||
self.num_threads = hnsw_params.get("num_threads", self.num_threads)
|
||||
self._index = hnswlib.Index(
|
||||
space="cosine", dim=self.embedding_func.embedding_dim
|
||||
)
|
||||
|
||||
if os.path.exists(self._index_file_name) and os.path.exists(
|
||||
self._metadata_file_name
|
||||
):
|
||||
self._index.load_index(
|
||||
self._index_file_name, max_elements=self.max_elements
|
||||
)
|
||||
with open(self._metadata_file_name, "rb") as f:
|
||||
self._metadata, self._current_elements = pickle.load(f)
|
||||
logger.info(
|
||||
f"Loaded existing index for {self.namespace} with {self._current_elements} elements"
|
||||
)
|
||||
else:
|
||||
self._index.init_index(
|
||||
max_elements=self.max_elements,
|
||||
ef_construction=self.ef_construction,
|
||||
M=self.M,
|
||||
)
|
||||
self._index.set_ef(self.ef_search)
|
||||
self._metadata = {}
|
||||
self._current_elements = 0
|
||||
logger.info(f"Created new index for {self.namespace}")
|
||||
|
||||
async def upsert(self, data: dict[str, dict]) -> np.ndarray:
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not data:
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
|
||||
if self._current_elements + len(data) > self.max_elements:
|
||||
raise ValueError(
|
||||
f"Cannot insert {len(data)} elements. Current: {self._current_elements}, Max: {self.max_elements}"
|
||||
)
|
||||
|
||||
list_data = [
|
||||
{
|
||||
"id": k,
|
||||
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batch_size = min(self._embedding_batch_num, len(contents))
|
||||
embeddings = np.concatenate(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self.embedding_func(contents[i : i + batch_size])
|
||||
for i in range(0, len(contents), batch_size)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
ids = np.fromiter(
|
||||
(xxhash.xxh32_intdigest(d["id"].encode()) for d in list_data),
|
||||
dtype=np.uint32,
|
||||
count=len(list_data),
|
||||
)
|
||||
self._metadata.update(
|
||||
{
|
||||
id_int: {
|
||||
k: v for k, v in d.items() if k in self.meta_fields or k == "id"
|
||||
}
|
||||
for id_int, d in zip(ids, list_data)
|
||||
}
|
||||
)
|
||||
self._index.add_items(data=embeddings, ids=ids, num_threads=self.num_threads)
|
||||
self._current_elements = self._index.get_current_count()
|
||||
return ids
|
||||
|
||||
async def query(self, query: str, top_k: int = 5) -> list[dict]:
|
||||
if self._current_elements == 0:
|
||||
return []
|
||||
|
||||
top_k = min(top_k, self._current_elements)
|
||||
|
||||
if top_k > self.ef_search:
|
||||
logger.warning(
|
||||
f"Setting ef_search to {top_k} because top_k is larger than ef_search"
|
||||
)
|
||||
self._index.set_ef(top_k)
|
||||
|
||||
embedding = await self.embedding_func([query])
|
||||
labels, distances = self._index.knn_query(
|
||||
data=embedding[0], k=top_k, num_threads=self.num_threads
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
**self._metadata.get(label, {}),
|
||||
"distance": distance,
|
||||
"similarity": 1 - distance,
|
||||
}
|
||||
for label, distance in zip(labels[0], distances[0])
|
||||
]
|
||||
|
||||
async def index_done_callback(self):
|
||||
self._index.save_index(self._index_file_name)
|
||||
with open(self._metadata_file_name, "wb") as f:
|
||||
pickle.dump((self._metadata, self._current_elements), f)
|
||||
@@ -0,0 +1,68 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from nano_vectordb import NanoVectorDB
|
||||
|
||||
from .._utils import logger
|
||||
from ..base import BaseVectorStorage
|
||||
|
||||
|
||||
@dataclass
|
||||
class NanoVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = 0.2
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
)
|
||||
self.cosine_better_than_threshold = self.global_config.get(
|
||||
"query_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
list_data = [
|
||||
{
|
||||
"__id__": k,
|
||||
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
results = self._client.upsert(datas=list_data)
|
||||
return results
|
||||
|
||||
async def query(self, query: str, top_k=5):
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
results = self._client.query(
|
||||
query=embedding,
|
||||
top_k=top_k,
|
||||
better_than_threshold=self.cosine_better_than_threshold,
|
||||
)
|
||||
results = [
|
||||
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
|
||||
]
|
||||
return results
|
||||
|
||||
async def index_done_callback(self):
|
||||
self._client.save()
|
||||
Reference in New Issue
Block a user