init. project

This commit is contained in:
2026-04-13 11:34:23 +08:00
commit c7c0659a85
202 changed files with 31196 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
from .gdb_networkx import NetworkXStorage
from .gdb_neo4j import Neo4jStorage
from .vdb_nanovectordb import NanoVectorDBStorage
from .kv_json import JsonKVStorage
try:
from .vdb_hnswlib import HNSWVectorStorage
except ImportError:
HNSWVectorStorage = None

View File

@@ -0,0 +1,529 @@
import json
import asyncio
from collections import defaultdict
from typing import List
from neo4j import AsyncGraphDatabase
from dataclasses import dataclass
from typing import Union
from ..base import BaseGraphStorage, SingleCommunitySchema
from .._utils import logger
from ..prompt import GRAPH_FIELD_SEP
neo4j_lock = asyncio.Lock()
def make_path_idable(path):
return path.replace(".", "_").replace("/", "__").replace("-", "_").replace(":", "_").replace("\\", "__")
@dataclass
class Neo4jStorage(BaseGraphStorage):
def __post_init__(self):
self.neo4j_url = self.global_config["addon_params"].get("neo4j_url", None)
self.neo4j_auth = self.global_config["addon_params"].get("neo4j_auth", None)
self.namespace = (
f"{make_path_idable(self.global_config['working_dir'])}__{self.namespace}"
)
logger.info(f"Using the label {self.namespace} for Neo4j as identifier")
if self.neo4j_url is None or self.neo4j_auth is None:
raise ValueError("Missing neo4j_url or neo4j_auth in addon_params")
self.async_driver = AsyncGraphDatabase.driver(
self.neo4j_url, auth=self.neo4j_auth, max_connection_pool_size=50,
)
# async def create_database(self):
# async with self.async_driver.session() as session:
# try:
# constraints = await session.run("SHOW CONSTRAINTS")
# # TODO I don't know why CREATE CONSTRAINT IF NOT EXISTS still trigger error
# # so have to check if the constrain exists
# constrain_exists = False
# async for record in constraints:
# if (
# self.namespace in record["labelsOrTypes"]
# and "id" in record["properties"]
# and record["type"] == "UNIQUENESS"
# ):
# constrain_exists = True
# break
# if not constrain_exists:
# await session.run(
# f"CREATE CONSTRAINT FOR (n:{self.namespace}) REQUIRE n.id IS UNIQUE"
# )
# logger.info(f"Add constraint for namespace: {self.namespace}")
# except Exception as e:
# logger.error(f"Error accessing or setting up the database: {str(e)}")
# raise
async def _init_workspace(self):
await self.async_driver.verify_authentication()
await self.async_driver.verify_connectivity()
# TODOLater: create database if not exists always cause an error when async
# await self.create_database()
async def index_start_callback(self):
logger.info("Init Neo4j workspace")
await self._init_workspace()
# create index for faster searching
try:
async with self.async_driver.session() as session:
await session.run(
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.id)"
)
await session.run(
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.entity_type)"
)
await session.run(
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.communityIds)"
)
await session.run(
f"CREATE INDEX IF NOT EXISTS FOR (n:`{self.namespace}`) ON (n.source_id)"
)
logger.info("Neo4j indexes created successfully")
except Exception as e:
logger.error(f"Failed to create indexes: {e}")
raise e
async def has_node(self, node_id: str) -> bool:
async with self.async_driver.session() as session:
result = await session.run(
f"MATCH (n:`{self.namespace}`) WHERE n.id = $node_id RETURN COUNT(n) > 0 AS exists",
node_id=node_id,
)
record = await result.single()
return record["exists"] if record else False
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
async with self.async_driver.session() as session:
result = await session.run(
f"""
MATCH (s:`{self.namespace}`)
WHERE s.id = $source_id
MATCH (t:`{self.namespace}`)
WHERE t.id = $target_id
RETURN EXISTS((s)-[]->(t)) AS exists
""",
source_id=source_node_id,
target_id=target_node_id,
)
record = await result.single()
return record["exists"] if record else False
async def node_degree(self, node_id: str) -> int:
results = await self.node_degrees_batch([node_id])
return results[0] if results else 0
async def node_degrees_batch(self, node_ids: List[str]) -> List[str]:
if not node_ids:
return {}
result_dict = {node_id: 0 for node_id in node_ids}
async with self.async_driver.session() as session:
result = await session.run(
f"""
UNWIND $node_ids AS node_id
MATCH (n:`{self.namespace}`)
WHERE n.id = node_id
OPTIONAL MATCH (n)-[]-(m:`{self.namespace}`)
RETURN node_id, COUNT(m) AS degree
""",
node_ids=node_ids
)
async for record in result:
result_dict[record["node_id"]] = record["degree"]
return [result_dict[node_id] for node_id in node_ids]
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
results = await self.edge_degrees_batch([(src_id, tgt_id)])
return results[0] if results else 0
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> list[int]:
if not edge_pairs:
return []
result_dict = {tuple(edge_pair): 0 for edge_pair in edge_pairs}
edges_params = [{"src_id": src, "tgt_id": tgt} for src, tgt in edge_pairs]
try:
async with self.async_driver.session() as session:
result = await session.run(
f"""
UNWIND $edges AS edge
MATCH (s:`{self.namespace}`)
WHERE s.id = edge.src_id
WITH edge, s
OPTIONAL MATCH (s)-[]-(n1:`{self.namespace}`)
WITH edge, COUNT(n1) AS src_degree
MATCH (t:`{self.namespace}`)
WHERE t.id = edge.tgt_id
WITH edge, src_degree, t
OPTIONAL MATCH (t)-[]-(n2:`{self.namespace}`)
WITH edge.src_id AS src_id, edge.tgt_id AS tgt_id, src_degree, COUNT(n2) AS tgt_degree
RETURN src_id, tgt_id, src_degree + tgt_degree AS degree
""",
edges=edges_params
)
async for record in result:
src_id = record["src_id"]
tgt_id = record["tgt_id"]
degree = record["degree"]
# 更新结果字典
edge_pair = (src_id, tgt_id)
result_dict[edge_pair] = degree
return [result_dict[tuple(edge_pair)] for edge_pair in edge_pairs]
except Exception as e:
logger.error(f"Error in batch edge degree calculation: {e}")
return [0] * len(edge_pairs)
async def get_node(self, node_id: str) -> Union[dict, None]:
result = await self.get_nodes_batch([node_id])
return result[0] if result else None
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, Union[dict, None]]:
if not node_ids:
return {}
result_dict = {node_id: None for node_id in node_ids}
try:
async with self.async_driver.session() as session:
result = await session.run(
f"""
UNWIND $node_ids AS node_id
MATCH (n:`{self.namespace}`)
WHERE n.id = node_id
RETURN node_id, properties(n) AS node_data
""",
node_ids=node_ids
)
async for record in result:
node_id = record["node_id"]
raw_node_data = record["node_data"]
if raw_node_data:
raw_node_data["clusters"] = json.dumps(
[
{
"level": index,
"cluster": cluster_id,
}
for index, cluster_id in enumerate(
raw_node_data.get("communityIds", [])
)
]
)
result_dict[node_id] = raw_node_data
return [result_dict[node_id] for node_id in node_ids]
except Exception as e:
logger.error(f"Error in batch node retrieval: {e}")
raise e
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
results = await self.get_edges_batch([(source_node_id, target_node_id)])
return results[0] if results else None
async def get_edges_batch(
self, edge_pairs: list[tuple[str, str]]
) -> list[Union[dict, None]]:
if not edge_pairs:
return []
result_dict = {tuple(edge_pair): None for edge_pair in edge_pairs}
edges_params = [{"source_id": src, "target_id": tgt} for src, tgt in edge_pairs]
try:
async with self.async_driver.session() as session:
result = await session.run(
f"""
UNWIND $edges AS edge
MATCH (s:`{self.namespace}`)-[r]->(t:`{self.namespace}`)
WHERE s.id = edge.source_id AND t.id = edge.target_id
RETURN edge.source_id AS source_id, edge.target_id AS target_id, properties(r) AS edge_data
""",
edges=edges_params
)
async for record in result:
source_id = record["source_id"]
target_id = record["target_id"]
edge_data = record["edge_data"]
edge_pair = (source_id, target_id)
result_dict[edge_pair] = edge_data
return [result_dict[tuple(edge_pair)] for edge_pair in edge_pairs]
except Exception as e:
logger.error(f"Error in batch edge retrieval: {e}")
return [None] * len(edge_pairs)
async def get_node_edges(
self, source_node_id: str
) -> list[tuple[str, str]]:
results = await self.get_nodes_edges_batch([source_node_id])
return results[0] if results else []
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> list[list[tuple[str, str]]]:
if not node_ids:
return []
result_dict = {node_id: [] for node_id in node_ids}
try:
async with self.async_driver.session() as session:
result = await session.run(
f"""
UNWIND $node_ids AS node_id
MATCH (s:`{self.namespace}`)-[r]->(t:`{self.namespace}`)
WHERE s.id = node_id
RETURN s.id AS source_id, t.id AS target_id
""",
node_ids=node_ids
)
async for record in result:
source_id = record["source_id"]
target_id = record["target_id"]
if source_id in result_dict:
result_dict[source_id].append((source_id, target_id))
return [result_dict[node_id] for node_id in node_ids]
except Exception as e:
logger.error(f"Error in batch node edges retrieval: {e}")
return [[] for _ in node_ids]
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
await self.upsert_nodes_batch([(node_id, node_data)])
async def upsert_nodes_batch(self, nodes_data: list[tuple[str, dict[str, str]]]):
if not nodes_data:
return []
nodes_by_type = {}
for node_id, node_data in nodes_data:
node_type = node_data.get("entity_type", "UNKNOWN").strip('"')
if node_type not in nodes_by_type:
nodes_by_type[node_type] = []
nodes_by_type[node_type].append((node_id, node_data))
async with self.async_driver.session() as session:
for node_type, type_nodes in nodes_by_type.items():
params = [{"id": node_id, "data": node_data} for node_id, node_data in type_nodes]
await session.run(
f"""
UNWIND $nodes AS node
MERGE (n:`{self.namespace}`:`{node_type}` {{id: node.id}})
SET n += node.data
""",
nodes=params
)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
await self.upsert_edges_batch([(source_node_id, target_node_id, edge_data)])
async def upsert_edges_batch(
self, edges_data: list[tuple[str, str, dict[str, str]]]
):
if not edges_data:
return
edges_params = []
for source_id, target_id, edge_data in edges_data:
edge_data_copy = edge_data.copy()
edge_data_copy.setdefault("weight", 0.0)
edges_params.append({
"source_id": source_id,
"target_id": target_id,
"edge_data": edge_data_copy
})
async with self.async_driver.session() as session:
await session.run(
f"""
UNWIND $edges AS edge
MATCH (s:`{self.namespace}`)
WHERE s.id = edge.source_id
WITH edge, s
MATCH (t:`{self.namespace}`)
WHERE t.id = edge.target_id
MERGE (s)-[r:RELATED]->(t)
SET r += edge.edge_data
""",
edges=edges_params
)
async def clustering(self, algorithm: str):
if algorithm != "leiden":
raise ValueError(
f"Clustering algorithm {algorithm} not supported in Neo4j implementation"
)
random_seed = self.global_config["graph_cluster_seed"]
max_level = self.global_config["max_graph_cluster_size"]
async with self.async_driver.session() as session:
try:
# Project the graph with undirected relationships
await session.run(
f"""
CALL gds.graph.project(
'graph_{self.namespace}',
['{self.namespace}'],
{{
RELATED: {{
orientation: 'UNDIRECTED',
properties: ['weight']
}}
}}
)
"""
)
# Run Leiden algorithm
result = await session.run(
f"""
CALL gds.leiden.write(
'graph_{self.namespace}',
{{
writeProperty: 'communityIds',
includeIntermediateCommunities: True,
relationshipWeightProperty: "weight",
maxLevels: {max_level},
tolerance: 0.0001,
gamma: 1.0,
theta: 0.01,
randomSeed: {random_seed}
}}
)
YIELD communityCount, modularities;
"""
)
result = await result.single()
community_count: int = result["communityCount"]
modularities = result["modularities"]
logger.info(
f"Performed graph clustering with {community_count} communities and modularities {modularities}"
)
finally:
# Drop the projected graph
await session.run(f"CALL gds.graph.drop('graph_{self.namespace}')")
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
results = defaultdict(
lambda: dict(
level=None,
title=None,
edges=set(),
nodes=set(),
chunk_ids=set(),
occurrence=0.0,
sub_communities=[],
)
)
async with self.async_driver.session() as session:
# Fetch community data
result = await session.run(
f"""
MATCH (n:`{self.namespace}`)
WITH n, n.communityIds AS communityIds, [(n)-[]-(m:`{self.namespace}`) | m.id] AS connected_nodes
RETURN n.id AS node_id, n.source_id AS source_id,
communityIds AS cluster_key,
connected_nodes
"""
)
# records = await result.fetch()
max_num_ids = 0
async for record in result:
for index, c_id in enumerate(record["cluster_key"]):
node_id = str(record["node_id"])
source_id = record["source_id"]
level = index
cluster_key = str(c_id)
connected_nodes = record["connected_nodes"]
results[cluster_key]["level"] = level
results[cluster_key]["title"] = f"Cluster {cluster_key}"
results[cluster_key]["nodes"].add(node_id)
results[cluster_key]["edges"].update(
[
tuple(sorted([node_id, str(connected)]))
for connected in connected_nodes
if connected != node_id
]
)
chunk_ids = source_id.split(GRAPH_FIELD_SEP)
results[cluster_key]["chunk_ids"].update(chunk_ids)
max_num_ids = max(
max_num_ids, len(results[cluster_key]["chunk_ids"])
)
# Process results
for k, v in results.items():
v["edges"] = [list(e) for e in v["edges"]]
v["nodes"] = list(v["nodes"])
v["chunk_ids"] = list(v["chunk_ids"])
v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
# Compute sub-communities (this is a simplified approach)
for cluster in results.values():
cluster["sub_communities"] = [
sub_key
for sub_key, sub_cluster in results.items()
if sub_cluster["level"] > cluster["level"]
and set(sub_cluster["nodes"]).issubset(set(cluster["nodes"]))
]
return dict(results)
async def index_done_callback(self):
await self.async_driver.close()
async def _debug_delete_all_node_edges(self):
async with self.async_driver.session() as session:
try:
# Delete all relationships in the namespace
await session.run(f"MATCH (n:`{self.namespace}`)-[r]-() DELETE r")
# Delete all nodes in the namespace
await session.run(f"MATCH (n:`{self.namespace}`) DELETE n")
logger.info(
f"All nodes and edges in namespace '{self.namespace}' have been deleted."
)
except Exception as e:
logger.error(f"Error deleting nodes and edges: {str(e)}")
raise

View File

@@ -0,0 +1,268 @@
import html
import json
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Union, cast, List
import networkx as nx
import numpy as np
import asyncio
from .._utils import logger
from ..base import (
BaseGraphStorage,
SingleCommunitySchema,
)
from ..prompt import GRAPH_FIELD_SEP
@dataclass
class NetworkXStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name) -> nx.Graph:
if os.path.exists(file_name):
return nx.read_graphml(file_name)
return None
@staticmethod
def write_nx_graph(graph: nx.Graph, file_name):
logger.info(
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
)
nx.write_graphml(graph, file_name)
@staticmethod
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
"""
from graspologic.utils import largest_connected_component
graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph))
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
graph = nx.relabel_nodes(graph, node_mapping)
return NetworkXStorage._stabilize_graph(graph)
@staticmethod
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Ensure an undirected graph with the same relationships will always be read the same way.
"""
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
sorted_nodes = graph.nodes(data=True)
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
fixed_graph.add_nodes_from(sorted_nodes)
edges = list(graph.edges(data=True))
if not graph.is_directed():
def _sort_source_target(edge):
source, target, edge_data = edge
if source > target:
temp = source
source = target
target = temp
return source, target, edge_data
edges = [_sort_source_target(edge) for edge in edges]
def _get_edge_key(source: Any, target: Any) -> str:
return f"{source} -> {target}"
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
fixed_graph.add_edges_from(edges)
return fixed_graph
def __post_init__(self):
self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
)
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
self._graph = preloaded_graph or nx.Graph()
self._clustering_algorithms = {
"leiden": self._leiden_clustering,
}
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def index_done_callback(self):
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool:
return self._graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> Union[dict, None]:
return self._graph.nodes.get(node_id)
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, Union[dict, None]]:
return await asyncio.gather(*[self.get_node(node_id) for node_id in node_ids])
async def node_degree(self, node_id: str) -> int:
# [numberchiffre]: node_id not part of graph returns `DegreeView({})` instead of 0
return self._graph.degree(node_id) if self._graph.has_node(node_id) else 0
async def node_degrees_batch(self, node_ids: List[str]) -> List[str]:
return await asyncio.gather(*[self.node_degree(node_id) for node_id in node_ids])
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
return (self._graph.degree(src_id) if self._graph.has_node(src_id) else 0) + (
self._graph.degree(tgt_id) if self._graph.has_node(tgt_id) else 0
)
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> list[int]:
return await asyncio.gather(*[self.edge_degree(src_id, tgt_id) for src_id, tgt_id in edge_pairs])
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
return self._graph.edges.get((source_node_id, target_node_id))
async def get_edges_batch(
self, edge_pairs: list[tuple[str, str]]
) -> list[Union[dict, None]]:
return await asyncio.gather(*[self.get_edge(source_node_id, target_node_id) for source_node_id, target_node_id in edge_pairs])
async def get_node_edges(self, source_node_id: str):
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
return None
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> list[list[tuple[str, str]]]:
return await asyncio.gather(*[self.get_node_edges(node_id) for node_id
in node_ids])
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
self._graph.add_node(node_id, **node_data)
async def upsert_nodes_batch(self, nodes_data: list[tuple[str, dict[str, str]]]):
await asyncio.gather(*[self.upsert_node(node_id, node_data) for node_id, node_data in nodes_data])
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def upsert_edges_batch(
self, edges_data: list[tuple[str, str, dict[str, str]]]
):
await asyncio.gather(*[self.upsert_edge(source_node_id, target_node_id, edge_data)
for source_node_id, target_node_id, edge_data in edges_data])
async def clustering(self, algorithm: str):
if algorithm not in self._clustering_algorithms:
raise ValueError(f"Clustering algorithm {algorithm} not supported")
await self._clustering_algorithms[algorithm]()
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
results = defaultdict(
lambda: dict(
level=None,
title=None,
edges=set(),
nodes=set(),
chunk_ids=set(),
occurrence=0.0,
sub_communities=[],
)
)
max_num_ids = 0
levels = defaultdict(set)
for node_id, node_data in self._graph.nodes(data=True):
if "clusters" not in node_data:
continue
clusters = json.loads(node_data["clusters"])
this_node_edges = self._graph.edges(node_id)
for cluster in clusters:
level = cluster["level"]
cluster_key = str(cluster["cluster"])
levels[level].add(cluster_key)
results[cluster_key]["level"] = level
results[cluster_key]["title"] = f"Cluster {cluster_key}"
results[cluster_key]["nodes"].add(node_id)
results[cluster_key]["edges"].update(
[tuple(sorted(e)) for e in this_node_edges]
)
results[cluster_key]["chunk_ids"].update(
node_data["source_id"].split(GRAPH_FIELD_SEP)
)
max_num_ids = max(max_num_ids, len(results[cluster_key]["chunk_ids"]))
ordered_levels = sorted(levels.keys())
for i, curr_level in enumerate(ordered_levels[:-1]):
next_level = ordered_levels[i + 1]
this_level_comms = levels[curr_level]
next_level_comms = levels[next_level]
# compute the sub-communities by nodes intersection
for comm in this_level_comms:
results[comm]["sub_communities"] = [
c
for c in next_level_comms
if results[c]["nodes"].issubset(results[comm]["nodes"])
]
for k, v in results.items():
v["edges"] = list(v["edges"])
v["edges"] = [list(e) for e in v["edges"]]
v["nodes"] = list(v["nodes"])
v["chunk_ids"] = list(v["chunk_ids"])
v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
return dict(results)
def _cluster_data_to_subgraphs(self, cluster_data: dict[str, list[dict[str, str]]]):
for node_id, clusters in cluster_data.items():
self._graph.nodes[node_id]["clusters"] = json.dumps(clusters)
async def _leiden_clustering(self):
from graspologic.partition import hierarchical_leiden
graph = NetworkXStorage.stable_largest_connected_component(self._graph)
community_mapping = hierarchical_leiden(
graph,
max_cluster_size=self.global_config["max_graph_cluster_size"],
random_seed=self.global_config["graph_cluster_seed"],
)
node_communities: dict[str, list[dict[str, str]]] = defaultdict(list)
__levels = defaultdict(set)
for partition in community_mapping:
level_key = partition.level
cluster_id = partition.cluster
node_communities[partition.node].append(
{"level": level_key, "cluster": cluster_id}
)
__levels[level_key].add(cluster_id)
node_communities = dict(node_communities)
__levels = {k: len(v) for k, v in __levels.items()}
logger.info(f"Each level has communities: {dict(__levels)}")
self._cluster_data_to_subgraphs(node_communities)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
async def _node2vec_embed(self):
from graspologic import embed
embeddings, nodes = embed.node2vec_embed(
self._graph,
**self.global_config["node2vec_params"],
)
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids

View File

@@ -0,0 +1,46 @@
import os
from dataclasses import dataclass
from .._utils import load_json, logger, write_json
from ..base import (
BaseKVStorage,
)
@dataclass
class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data = load_json(self._file_name) or {}
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
async def all_keys(self) -> list[str]:
return list(self._data.keys())
async def index_done_callback(self):
write_json(self._data, self._file_name)
async def get_by_id(self, id):
return self._data.get(id, None)
async def get_by_ids(self, ids, fields=None):
if fields is None:
return [self._data.get(id, None) for id in ids]
return [
(
{k: v for k, v in self._data[id].items() if k in fields}
if self._data.get(id, None)
else None
)
for id in ids
]
async def filter_keys(self, data: list[str]) -> set[str]:
return set([s for s in data if s not in self._data])
async def upsert(self, data: dict[str, dict]):
self._data.update(data)
async def drop(self):
self._data = {}

View File

@@ -0,0 +1,141 @@
import asyncio
import os
from dataclasses import dataclass, field
from typing import Any
import pickle
import hnswlib
import numpy as np
import xxhash
from .._utils import logger
from ..base import BaseVectorStorage
@dataclass
class HNSWVectorStorage(BaseVectorStorage):
ef_construction: int = 100
M: int = 16
max_elements: int = 1000000
ef_search: int = 50
num_threads: int = -1
_index: Any = field(init=False)
_metadata: dict[str, dict] = field(default_factory=dict)
_current_elements: int = 0
def __post_init__(self):
self._index_file_name = os.path.join(
self.global_config["working_dir"], f"{self.namespace}_hnsw.index"
)
self._metadata_file_name = os.path.join(
self.global_config["working_dir"], f"{self.namespace}_hnsw_metadata.pkl"
)
self._embedding_batch_num = self.global_config.get("embedding_batch_num", 100)
hnsw_params = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.ef_construction = hnsw_params.get("ef_construction", self.ef_construction)
self.M = hnsw_params.get("M", self.M)
self.max_elements = hnsw_params.get("max_elements", self.max_elements)
self.ef_search = hnsw_params.get("ef_search", self.ef_search)
self.num_threads = hnsw_params.get("num_threads", self.num_threads)
self._index = hnswlib.Index(
space="cosine", dim=self.embedding_func.embedding_dim
)
if os.path.exists(self._index_file_name) and os.path.exists(
self._metadata_file_name
):
self._index.load_index(
self._index_file_name, max_elements=self.max_elements
)
with open(self._metadata_file_name, "rb") as f:
self._metadata, self._current_elements = pickle.load(f)
logger.info(
f"Loaded existing index for {self.namespace} with {self._current_elements} elements"
)
else:
self._index.init_index(
max_elements=self.max_elements,
ef_construction=self.ef_construction,
M=self.M,
)
self._index.set_ef(self.ef_search)
self._metadata = {}
self._current_elements = 0
logger.info(f"Created new index for {self.namespace}")
async def upsert(self, data: dict[str, dict]) -> np.ndarray:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not data:
logger.warning("You insert an empty data to vector DB")
return []
if self._current_elements + len(data) > self.max_elements:
raise ValueError(
f"Cannot insert {len(data)} elements. Current: {self._current_elements}, Max: {self.max_elements}"
)
list_data = [
{
"id": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batch_size = min(self._embedding_batch_num, len(contents))
embeddings = np.concatenate(
await asyncio.gather(
*[
self.embedding_func(contents[i : i + batch_size])
for i in range(0, len(contents), batch_size)
]
)
)
ids = np.fromiter(
(xxhash.xxh32_intdigest(d["id"].encode()) for d in list_data),
dtype=np.uint32,
count=len(list_data),
)
self._metadata.update(
{
id_int: {
k: v for k, v in d.items() if k in self.meta_fields or k == "id"
}
for id_int, d in zip(ids, list_data)
}
)
self._index.add_items(data=embeddings, ids=ids, num_threads=self.num_threads)
self._current_elements = self._index.get_current_count()
return ids
async def query(self, query: str, top_k: int = 5) -> list[dict]:
if self._current_elements == 0:
return []
top_k = min(top_k, self._current_elements)
if top_k > self.ef_search:
logger.warning(
f"Setting ef_search to {top_k} because top_k is larger than ef_search"
)
self._index.set_ef(top_k)
embedding = await self.embedding_func([query])
labels, distances = self._index.knn_query(
data=embedding[0], k=top_k, num_threads=self.num_threads
)
return [
{
**self._metadata.get(label, {}),
"distance": distance,
"similarity": 1 - distance,
}
for label, distance in zip(labels[0], distances[0])
]
async def index_done_callback(self):
self._index.save_index(self._index_file_name)
with open(self._metadata_file_name, "wb") as f:
pickle.dump((self._metadata, self._current_elements), f)

View File

@@ -0,0 +1,68 @@
import asyncio
import os
from dataclasses import dataclass
import numpy as np
from nano_vectordb import NanoVectorDB
from .._utils import logger
from ..base import BaseVectorStorage
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
self.cosine_better_than_threshold = self.global_config.get(
"query_better_than_threshold", self.cosine_better_than_threshold
)
async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
list_data = [
{
"__id__": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data)
return results
async def query(self, query: str, top_k=5):
embedding = await self.embedding_func([query])
embedding = embedding[0]
results = self._client.query(
query=embedding,
top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold,
)
results = [
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
]
return results
async def index_done_callback(self):
self._client.save()