init. project
This commit is contained in:
307
rag-web-ui/backend/nano_graphrag/_utils.py
Normal file
307
rag-web-ui/backend/nano_graphrag/_utils.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import asyncio
|
||||
import html
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import numbers
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from hashlib import md5
|
||||
from typing import Any, Union, Literal
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError:
|
||||
AutoTokenizer = None
|
||||
|
||||
logger = logging.getLogger("nano-graphrag")
|
||||
logging.getLogger("neo4j").setLevel(logging.ERROR)
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
# If there is already an event loop, use it.
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
# If in a sub-thread, create a new event loop.
|
||||
logger.info("Creating a new event loop in a sub-thread.")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
|
||||
def extract_first_complete_json(s: str):
|
||||
"""Extract the first complete JSON object from the string using a stack to track braces."""
|
||||
stack = []
|
||||
first_json_start = None
|
||||
|
||||
for i, char in enumerate(s):
|
||||
if char == '{':
|
||||
stack.append(i)
|
||||
if first_json_start is None:
|
||||
first_json_start = i
|
||||
elif char == '}':
|
||||
if stack:
|
||||
start = stack.pop()
|
||||
if not stack:
|
||||
first_json_str = s[first_json_start:i+1]
|
||||
try:
|
||||
# Attempt to parse the JSON string
|
||||
return json.loads(first_json_str.replace("\n", ""))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON decoding failed: {e}. Attempted string: {first_json_str[:50]}...")
|
||||
return None
|
||||
finally:
|
||||
first_json_start = None
|
||||
logger.warning("No complete JSON object found in the input string.")
|
||||
return None
|
||||
|
||||
def parse_value(value: str):
|
||||
"""Convert a string value to its appropriate type (int, float, bool, None, or keep as string). Work as a more broad 'eval()'"""
|
||||
value = value.strip()
|
||||
|
||||
if value == "null":
|
||||
return None
|
||||
elif value == "true":
|
||||
return True
|
||||
elif value == "false":
|
||||
return False
|
||||
else:
|
||||
# Try to convert to int or float
|
||||
try:
|
||||
if '.' in value: # If there's a dot, it might be a float
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
# If conversion fails, return the value as-is (likely a string)
|
||||
return value.strip('"') # Remove surrounding quotes if they exist
|
||||
|
||||
def extract_values_from_json(json_string, keys=["reasoning", "answer", "data"], allow_no_quotes=False):
|
||||
"""Extract key values from a non-standard or malformed JSON string, handling nested objects."""
|
||||
extracted_values = {}
|
||||
|
||||
# Enhanced pattern to match both quoted and unquoted values, as well as nested objects
|
||||
regex_pattern = r'(?P<key>"?\w+"?)\s*:\s*(?P<value>{[^}]*}|".*?"|[^,}]+)'
|
||||
|
||||
for match in re.finditer(regex_pattern, json_string, re.DOTALL):
|
||||
key = match.group('key').strip('"') # Strip quotes from key
|
||||
value = match.group('value').strip()
|
||||
|
||||
# If the value is another nested JSON (starts with '{' and ends with '}'), recursively parse it
|
||||
if value.startswith('{') and value.endswith('}'):
|
||||
extracted_values[key] = extract_values_from_json(value)
|
||||
else:
|
||||
# Parse the value into the appropriate type (int, float, bool, etc.)
|
||||
extracted_values[key] = parse_value(value)
|
||||
|
||||
if not extracted_values:
|
||||
logger.warning("No values could be extracted from the string.")
|
||||
|
||||
return extracted_values
|
||||
|
||||
|
||||
def convert_response_to_json(response: str) -> dict:
|
||||
"""Convert response string to JSON, with error handling and fallback to non-standard JSON extraction."""
|
||||
prediction_json = extract_first_complete_json(response)
|
||||
|
||||
if prediction_json is None:
|
||||
logger.info("Attempting to extract values from a non-standard JSON string...")
|
||||
prediction_json = extract_values_from_json(response, allow_no_quotes=True)
|
||||
|
||||
if not prediction_json:
|
||||
logger.error("Unable to extract meaningful data from the response.")
|
||||
else:
|
||||
logger.info("JSON data successfully extracted.")
|
||||
|
||||
return prediction_json
|
||||
|
||||
|
||||
|
||||
|
||||
class TokenizerWrapper:
|
||||
def __init__(self, tokenizer_type: Literal["tiktoken", "huggingface"] = "tiktoken", model_name: str = "gpt-4o"):
|
||||
self.tokenizer_type = tokenizer_type
|
||||
self.model_name = model_name
|
||||
self._tokenizer = None
|
||||
self._lazy_load_tokenizer()
|
||||
|
||||
def _lazy_load_tokenizer(self):
|
||||
if self._tokenizer is not None:
|
||||
return
|
||||
logger.info(f"Loading tokenizer: type='{self.tokenizer_type}', name='{self.model_name}'")
|
||||
if self.tokenizer_type == "tiktoken":
|
||||
self._tokenizer = tiktoken.encoding_for_model(self.model_name)
|
||||
elif self.tokenizer_type == "huggingface":
|
||||
if AutoTokenizer is None:
|
||||
raise ImportError("`transformers` is not installed. Please install it via `pip install transformers` to use HuggingFace tokenizers.")
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
||||
else:
|
||||
raise ValueError(f"Unknown tokenizer_type: {self.tokenizer_type}")
|
||||
|
||||
def get_tokenizer(self):
|
||||
"""提供对底层 tokenizer 对象的访问,用于特殊情况(如 decode_batch)。"""
|
||||
self._lazy_load_tokenizer()
|
||||
return self._tokenizer
|
||||
|
||||
def encode(self, text: str) -> list[int]:
|
||||
self._lazy_load_tokenizer()
|
||||
return self._tokenizer.encode(text)
|
||||
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
self._lazy_load_tokenizer()
|
||||
return self._tokenizer.decode(tokens)
|
||||
|
||||
# +++ 新增 +++: 增加一个批量解码的方法以提高效率,并保持接口一致性
|
||||
def decode_batch(self, tokens_list: list[list[int]]) -> list[str]:
|
||||
self._lazy_load_tokenizer()
|
||||
# HuggingFace tokenizer 有 decode_batch,但 tiktoken 没有,我们用列表推导来模拟
|
||||
if self.tokenizer_type == "tiktoken":
|
||||
return [self._tokenizer.decode(tokens) for tokens in tokens_list]
|
||||
elif self.tokenizer_type == "huggingface":
|
||||
return self._tokenizer.batch_decode(tokens_list, skip_special_tokens=True)
|
||||
else:
|
||||
raise ValueError(f"Unknown tokenizer_type: {self.tokenizer_type}")
|
||||
|
||||
|
||||
|
||||
def truncate_list_by_token_size(
|
||||
list_data: list,
|
||||
key: callable,
|
||||
max_token_size: int,
|
||||
tokenizer_wrapper: TokenizerWrapper
|
||||
):
|
||||
"""Truncate a list of data by token size using a provided tokenizer wrapper."""
|
||||
if max_token_size <= 0:
|
||||
return []
|
||||
tokens = 0
|
||||
for i, data in enumerate(list_data):
|
||||
tokens += len(tokenizer_wrapper.encode(key(data))) + 1 # 防御性,模拟通过\n拼接列表的情况
|
||||
if tokens > max_token_size:
|
||||
return list_data[:i]
|
||||
return list_data
|
||||
|
||||
|
||||
def compute_mdhash_id(content, prefix: str = ""):
|
||||
return prefix + md5(content.encode()).hexdigest()
|
||||
|
||||
|
||||
def write_json(json_obj, file_name):
|
||||
with open(file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def load_json(file_name):
|
||||
if not os.path.exists(file_name):
|
||||
return None
|
||||
with open(file_name, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
# it's dirty to type, so it's a good way to have fun
|
||||
def pack_user_ass_to_openai_messages(prompt: str, generated_content: str, using_amazon_bedrock: bool):
|
||||
if using_amazon_bedrock:
|
||||
return [
|
||||
{"role": "user", "content": [{"text": prompt}]},
|
||||
{"role": "assistant", "content": [{"text": generated_content}]},
|
||||
]
|
||||
else:
|
||||
return [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": generated_content},
|
||||
]
|
||||
|
||||
|
||||
def is_float_regex(value):
|
||||
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
||||
|
||||
|
||||
def compute_args_hash(*args):
|
||||
return md5(str(args).encode()).hexdigest()
|
||||
|
||||
|
||||
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
||||
"""Split a string by multiple markers"""
|
||||
if not markers:
|
||||
return [content]
|
||||
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
||||
return [r.strip() for r in results if r.strip()]
|
||||
|
||||
|
||||
def enclose_string_with_quotes(content: Any) -> str:
|
||||
"""Enclose a string with quotes"""
|
||||
if isinstance(content, numbers.Number):
|
||||
return str(content)
|
||||
content = str(content)
|
||||
content = content.strip().strip("'").strip('"')
|
||||
return f'"{content}"'
|
||||
|
||||
|
||||
def list_of_list_to_csv(data: list[list]):
|
||||
return "\n".join(
|
||||
[
|
||||
",\t".join([f"{enclose_string_with_quotes(data_dd)}" for data_dd in data_d])
|
||||
for data_d in data
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------------
|
||||
# Refer the utils functions of the official GraphRAG implementation:
|
||||
# https://github.com/microsoft/graphrag
|
||||
def clean_str(input: Any) -> str:
|
||||
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
|
||||
# If we get non-string input, just give it back
|
||||
if not isinstance(input, str):
|
||||
return input
|
||||
|
||||
result = html.unescape(input.strip())
|
||||
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
||||
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
||||
|
||||
|
||||
# Utils types -----------------------------------------------------------------------
|
||||
@dataclass
|
||||
class EmbeddingFunc:
|
||||
embedding_dim: int
|
||||
max_token_size: int
|
||||
func: callable
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||
return await self.func(*args, **kwargs)
|
||||
|
||||
|
||||
# Decorators ------------------------------------------------------------------------
|
||||
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
||||
"""Add restriction of maximum async calling times for a async func"""
|
||||
|
||||
def final_decro(func):
|
||||
"""Not using async.Semaphore to aovid use nest-asyncio"""
|
||||
__current_size = 0
|
||||
|
||||
@wraps(func)
|
||||
async def wait_func(*args, **kwargs):
|
||||
nonlocal __current_size
|
||||
while __current_size >= max_size:
|
||||
await asyncio.sleep(waitting_time)
|
||||
__current_size += 1
|
||||
result = await func(*args, **kwargs)
|
||||
__current_size -= 1
|
||||
return result
|
||||
|
||||
return wait_func
|
||||
|
||||
return final_decro
|
||||
|
||||
|
||||
def wrap_embedding_func_with_attrs(**kwargs):
|
||||
"""Wrap a function with attributes"""
|
||||
|
||||
def final_decro(func) -> EmbeddingFunc:
|
||||
new_func = EmbeddingFunc(**kwargs, func=func)
|
||||
return new_func
|
||||
|
||||
return final_decro
|
||||
Reference in New Issue
Block a user