Files

1141 lines
39 KiB
Python
Raw Permalink Normal View History

2026-04-13 11:34:23 +08:00
import re
import json
import asyncio
from typing import Union
from collections import Counter, defaultdict
from ._splitter import SeparatorSplitter
from ._utils import (
logger,
clean_str,
compute_mdhash_id,
is_float_regex,
list_of_list_to_csv,
pack_user_ass_to_openai_messages,
split_string_by_multi_markers,
truncate_list_by_token_size,
TokenizerWrapper
)
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
SingleCommunitySchema,
CommunitySchema,
TextChunkSchema,
QueryParam,
)
from .prompt import GRAPH_FIELD_SEP, PROMPTS
def chunking_by_token_size(
tokens_list: list[list[int]],
doc_keys,
tokenizer_wrapper: TokenizerWrapper,
overlap_token_size=128,
max_token_size=1024,
):
results = []
for index, tokens in enumerate(tokens_list):
chunk_token = []
lengths = []
for start in range(0, len(tokens), max_token_size - overlap_token_size):
chunk_token.append(tokens[start : start + max_token_size])
lengths.append(min(max_token_size, len(tokens) - start))
chunk_texts = tokenizer_wrapper.decode_batch(chunk_token)
for i, chunk in enumerate(chunk_texts):
results.append(
{
"tokens": lengths[i],
"content": chunk.strip(),
"chunk_order_index": i,
"full_doc_id": doc_keys[index],
}
)
return results
def chunking_by_seperators(
tokens_list: list[list[int]],
doc_keys,
tokenizer_wrapper: TokenizerWrapper,
overlap_token_size=128,
max_token_size=1024,
):
from .prompt import PROMPTS
# *** 修改 ***: 直接使用 wrapper 编码,而不是获取底层 tokenizer
separators = [tokenizer_wrapper.encode(s) for s in PROMPTS["default_text_separator"]]
splitter = SeparatorSplitter(
separators=separators,
chunk_size=max_token_size,
chunk_overlap=overlap_token_size,
)
results = []
for index, tokens in enumerate(tokens_list):
chunk_tokens = splitter.split_tokens(tokens)
lengths = [len(c) for c in chunk_tokens]
decoded_chunks = tokenizer_wrapper.decode_batch(chunk_tokens)
for i, chunk in enumerate(decoded_chunks):
results.append(
{
"tokens": lengths[i],
"content": chunk.strip(),
"chunk_order_index": i,
"full_doc_id": doc_keys[index],
}
)
return results
def get_chunks(new_docs, chunk_func=chunking_by_token_size, tokenizer_wrapper: TokenizerWrapper = None, **chunk_func_params):
inserting_chunks = {}
new_docs_list = list(new_docs.items())
docs = [new_doc[1]["content"] for new_doc in new_docs_list]
doc_keys = [new_doc[0] for new_doc in new_docs_list]
tokens = [tokenizer_wrapper.encode(doc) for doc in docs]
chunks = chunk_func(
tokens, doc_keys=doc_keys, tokenizer_wrapper=tokenizer_wrapper, overlap_token_size=chunk_func_params.get("overlap_token_size", 128), max_token_size=chunk_func_params.get("max_token_size", 1024)
)
for chunk in chunks:
inserting_chunks.update(
{compute_mdhash_id(chunk["content"], prefix="chunk-"): chunk}
)
return inserting_chunks
async def _handle_entity_relation_summary(
entity_or_relation_name: str,
description: str,
global_config: dict,
tokenizer_wrapper: TokenizerWrapper,
) -> str:
use_llm_func: callable = global_config["cheap_model_func"]
llm_max_tokens = global_config["cheap_model_max_token_size"]
summary_max_tokens = global_config["entity_summary_to_max_tokens"]
tokens = tokenizer_wrapper.encode(description)
if len(tokens) < summary_max_tokens:
return description
prompt_template = PROMPTS["summarize_entity_descriptions"]
use_description = tokenizer_wrapper.decode(tokens[:llm_max_tokens])
context_base = dict(
entity_name=entity_or_relation_name,
description_list=use_description.split(GRAPH_FIELD_SEP),
)
use_prompt = prompt_template.format(**context_base)
logger.debug(f"Trigger summary: {entity_or_relation_name}")
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
return summary
async def _handle_single_entity_extraction(
record_attributes: list[str],
chunk_key: str,
):
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
return None
# add this record as a node in the G
entity_name = clean_str(record_attributes[1].upper())
if not entity_name.strip():
return None
entity_type = clean_str(record_attributes[2].upper())
entity_description = clean_str(record_attributes[3])
entity_source_id = chunk_key
return dict(
entity_name=entity_name,
entity_type=entity_type,
description=entity_description,
source_id=entity_source_id,
)
async def _handle_single_relationship_extraction(
record_attributes: list[str],
chunk_key: str,
):
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
return None
# add this record as edge
source = clean_str(record_attributes[1].upper())
target = clean_str(record_attributes[2].upper())
edge_description = clean_str(record_attributes[3])
edge_source_id = chunk_key
weight = (
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
)
return dict(
src_id=source,
tgt_id=target,
weight=weight,
description=edge_description,
source_id=edge_source_id,
)
async def _merge_nodes_then_upsert(
entity_name: str,
nodes_data: list[dict],
knwoledge_graph_inst: BaseGraphStorage,
global_config: dict,
tokenizer_wrapper,
):
already_entitiy_types = []
already_source_ids = []
already_description = []
already_node = await knwoledge_graph_inst.get_node(entity_name)
if already_node is not None:
already_entitiy_types.append(already_node["entity_type"])
already_source_ids.extend(
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
)
already_description.append(already_node["description"])
entity_type = sorted(
Counter(
[dp["entity_type"] for dp in nodes_data] + already_entitiy_types
).items(),
key=lambda x: x[1],
reverse=True,
)[0][0]
description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in nodes_data] + already_description))
)
source_id = GRAPH_FIELD_SEP.join(
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
)
description = await _handle_entity_relation_summary(
entity_name, description, global_config, tokenizer_wrapper
)
node_data = dict(
entity_type=entity_type,
description=description,
source_id=source_id,
)
await knwoledge_graph_inst.upsert_node(
entity_name,
node_data=node_data,
)
node_data["entity_name"] = entity_name
return node_data
async def _merge_edges_then_upsert(
src_id: str,
tgt_id: str,
edges_data: list[dict],
knwoledge_graph_inst: BaseGraphStorage,
global_config: dict,
tokenizer_wrapper,
):
already_weights = []
already_source_ids = []
already_description = []
already_order = []
if await knwoledge_graph_inst.has_edge(src_id, tgt_id):
already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id)
already_weights.append(already_edge["weight"])
already_source_ids.extend(
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
)
already_description.append(already_edge["description"])
already_order.append(already_edge.get("order", 1))
# [numberchiffre]: `Relationship.order` is only returned from DSPy's predictions
order = min([dp.get("order", 1) for dp in edges_data] + already_order)
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in edges_data] + already_description))
)
source_id = GRAPH_FIELD_SEP.join(
set([dp["source_id"] for dp in edges_data] + already_source_ids)
)
for need_insert_id in [src_id, tgt_id]:
if not (await knwoledge_graph_inst.has_node(need_insert_id)):
await knwoledge_graph_inst.upsert_node(
need_insert_id,
node_data={
"source_id": source_id,
"description": description,
"entity_type": '"UNKNOWN"',
},
)
description = await _handle_entity_relation_summary(
(src_id, tgt_id), description, global_config, tokenizer_wrapper
)
await knwoledge_graph_inst.upsert_edge(
src_id,
tgt_id,
edge_data=dict(
weight=weight, description=description, source_id=source_id, order=order
),
)
async def extract_entities(
chunks: dict[str, TextChunkSchema],
knwoledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
tokenizer_wrapper,
global_config: dict,
using_amazon_bedrock: bool=False,
) -> Union[BaseGraphStorage, None]:
use_llm_func: callable = global_config["best_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
ordered_chunks = list(chunks.items())
entity_extract_prompt = PROMPTS["entity_extraction"]
context_base = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
)
continue_prompt = PROMPTS["entiti_continue_extraction"]
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
already_processed = 0
already_entities = 0
already_relations = 0
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
nonlocal already_processed, already_entities, already_relations
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
final_result = await use_llm_func(hint_prompt)
if isinstance(final_result, list):
final_result = final_result[0]["text"]
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, using_amazon_bedrock)
for now_glean_index in range(entity_extract_max_gleaning):
glean_result = await use_llm_func(continue_prompt, history_messages=history)
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result, using_amazon_bedrock)
final_result += glean_result
if now_glean_index == entity_extract_max_gleaning - 1:
break
if_loop_result: str = await use_llm_func(
if_loop_prompt, history_messages=history
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break
records = split_string_by_multi_markers(
final_result,
[context_base["record_delimiter"], context_base["completion_delimiter"]],
)
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
record = record.group(1)
record_attributes = split_string_by_multi_markers(
record, [context_base["tuple_delimiter"]]
)
if_entities = await _handle_single_entity_extraction(
record_attributes, chunk_key
)
if if_entities is not None:
maybe_nodes[if_entities["entity_name"]].append(if_entities)
continue
if_relation = await _handle_single_relationship_extraction(
record_attributes, chunk_key
)
if if_relation is not None:
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
if_relation
)
already_processed += 1
already_entities += len(maybe_nodes)
already_relations += len(maybe_edges)
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
print(
f"{now_ticks} Processed {already_processed}({already_processed*100//len(ordered_chunks)}%) chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
end="",
flush=True,
)
return dict(maybe_nodes), dict(maybe_edges)
# use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callings
results = await asyncio.gather(
*[_process_single_content(c) for c in ordered_chunks]
)
print() # clear the progress bar
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for m_nodes, m_edges in results:
for k, v in m_nodes.items():
maybe_nodes[k].extend(v)
for k, v in m_edges.items():
# it's undirected graph
maybe_edges[tuple(sorted(k))].extend(v)
all_entities_data = await asyncio.gather(
*[
_merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config, tokenizer_wrapper)
for k, v in maybe_nodes.items()
]
)
await asyncio.gather(
*[
_merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config, tokenizer_wrapper)
for k, v in maybe_edges.items()
]
)
if not len(all_entities_data):
logger.warning("Didn't extract any entities, maybe your LLM is not working")
return None
if entity_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
}
for dp in all_entities_data
}
await entity_vdb.upsert(data_for_vdb)
return knwoledge_graph_inst
def _pack_single_community_by_sub_communities(
community: SingleCommunitySchema,
max_token_size: int,
already_reports: dict[str, CommunitySchema],
tokenizer_wrapper: TokenizerWrapper,
) -> tuple[str, int, set, set]:
all_sub_communities = [
already_reports[k] for k in community["sub_communities"] if k in already_reports
]
all_sub_communities = sorted(
all_sub_communities, key=lambda x: x["occurrence"], reverse=True
)
may_trun_all_sub_communities = truncate_list_by_token_size(
all_sub_communities,
key=lambda x: x["report_string"],
max_token_size=max_token_size,
tokenizer_wrapper=tokenizer_wrapper,
)
sub_fields = ["id", "report", "rating", "importance"]
sub_communities_describe = list_of_list_to_csv(
[sub_fields]
+ [
[
i,
c["report_string"],
c["report_json"].get("rating", -1),
c["occurrence"],
]
for i, c in enumerate(may_trun_all_sub_communities)
]
)
already_nodes = []
already_edges = []
for c in may_trun_all_sub_communities:
already_nodes.extend(c["nodes"])
already_edges.extend([tuple(e) for e in c["edges"]])
return (
sub_communities_describe,
len(tokenizer_wrapper.encode(sub_communities_describe)),
set(already_nodes),
set(already_edges),
)
async def _pack_single_community_describe(
knwoledge_graph_inst: BaseGraphStorage,
community: SingleCommunitySchema,
tokenizer_wrapper: "TokenizerWrapper",
max_token_size: int = 12000,
already_reports: dict[str, CommunitySchema] = {},
global_config: dict = {},
) -> str:
# 1. 准备原始数据
nodes_in_order = sorted(community["nodes"])
edges_in_order = sorted(community["edges"], key=lambda x: x[0] + x[1])
nodes_data = await asyncio.gather(
*[knwoledge_graph_inst.get_node(n) for n in nodes_in_order]
)
edges_data = await asyncio.gather(
*[knwoledge_graph_inst.get_edge(src, tgt) for src, tgt in edges_in_order]
)
# 2. 定义模板和固定开销
final_template = """-----Reports-----
```csv
{reports}
```
-----Entities-----
```csv
{entities}
```
-----Relationships-----
```csv
{relationships}
```"""
base_template_tokens = len(tokenizer_wrapper.encode(
final_template.format(reports="", entities="", relationships="")
))
remaining_budget = max_token_size - base_template_tokens
# 3. 处理子社区报告
report_describe = ""
contain_nodes = set()
contain_edges = set()
# 启发式截断检测
truncated = len(nodes_in_order) > 100 or len(edges_in_order) > 100
need_to_use_sub_communities = (
truncated and
community["sub_communities"] and
already_reports
)
force_to_use_sub_communities = global_config["addon_params"].get(
"force_to_use_sub_communities", False
)
if need_to_use_sub_communities or force_to_use_sub_communities:
logger.debug(f"Community {community['title']} using sub-communities")
# 获取子社区报告及包含的节点/边
result = _pack_single_community_by_sub_communities(
community, remaining_budget, already_reports, tokenizer_wrapper
)
report_describe, report_size, contain_nodes, contain_edges = result
remaining_budget = max(0, remaining_budget - report_size)
# 4. 准备节点和边数据(过滤子社区已包含的)
def format_row(row: list) -> str:
return ','.join('"{}"'.format(str(item).replace('"', '""')) for item in row)
node_fields = ["id", "entity", "type", "description", "degree"]
edge_fields = ["id", "source", "target", "description", "rank"]
# 获取度数并创建数据结构
node_degrees = await knwoledge_graph_inst.node_degrees_batch(nodes_in_order)
edge_degrees = await knwoledge_graph_inst.edge_degrees_batch(edges_in_order)
# 过滤已存在于子社区的节点/边
nodes_list_data = [
[i, name, data.get("entity_type", "UNKNOWN"),
data.get("description", "UNKNOWN"), node_degrees[i]]
for i, (name, data) in enumerate(zip(nodes_in_order, nodes_data))
if name not in contain_nodes # 关键过滤
]
edges_list_data = [
[i, edge[0], edge[1], data.get("description", "UNKNOWN"), edge_degrees[i]]
for i, (edge, data) in enumerate(zip(edges_in_order, edges_data))
if (edge[0], edge[1]) not in contain_edges # 关键过滤
]
# 按重要性排序
nodes_list_data.sort(key=lambda x: x[-1], reverse=True)
edges_list_data.sort(key=lambda x: x[-1], reverse=True)
# 5. 动态分配预算
# 计算表头开销
header_tokens = len(tokenizer_wrapper.encode(
list_of_list_to_csv([node_fields]) + "\n" + list_of_list_to_csv([edge_fields])
))
data_budget = max(0, remaining_budget - header_tokens)
total_items = len(nodes_list_data) + len(edges_list_data)
node_ratio = len(nodes_list_data) / max(1, total_items)
edge_ratio = 1 - node_ratio
# 执行截断
nodes_final = truncate_list_by_token_size(
nodes_list_data, key=format_row,
max_token_size=int(data_budget * node_ratio),
tokenizer_wrapper=tokenizer_wrapper
)
edges_final = truncate_list_by_token_size(
edges_list_data, key=format_row,
max_token_size= int(data_budget * edge_ratio),
tokenizer_wrapper=tokenizer_wrapper
)
# 6. 组装最终输出
nodes_describe = list_of_list_to_csv([node_fields] + nodes_final)
edges_describe = list_of_list_to_csv([edge_fields] + edges_final)
final_output = final_template.format(
reports=report_describe,
entities=nodes_describe,
relationships=edges_describe
)
return final_output
def _community_report_json_to_str(parsed_output: dict) -> str:
"""refer official graphrag: index/graph/extractors/community_reports"""
title = parsed_output.get("title", "Report")
summary = parsed_output.get("summary", "")
findings = parsed_output.get("findings", [])
def finding_summary(finding: dict):
if isinstance(finding, str):
return finding
return finding.get("summary")
def finding_explanation(finding: dict):
if isinstance(finding, str):
return ""
return finding.get("explanation")
report_sections = "\n\n".join(
f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
)
return f"# {title}\n\n{summary}\n\n{report_sections}"
async def generate_community_report(
community_report_kv: BaseKVStorage[CommunitySchema],
knwoledge_graph_inst: BaseGraphStorage,
tokenizer_wrapper: TokenizerWrapper,
global_config: dict,
):
llm_extra_kwargs = global_config["special_community_report_llm_kwargs"]
use_llm_func: callable = global_config["best_model_func"]
use_string_json_convert_func: callable = global_config["convert_response_to_json_func"]
communities_schema = await knwoledge_graph_inst.community_schema()
community_keys, community_values = list(communities_schema.keys()), list(communities_schema.values())
already_processed = 0
prompt_template = PROMPTS["community_report"]
prompt_overhead = len(tokenizer_wrapper.encode(prompt_template.format(input_text="")))
async def _form_single_community_report(
community: SingleCommunitySchema, already_reports: dict[str, CommunitySchema]
):
nonlocal already_processed
describe = await _pack_single_community_describe(
knwoledge_graph_inst,
community,
tokenizer_wrapper=tokenizer_wrapper,
max_token_size=global_config["best_model_max_token_size"] - prompt_overhead -200, # extra token for chat template and prompt template
already_reports=already_reports,
global_config=global_config,
)
prompt = prompt_template.format(input_text=describe)
response = await use_llm_func(prompt, **llm_extra_kwargs)
data = use_string_json_convert_func(response)
already_processed += 1
now_ticks = PROMPTS["process_tickers"][already_processed % len(PROMPTS["process_tickers"])]
print(f"{now_ticks} Processed {already_processed} communities\r", end="", flush=True)
return data
levels = sorted(set([c["level"] for c in community_values]), reverse=True)
logger.info(f"Generating by levels: {levels}")
community_datas = {}
for level in levels:
this_level_community_keys, this_level_community_values = zip(
*[
(k, v)
for k, v in zip(community_keys, community_values)
if v["level"] == level
]
)
this_level_communities_reports = await asyncio.gather(
*[
_form_single_community_report(c, community_datas)
for c in this_level_community_values
]
)
community_datas.update(
{
k: {
"report_string": _community_report_json_to_str(r),
"report_json": r,
**v,
}
for k, r, v in zip(
this_level_community_keys,
this_level_communities_reports,
this_level_community_values,
)
}
)
print() # clear the progress bar
await community_report_kv.upsert(community_datas)
async def _find_most_related_community_from_entities(
node_datas: list[dict],
query_param: QueryParam,
community_reports: BaseKVStorage[CommunitySchema],
tokenizer_wrapper,
):
related_communities = []
for node_d in node_datas:
if "clusters" not in node_d:
continue
related_communities.extend(json.loads(node_d["clusters"]))
related_community_dup_keys = [
str(dp["cluster"])
for dp in related_communities
if dp["level"] <= query_param.level
]
related_community_keys_counts = dict(Counter(related_community_dup_keys))
_related_community_datas = await asyncio.gather(
*[community_reports.get_by_id(k) for k in related_community_keys_counts.keys()]
)
related_community_datas = {
k: v
for k, v in zip(related_community_keys_counts.keys(), _related_community_datas)
if v is not None
}
related_community_keys = sorted(
related_community_keys_counts.keys(),
key=lambda k: (
related_community_keys_counts[k],
related_community_datas[k]["report_json"].get("rating", -1),
),
reverse=True,
)
sorted_community_datas = [
related_community_datas[k] for k in related_community_keys
]
use_community_reports = truncate_list_by_token_size(
sorted_community_datas,
key=lambda x: x["report_string"],
max_token_size=query_param.local_max_token_for_community_report,
tokenizer_wrapper=tokenizer_wrapper,
)
if query_param.local_community_single_one:
use_community_reports = use_community_reports[:1]
return use_community_reports
async def _find_most_related_text_unit_from_entities(
node_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage,
tokenizer_wrapper,
):
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
for dp in node_datas
]
edges = await knowledge_graph_inst.get_nodes_edges_batch([dp["entity_name"] for dp in node_datas])
all_one_hop_nodes = set()
for this_edges in edges:
if not this_edges:
continue
all_one_hop_nodes.update([e[1] for e in this_edges])
all_one_hop_nodes = list(all_one_hop_nodes)
all_one_hop_nodes_data = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
all_one_hop_text_units_lookup = {
k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP]))
for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data)
if v is not None
}
all_text_units_lookup = {}
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
for c_id in this_text_units:
if c_id in all_text_units_lookup:
continue
relation_counts = 0
for e in this_edges:
if (
e[1] in all_one_hop_text_units_lookup
and c_id in all_one_hop_text_units_lookup[e[1]]
):
relation_counts += 1
all_text_units_lookup[c_id] = {
"data": await text_chunks_db.get_by_id(c_id),
"order": index,
"relation_counts": relation_counts,
}
if any([v is None for v in all_text_units_lookup.values()]):
logger.warning("Text chunks are missing, maybe the storage is damaged")
all_text_units = [
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
]
all_text_units = sorted(
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
)
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.local_max_token_for_text_unit,
tokenizer_wrapper=tokenizer_wrapper, # 传入 wrapper
)
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
return all_text_units
async def _find_most_related_edges_from_entities(
node_datas: list[dict],
query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage,
tokenizer_wrapper,
):
all_related_edges = await knowledge_graph_inst.get_nodes_edges_batch([dp["entity_name"] for dp in node_datas])
all_edges = []
seen = set()
for this_edges in all_related_edges:
for e in this_edges:
sorted_edge = tuple(sorted(e))
if sorted_edge not in seen:
seen.add(sorted_edge)
all_edges.append(sorted_edge)
all_edges_pack = await knowledge_graph_inst.get_edges_batch(all_edges)
all_edges_degree = await knowledge_graph_inst.edge_degrees_batch(all_edges)
all_edges_data = [
{"src_tgt": k, "rank": d, **v}
for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree)
if v is not None
]
all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
all_edges_data = truncate_list_by_token_size(
all_edges_data,
key=lambda x: x["description"],
max_token_size=query_param.local_max_token_for_local_context,
tokenizer_wrapper=tokenizer_wrapper,
)
return all_edges_data
async def _build_local_query_context(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
community_reports: BaseKVStorage[CommunitySchema],
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
tokenizer_wrapper,
):
results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results):
return None
node_datas = await knowledge_graph_inst.get_nodes_batch([r["entity_name"] for r in results])
if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged")
node_degrees = await knowledge_graph_inst.node_degrees_batch([r["entity_name"] for r in results])
node_datas = [
{**n, "entity_name": k["entity_name"], "rank": d}
for k, n, d in zip(results, node_datas, node_degrees)
if n is not None
]
use_communities = await _find_most_related_community_from_entities(
node_datas, query_param, community_reports, tokenizer_wrapper
)
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst, tokenizer_wrapper
)
use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst, tokenizer_wrapper
)
logger.info(
f"Using {len(node_datas)} entites, {len(use_communities)} communities, {len(use_relations)} relations, {len(use_text_units)} text units"
)
entites_section_list = [["id", "entity", "type", "description", "rank"]]
for i, n in enumerate(node_datas):
entites_section_list.append(
[
i,
n["entity_name"],
n.get("entity_type", "UNKNOWN"),
n.get("description", "UNKNOWN"),
n["rank"],
]
)
entities_context = list_of_list_to_csv(entites_section_list)
relations_section_list = [
["id", "source", "target", "description", "weight", "rank"]
]
for i, e in enumerate(use_relations):
relations_section_list.append(
[
i,
e["src_tgt"][0],
e["src_tgt"][1],
e["description"],
e["weight"],
e["rank"],
]
)
relations_context = list_of_list_to_csv(relations_section_list)
communities_section_list = [["id", "content"]]
for i, c in enumerate(use_communities):
communities_section_list.append([i, c["report_string"]])
communities_context = list_of_list_to_csv(communities_section_list)
text_units_section_list = [["id", "content"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
return f"""
-----Reports-----
```csv
{communities_context}
```
-----Entities-----
```csv
{entities_context}
```
-----Relationships-----
```csv
{relations_context}
```
-----Sources-----
```csv
{text_units_context}
```
"""
async def local_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
community_reports: BaseKVStorage[CommunitySchema],
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
tokenizer_wrapper,
global_config: dict,
) -> str:
use_model_func = global_config["best_model_func"]
context = await _build_local_query_context(
query,
knowledge_graph_inst,
entities_vdb,
community_reports,
text_chunks_db,
query_param,
tokenizer_wrapper,
)
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["local_rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
return response
async def _map_global_communities(
query: str,
communities_data: list[CommunitySchema],
query_param: QueryParam,
global_config: dict,
tokenizer_wrapper,
):
use_string_json_convert_func = global_config["convert_response_to_json_func"]
use_model_func = global_config["best_model_func"]
community_groups = []
while len(communities_data):
this_group = truncate_list_by_token_size(
communities_data,
key=lambda x: x["report_string"],
max_token_size=query_param.global_max_token_for_community_report,
tokenizer_wrapper=tokenizer_wrapper, # 传入 wrapper
)
community_groups.append(this_group)
communities_data = communities_data[len(this_group) :]
async def _process(community_truncated_datas: list[CommunitySchema]) -> dict:
communities_section_list = [["id", "content", "rating", "importance"]]
for i, c in enumerate(community_truncated_datas):
communities_section_list.append(
[
i,
c["report_string"],
c["report_json"].get("rating", 0),
c["occurrence"],
]
)
community_context = list_of_list_to_csv(communities_section_list)
sys_prompt_temp = PROMPTS["global_map_rag_points"]
sys_prompt = sys_prompt_temp.format(context_data=community_context)
response = await use_model_func(
query,
system_prompt=sys_prompt,
**query_param.global_special_community_map_llm_kwargs,
)
data = use_string_json_convert_func(response)
return data.get("points", [])
logger.info(f"Grouping to {len(community_groups)} groups for global search")
responses = await asyncio.gather(*[_process(c) for c in community_groups])
return responses
async def global_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
community_reports: BaseKVStorage[CommunitySchema],
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
tokenizer_wrapper,
global_config: dict,
) -> str:
community_schema = await knowledge_graph_inst.community_schema()
community_schema = {
k: v for k, v in community_schema.items() if v["level"] <= query_param.level
}
if not len(community_schema):
return PROMPTS["fail_response"]
use_model_func = global_config["best_model_func"]
sorted_community_schemas = sorted(
community_schema.items(),
key=lambda x: x[1]["occurrence"],
reverse=True,
)
sorted_community_schemas = sorted_community_schemas[
: query_param.global_max_consider_community
]
community_datas = await community_reports.get_by_ids(
[k[0] for k in sorted_community_schemas]
)
community_datas = [c for c in community_datas if c is not None]
community_datas = [
c
for c in community_datas
if c["report_json"].get("rating", 0) >= query_param.global_min_community_rating
]
community_datas = sorted(
community_datas,
key=lambda x: (x["occurrence"], x["report_json"].get("rating", 0)),
reverse=True,
)
logger.info(f"Revtrieved {len(community_datas)} communities")
map_communities_points = await _map_global_communities(
query, community_datas, query_param, global_config, tokenizer_wrapper
)
final_support_points = []
for i, mc in enumerate(map_communities_points):
for point in mc:
if "description" not in point:
continue
final_support_points.append(
{
"analyst": i,
"answer": point["description"],
"score": point.get("score", 1),
}
)
final_support_points = [p for p in final_support_points if p["score"] > 0]
if not len(final_support_points):
return PROMPTS["fail_response"]
final_support_points = sorted(
final_support_points, key=lambda x: x["score"], reverse=True
)
final_support_points = truncate_list_by_token_size(
final_support_points,
key=lambda x: x["answer"],
max_token_size=query_param.global_max_token_for_community_report,
tokenizer_wrapper=tokenizer_wrapper, # 传入 wrapper
)
points_context = []
for dp in final_support_points:
points_context.append(
f"""----Analyst {dp['analyst']}----
Importance Score: {dp['score']}
{dp['answer']}
"""
)
points_context = "\n".join(points_context)
if query_param.only_need_context:
return points_context
sys_prompt_temp = PROMPTS["global_reduce_rag_response"]
response = await use_model_func(
query,
sys_prompt_temp.format(
report_data=points_context, response_type=query_param.response_type
),
)
return response
async def naive_query(
query,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
tokenizer_wrapper,
global_config: dict,
):
use_model_func = global_config["best_model_func"]
results = await chunks_vdb.query(query, top_k=query_param.top_k)
if not len(results):
return PROMPTS["fail_response"]
chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids)
maybe_trun_chunks = truncate_list_by_token_size(
chunks,
key=lambda x: x["content"],
max_token_size=query_param.naive_max_token_for_text_unit,
tokenizer_wrapper=tokenizer_wrapper, # 传入 wrapper
)
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
if query_param.only_need_context:
return section
sys_prompt_temp = PROMPTS["naive_rag_response"]
sys_prompt = sys_prompt_temp.format(
content_data=section, response_type=query_param.response_type
)
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
return response