import asyncio import os from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial from typing import Callable, Dict, List, Optional, Type, Union, cast from ._llm import ( amazon_bedrock_embedding, create_amazon_bedrock_complete_function, gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, azure_gpt_4o_complete, azure_openai_embedding, azure_gpt_4o_mini_complete, ) from ._op import ( chunking_by_token_size, extract_entities, generate_community_report, get_chunks, local_query, global_query, naive_query, ) from ._storage import ( JsonKVStorage, NanoVectorDBStorage, NetworkXStorage, ) from ._utils import ( EmbeddingFunc, compute_mdhash_id, limit_async_func_call, convert_response_to_json, always_get_an_event_loop, logger, TokenizerWrapper, ) from .base import ( BaseGraphStorage, BaseKVStorage, BaseVectorStorage, StorageNameSpace, QueryParam, ) @dataclass class GraphRAG: working_dir: str = field( default_factory=lambda: f"./nano_graphrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" ) # graph mode enable_local: bool = True enable_naive_rag: bool = False # text chunking tokenizer_type: str = "tiktoken" # or 'huggingface' tiktoken_model_name: str = "gpt-4o" huggingface_model_name: str = "bert-base-uncased" # default HF model chunk_func: Callable[ [ list[list[int]], List[str], TokenizerWrapper, Optional[int], Optional[int], ], List[Dict[str, Union[str, int]]], ] = chunking_by_token_size chunk_token_size: int = 1200 chunk_overlap_token_size: int = 100 # entity extraction entity_extract_max_gleaning: int = 1 entity_summary_to_max_tokens: int = 500 # graph clustering graph_cluster_algorithm: str = "leiden" max_graph_cluster_size: int = 10 graph_cluster_seed: int = 0xDEADBEEF # node embedding node_embedding_algorithm: str = "node2vec" node2vec_params: dict = field( default_factory=lambda: { "dimensions": 1536, "num_walks": 10, "walk_length": 40, "num_walks": 10, "window_size": 2, "iterations": 3, "random_seed": 3, } ) # community reports special_community_report_llm_kwargs: dict = field( default_factory=lambda: {"response_format": {"type": "json_object"}} ) # text embedding embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) embedding_batch_num: int = 32 embedding_func_max_async: int = 16 query_better_than_threshold: float = 0.2 # LLM using_azure_openai: bool = False using_amazon_bedrock: bool = False best_model_id: str = "us.anthropic.claude-3-sonnet-20240229-v1:0" cheap_model_id: str = "us.anthropic.claude-3-haiku-20240307-v1:0" best_model_func: callable = gpt_4o_complete best_model_max_token_size: int = 32768 best_model_max_async: int = 16 cheap_model_func: callable = gpt_4o_mini_complete cheap_model_max_token_size: int = 32768 cheap_model_max_async: int = 16 # entity extraction entity_extraction_func: callable = extract_entities # storage key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage vector_db_storage_cls_kwargs: dict = field(default_factory=dict) graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage enable_llm_cache: bool = True # extension always_create_working_dir: bool = True addon_params: dict = field(default_factory=dict) convert_response_to_json_func: callable = convert_response_to_json def __post_init__(self): _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) logger.debug(f"GraphRAG init with param:\n\n {_print_config}\n") self.tokenizer_wrapper = TokenizerWrapper( tokenizer_type=self.tokenizer_type, model_name=self.tiktoken_model_name if self.tokenizer_type == "tiktoken" else self.huggingface_model_name ) if self.using_azure_openai: # If there's no OpenAI API key, use Azure OpenAI if self.best_model_func == gpt_4o_complete: self.best_model_func = azure_gpt_4o_complete if self.cheap_model_func == gpt_4o_mini_complete: self.cheap_model_func = azure_gpt_4o_mini_complete if self.embedding_func == openai_embedding: self.embedding_func = azure_openai_embedding logger.info( "Switched the default openai funcs to Azure OpenAI if you didn't set any of it" ) if self.using_amazon_bedrock: self.best_model_func = create_amazon_bedrock_complete_function(self.best_model_id) self.cheap_model_func = create_amazon_bedrock_complete_function(self.cheap_model_id) self.embedding_func = amazon_bedrock_embedding logger.info( "Switched the default openai funcs to Amazon Bedrock" ) if not os.path.exists(self.working_dir) and self.always_create_working_dir: logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) self.full_docs = self.key_string_value_json_storage_cls( namespace="full_docs", global_config=asdict(self) ) self.text_chunks = self.key_string_value_json_storage_cls( namespace="text_chunks", global_config=asdict(self) ) self.llm_response_cache = ( self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self) ) if self.enable_llm_cache else None ) self.community_reports = self.key_string_value_json_storage_cls( namespace="community_reports", global_config=asdict(self) ) self.chunk_entity_relation_graph = self.graph_storage_cls( namespace="chunk_entity_relation", global_config=asdict(self) ) self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func ) self.entities_vdb = ( self.vector_db_storage_cls( namespace="entities", global_config=asdict(self), embedding_func=self.embedding_func, meta_fields={"entity_name"}, ) if self.enable_local else None ) self.chunks_vdb = ( self.vector_db_storage_cls( namespace="chunks", global_config=asdict(self), embedding_func=self.embedding_func, ) if self.enable_naive_rag else None ) self.best_model_func = limit_async_func_call(self.best_model_max_async)( partial(self.best_model_func, hashing_kv=self.llm_response_cache) ) self.cheap_model_func = limit_async_func_call(self.cheap_model_max_async)( partial(self.cheap_model_func, hashing_kv=self.llm_response_cache) ) def insert(self, string_or_strings): loop = always_get_an_event_loop() return loop.run_until_complete(self.ainsert(string_or_strings)) def query(self, query: str, param: QueryParam = QueryParam()): loop = always_get_an_event_loop() return loop.run_until_complete(self.aquery(query, param)) async def aquery(self, query: str, param: QueryParam = QueryParam()): if param.mode == "local" and not self.enable_local: raise ValueError("enable_local is False, cannot query in local mode") if param.mode == "naive" and not self.enable_naive_rag: raise ValueError("enable_naive_rag is False, cannot query in naive mode") if param.mode == "local": response = await local_query( query, self.chunk_entity_relation_graph, self.entities_vdb, self.community_reports, self.text_chunks, param, self.tokenizer_wrapper, asdict(self), ) elif param.mode == "global": response = await global_query( query, self.chunk_entity_relation_graph, self.entities_vdb, self.community_reports, self.text_chunks, param, self.tokenizer_wrapper, asdict(self), ) elif param.mode == "naive": response = await naive_query( query, self.chunks_vdb, self.text_chunks, param, self.tokenizer_wrapper, asdict(self), ) else: raise ValueError(f"Unknown mode {param.mode}") await self._query_done() return response async def ainsert(self, string_or_strings): await self._insert_start() try: if isinstance(string_or_strings, str): string_or_strings = [string_or_strings] # ---------- new docs new_docs = { compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()} for c in string_or_strings } _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} if not len(new_docs): logger.warning(f"All docs are already in the storage") return logger.info(f"[New Docs] inserting {len(new_docs)} docs") # ---------- chunking inserting_chunks = get_chunks( new_docs=new_docs, chunk_func=self.chunk_func, overlap_token_size=self.chunk_overlap_token_size, max_token_size=self.chunk_token_size, tokenizer_wrapper=self.tokenizer_wrapper, ) _add_chunk_keys = await self.text_chunks.filter_keys( list(inserting_chunks.keys()) ) inserting_chunks = { k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys } if not len(inserting_chunks): logger.warning(f"All chunks are already in the storage") return logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks") if self.enable_naive_rag: logger.info("Insert chunks for naive RAG") await self.chunks_vdb.upsert(inserting_chunks) # TODO: don't support incremental update for communities now, so we have to drop all await self.community_reports.drop() # ---------- extract/summary entity and upsert to graph logger.info("[Entity Extraction]...") maybe_new_kg = await self.entity_extraction_func( inserting_chunks, knwoledge_graph_inst=self.chunk_entity_relation_graph, entity_vdb=self.entities_vdb, tokenizer_wrapper=self.tokenizer_wrapper, global_config=asdict(self), using_amazon_bedrock=self.using_amazon_bedrock, ) if maybe_new_kg is None: logger.warning("No new entities found") return self.chunk_entity_relation_graph = maybe_new_kg # ---------- update clusterings of graph logger.info("[Community Report]...") await self.chunk_entity_relation_graph.clustering( self.graph_cluster_algorithm ) await generate_community_report( self.community_reports, self.chunk_entity_relation_graph, self.tokenizer_wrapper, asdict(self) ) # ---------- commit upsertings and indexing await self.full_docs.upsert(new_docs) await self.text_chunks.upsert(inserting_chunks) finally: await self._insert_done() async def _insert_start(self): tasks = [] for storage_inst in [ self.chunk_entity_relation_graph, ]: if storage_inst is None: continue tasks.append(cast(StorageNameSpace, storage_inst).index_start_callback()) await asyncio.gather(*tasks) async def _insert_done(self): tasks = [] for storage_inst in [ self.full_docs, self.text_chunks, self.llm_response_cache, self.community_reports, self.entities_vdb, self.chunks_vdb, self.chunk_entity_relation_graph, ]: if storage_inst is None: continue tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) await asyncio.gather(*tasks) async def _query_done(self): tasks = [] for storage_inst in [self.llm_response_cache]: if storage_inst is None: continue tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) await asyncio.gather(*tasks)