import asyncio import html import json import logging import os import re import numbers from dataclasses import dataclass from functools import wraps from hashlib import md5 from typing import Any, Union, Literal import numpy as np import tiktoken try: from transformers import AutoTokenizer except ImportError: AutoTokenizer = None logger = logging.getLogger("nano-graphrag") logging.getLogger("neo4j").setLevel(logging.ERROR) def always_get_an_event_loop() -> asyncio.AbstractEventLoop: try: # If there is already an event loop, use it. loop = asyncio.get_event_loop() except RuntimeError: # If in a sub-thread, create a new event loop. logger.info("Creating a new event loop in a sub-thread.") loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop def extract_first_complete_json(s: str): """Extract the first complete JSON object from the string using a stack to track braces.""" stack = [] first_json_start = None for i, char in enumerate(s): if char == '{': stack.append(i) if first_json_start is None: first_json_start = i elif char == '}': if stack: start = stack.pop() if not stack: first_json_str = s[first_json_start:i+1] try: # Attempt to parse the JSON string return json.loads(first_json_str.replace("\n", "")) except json.JSONDecodeError as e: logger.error(f"JSON decoding failed: {e}. Attempted string: {first_json_str[:50]}...") return None finally: first_json_start = None logger.warning("No complete JSON object found in the input string.") return None def parse_value(value: str): """Convert a string value to its appropriate type (int, float, bool, None, or keep as string). Work as a more broad 'eval()'""" value = value.strip() if value == "null": return None elif value == "true": return True elif value == "false": return False else: # Try to convert to int or float try: if '.' in value: # If there's a dot, it might be a float return float(value) else: return int(value) except ValueError: # If conversion fails, return the value as-is (likely a string) return value.strip('"') # Remove surrounding quotes if they exist def extract_values_from_json(json_string, keys=["reasoning", "answer", "data"], allow_no_quotes=False): """Extract key values from a non-standard or malformed JSON string, handling nested objects.""" extracted_values = {} # Enhanced pattern to match both quoted and unquoted values, as well as nested objects regex_pattern = r'(?P"?\w+"?)\s*:\s*(?P{[^}]*}|".*?"|[^,}]+)' for match in re.finditer(regex_pattern, json_string, re.DOTALL): key = match.group('key').strip('"') # Strip quotes from key value = match.group('value').strip() # If the value is another nested JSON (starts with '{' and ends with '}'), recursively parse it if value.startswith('{') and value.endswith('}'): extracted_values[key] = extract_values_from_json(value) else: # Parse the value into the appropriate type (int, float, bool, etc.) extracted_values[key] = parse_value(value) if not extracted_values: logger.warning("No values could be extracted from the string.") return extracted_values def convert_response_to_json(response: str) -> dict: """Convert response string to JSON, with error handling and fallback to non-standard JSON extraction.""" prediction_json = extract_first_complete_json(response) if prediction_json is None: logger.info("Attempting to extract values from a non-standard JSON string...") prediction_json = extract_values_from_json(response, allow_no_quotes=True) if not prediction_json: logger.error("Unable to extract meaningful data from the response.") else: logger.info("JSON data successfully extracted.") return prediction_json class TokenizerWrapper: def __init__(self, tokenizer_type: Literal["tiktoken", "huggingface"] = "tiktoken", model_name: str = "gpt-4o"): self.tokenizer_type = tokenizer_type self.model_name = model_name self._tokenizer = None self._lazy_load_tokenizer() def _lazy_load_tokenizer(self): if self._tokenizer is not None: return logger.info(f"Loading tokenizer: type='{self.tokenizer_type}', name='{self.model_name}'") if self.tokenizer_type == "tiktoken": self._tokenizer = tiktoken.encoding_for_model(self.model_name) elif self.tokenizer_type == "huggingface": if AutoTokenizer is None: raise ImportError("`transformers` is not installed. Please install it via `pip install transformers` to use HuggingFace tokenizers.") self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True) else: raise ValueError(f"Unknown tokenizer_type: {self.tokenizer_type}") def get_tokenizer(self): """提供对底层 tokenizer 对象的访问,用于特殊情况(如 decode_batch)。""" self._lazy_load_tokenizer() return self._tokenizer def encode(self, text: str) -> list[int]: self._lazy_load_tokenizer() return self._tokenizer.encode(text) def decode(self, tokens: list[int]) -> str: self._lazy_load_tokenizer() return self._tokenizer.decode(tokens) # +++ 新增 +++: 增加一个批量解码的方法以提高效率,并保持接口一致性 def decode_batch(self, tokens_list: list[list[int]]) -> list[str]: self._lazy_load_tokenizer() # HuggingFace tokenizer 有 decode_batch,但 tiktoken 没有,我们用列表推导来模拟 if self.tokenizer_type == "tiktoken": return [self._tokenizer.decode(tokens) for tokens in tokens_list] elif self.tokenizer_type == "huggingface": return self._tokenizer.batch_decode(tokens_list, skip_special_tokens=True) else: raise ValueError(f"Unknown tokenizer_type: {self.tokenizer_type}") def truncate_list_by_token_size( list_data: list, key: callable, max_token_size: int, tokenizer_wrapper: TokenizerWrapper ): """Truncate a list of data by token size using a provided tokenizer wrapper.""" if max_token_size <= 0: return [] tokens = 0 for i, data in enumerate(list_data): tokens += len(tokenizer_wrapper.encode(key(data))) + 1 # 防御性,模拟通过\n拼接列表的情况 if tokens > max_token_size: return list_data[:i] return list_data def compute_mdhash_id(content, prefix: str = ""): return prefix + md5(content.encode()).hexdigest() def write_json(json_obj, file_name): with open(file_name, "w", encoding="utf-8") as f: json.dump(json_obj, f, indent=2, ensure_ascii=False) def load_json(file_name): if not os.path.exists(file_name): return None with open(file_name, encoding="utf-8") as f: return json.load(f) # it's dirty to type, so it's a good way to have fun def pack_user_ass_to_openai_messages(prompt: str, generated_content: str, using_amazon_bedrock: bool): if using_amazon_bedrock: return [ {"role": "user", "content": [{"text": prompt}]}, {"role": "assistant", "content": [{"text": generated_content}]}, ] else: return [ {"role": "user", "content": prompt}, {"role": "assistant", "content": generated_content}, ] def is_float_regex(value): return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) def compute_args_hash(*args): return md5(str(args).encode()).hexdigest() def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]: """Split a string by multiple markers""" if not markers: return [content] results = re.split("|".join(re.escape(marker) for marker in markers), content) return [r.strip() for r in results if r.strip()] def enclose_string_with_quotes(content: Any) -> str: """Enclose a string with quotes""" if isinstance(content, numbers.Number): return str(content) content = str(content) content = content.strip().strip("'").strip('"') return f'"{content}"' def list_of_list_to_csv(data: list[list]): return "\n".join( [ ",\t".join([f"{enclose_string_with_quotes(data_dd)}" for data_dd in data_d]) for data_d in data ] ) # ----------------------------------------------------------------------------------- # Refer the utils functions of the official GraphRAG implementation: # https://github.com/microsoft/graphrag def clean_str(input: Any) -> str: """Clean an input string by removing HTML escapes, control characters, and other unwanted characters.""" # If we get non-string input, just give it back if not isinstance(input, str): return input result = html.unescape(input.strip()) # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) # Utils types ----------------------------------------------------------------------- @dataclass class EmbeddingFunc: embedding_dim: int max_token_size: int func: callable async def __call__(self, *args, **kwargs) -> np.ndarray: return await self.func(*args, **kwargs) # Decorators ------------------------------------------------------------------------ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001): """Add restriction of maximum async calling times for a async func""" def final_decro(func): """Not using async.Semaphore to aovid use nest-asyncio""" __current_size = 0 @wraps(func) async def wait_func(*args, **kwargs): nonlocal __current_size while __current_size >= max_size: await asyncio.sleep(waitting_time) __current_size += 1 result = await func(*args, **kwargs) __current_size -= 1 return result return wait_func return final_decro def wrap_embedding_func_with_attrs(**kwargs): """Wrap a function with attributes""" def final_decro(func) -> EmbeddingFunc: new_func = EmbeddingFunc(**kwargs, func=func) return new_func return final_decro