Files
rag_agent/rag-web-ui/backend/nano_graphrag/graphrag.py
2026-04-13 11:34:23 +08:00

383 lines
13 KiB
Python

import asyncio
import os
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Callable, Dict, List, Optional, Type, Union, cast
from ._llm import (
amazon_bedrock_embedding,
create_amazon_bedrock_complete_function,
gpt_4o_complete,
gpt_4o_mini_complete,
openai_embedding,
azure_gpt_4o_complete,
azure_openai_embedding,
azure_gpt_4o_mini_complete,
)
from ._op import (
chunking_by_token_size,
extract_entities,
generate_community_report,
get_chunks,
local_query,
global_query,
naive_query,
)
from ._storage import (
JsonKVStorage,
NanoVectorDBStorage,
NetworkXStorage,
)
from ._utils import (
EmbeddingFunc,
compute_mdhash_id,
limit_async_func_call,
convert_response_to_json,
always_get_an_event_loop,
logger,
TokenizerWrapper,
)
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
StorageNameSpace,
QueryParam,
)
@dataclass
class GraphRAG:
working_dir: str = field(
default_factory=lambda: f"./nano_graphrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
)
# graph mode
enable_local: bool = True
enable_naive_rag: bool = False
# text chunking
tokenizer_type: str = "tiktoken" # or 'huggingface'
tiktoken_model_name: str = "gpt-4o"
huggingface_model_name: str = "bert-base-uncased" # default HF model
chunk_func: Callable[
[
list[list[int]],
List[str],
TokenizerWrapper,
Optional[int],
Optional[int],
],
List[Dict[str, Union[str, int]]],
] = chunking_by_token_size
chunk_token_size: int = 1200
chunk_overlap_token_size: int = 100
# entity extraction
entity_extract_max_gleaning: int = 1
entity_summary_to_max_tokens: int = 500
# graph clustering
graph_cluster_algorithm: str = "leiden"
max_graph_cluster_size: int = 10
graph_cluster_seed: int = 0xDEADBEEF
# node embedding
node_embedding_algorithm: str = "node2vec"
node2vec_params: dict = field(
default_factory=lambda: {
"dimensions": 1536,
"num_walks": 10,
"walk_length": 40,
"num_walks": 10,
"window_size": 2,
"iterations": 3,
"random_seed": 3,
}
)
# community reports
special_community_report_llm_kwargs: dict = field(
default_factory=lambda: {"response_format": {"type": "json_object"}}
)
# text embedding
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
embedding_batch_num: int = 32
embedding_func_max_async: int = 16
query_better_than_threshold: float = 0.2
# LLM
using_azure_openai: bool = False
using_amazon_bedrock: bool = False
best_model_id: str = "us.anthropic.claude-3-sonnet-20240229-v1:0"
cheap_model_id: str = "us.anthropic.claude-3-haiku-20240307-v1:0"
best_model_func: callable = gpt_4o_complete
best_model_max_token_size: int = 32768
best_model_max_async: int = 16
cheap_model_func: callable = gpt_4o_mini_complete
cheap_model_max_token_size: int = 32768
cheap_model_max_async: int = 16
# entity extraction
entity_extraction_func: callable = extract_entities
# storage
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
enable_llm_cache: bool = True
# extension
always_create_working_dir: bool = True
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json
def __post_init__(self):
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"GraphRAG init with param:\n\n {_print_config}\n")
self.tokenizer_wrapper = TokenizerWrapper(
tokenizer_type=self.tokenizer_type,
model_name=self.tiktoken_model_name if self.tokenizer_type == "tiktoken" else self.huggingface_model_name
)
if self.using_azure_openai:
# If there's no OpenAI API key, use Azure OpenAI
if self.best_model_func == gpt_4o_complete:
self.best_model_func = azure_gpt_4o_complete
if self.cheap_model_func == gpt_4o_mini_complete:
self.cheap_model_func = azure_gpt_4o_mini_complete
if self.embedding_func == openai_embedding:
self.embedding_func = azure_openai_embedding
logger.info(
"Switched the default openai funcs to Azure OpenAI if you didn't set any of it"
)
if self.using_amazon_bedrock:
self.best_model_func = create_amazon_bedrock_complete_function(self.best_model_id)
self.cheap_model_func = create_amazon_bedrock_complete_function(self.cheap_model_id)
self.embedding_func = amazon_bedrock_embedding
logger.info(
"Switched the default openai funcs to Amazon Bedrock"
)
if not os.path.exists(self.working_dir) and self.always_create_working_dir:
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
self.full_docs = self.key_string_value_json_storage_cls(
namespace="full_docs", global_config=asdict(self)
)
self.text_chunks = self.key_string_value_json_storage_cls(
namespace="text_chunks", global_config=asdict(self)
)
self.llm_response_cache = (
self.key_string_value_json_storage_cls(
namespace="llm_response_cache", global_config=asdict(self)
)
if self.enable_llm_cache
else None
)
self.community_reports = self.key_string_value_json_storage_cls(
namespace="community_reports", global_config=asdict(self)
)
self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace="chunk_entity_relation", global_config=asdict(self)
)
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
)
self.entities_vdb = (
self.vector_db_storage_cls(
namespace="entities",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
if self.enable_local
else None
)
self.chunks_vdb = (
self.vector_db_storage_cls(
namespace="chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
if self.enable_naive_rag
else None
)
self.best_model_func = limit_async_func_call(self.best_model_max_async)(
partial(self.best_model_func, hashing_kv=self.llm_response_cache)
)
self.cheap_model_func = limit_async_func_call(self.cheap_model_max_async)(
partial(self.cheap_model_func, hashing_kv=self.llm_response_cache)
)
def insert(self, string_or_strings):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert(string_or_strings))
def query(self, query: str, param: QueryParam = QueryParam()):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param))
async def aquery(self, query: str, param: QueryParam = QueryParam()):
if param.mode == "local" and not self.enable_local:
raise ValueError("enable_local is False, cannot query in local mode")
if param.mode == "naive" and not self.enable_naive_rag:
raise ValueError("enable_naive_rag is False, cannot query in naive mode")
if param.mode == "local":
response = await local_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.community_reports,
self.text_chunks,
param,
self.tokenizer_wrapper,
asdict(self),
)
elif param.mode == "global":
response = await global_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.community_reports,
self.text_chunks,
param,
self.tokenizer_wrapper,
asdict(self),
)
elif param.mode == "naive":
response = await naive_query(
query,
self.chunks_vdb,
self.text_chunks,
param,
self.tokenizer_wrapper,
asdict(self),
)
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
return response
async def ainsert(self, string_or_strings):
await self._insert_start()
try:
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
# ---------- new docs
new_docs = {
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
for c in string_or_strings
}
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning(f"All docs are already in the storage")
return
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
# ---------- chunking
inserting_chunks = get_chunks(
new_docs=new_docs,
chunk_func=self.chunk_func,
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tokenizer_wrapper=self.tokenizer_wrapper,
)
_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
logger.warning(f"All chunks are already in the storage")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
if self.enable_naive_rag:
logger.info("Insert chunks for naive RAG")
await self.chunks_vdb.upsert(inserting_chunks)
# TODO: don't support incremental update for communities now, so we have to drop all
await self.community_reports.drop()
# ---------- extract/summary entity and upsert to graph
logger.info("[Entity Extraction]...")
maybe_new_kg = await self.entity_extraction_func(
inserting_chunks,
knwoledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
tokenizer_wrapper=self.tokenizer_wrapper,
global_config=asdict(self),
using_amazon_bedrock=self.using_amazon_bedrock,
)
if maybe_new_kg is None:
logger.warning("No new entities found")
return
self.chunk_entity_relation_graph = maybe_new_kg
# ---------- update clusterings of graph
logger.info("[Community Report]...")
await self.chunk_entity_relation_graph.clustering(
self.graph_cluster_algorithm
)
await generate_community_report(
self.community_reports, self.chunk_entity_relation_graph, self.tokenizer_wrapper, asdict(self)
)
# ---------- commit upsertings and indexing
await self.full_docs.upsert(new_docs)
await self.text_chunks.upsert(inserting_chunks)
finally:
await self._insert_done()
async def _insert_start(self):
tasks = []
for storage_inst in [
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_start_callback())
await asyncio.gather(*tasks)
async def _insert_done(self):
tasks = []
for storage_inst in [
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.community_reports,
self.entities_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
async def _query_done(self):
tasks = []
for storage_inst in [self.llm_response_cache]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)