Files
rag_agent/RAG-TEST-TOOLS/graph_builder.py

826 lines
33 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
graph_builder.py - 知识图谱构建的核心流程控制器
包含完整的构建流程和交互逻辑
"""
import os
import json
try:
import faiss
except ImportError: # pragma: no cover - depends on deployment environment
faiss = None
try:
import numpy as np
except ImportError: # pragma: no cover - depends on deployment environment
np = None
import time
from datetime import datetime
from pathlib import Path
import logging
from tqdm import tqdm
from collections import defaultdict
from models import GraphNode, GraphEdge, NodeType, EdgeType
from code_parser import (
parser, CPP_EXTENSIONS, IGNORE_DIRS,
safe_decode, extract_comment_before, extract_code_snippet,
extract_full_function_name, collect_call_expressions,
extract_class_name, extract_member_function_name, extract_base_classes
)
from llm_processor import generate_code_logic, generate_code_summary, get_qwen_embedding
from config import PROJECT_ROOT, MAX_CODE_LENGTH, QWEN_EMBEDDING_MODEL
logger = logging.getLogger(__name__)
# 全局图数据结构
FUNCTION_CALL_GRAPH = defaultdict(list)
FILE_DEPENDENCIES = defaultdict(list)
CALLED_BY_GRAPH = defaultdict(list)
def _as_node_type(value):
return value.value if hasattr(value, "value") else value
def _normalize_output_dir(output_dir):
if output_dir is None:
return Path(".")
path = Path(output_dir)
path.mkdir(parents=True, exist_ok=True)
return path
def _qualified_name(node):
if node.id.startswith("Function:"):
return node.id[len("Function:"):]
return node.name
def _has_real_embedding(embedding):
if not embedding:
return False
try:
return any(abs(float(value)) > 1e-12 for value in embedding)
except (TypeError, ValueError):
return False
def _function_metadata(node, embedding_dim=None):
raw = node.raw_attributes or {}
embedding_available = raw.get("embedding_available")
if embedding_available is None:
embedding_available = _has_real_embedding(node.embedding)
return {
"node_id": node.id,
"name": node.name,
"function_name": node.name,
"qualified_name": raw.get("qualified_name") or _qualified_name(node),
"file": node.file_path,
"file_path": node.file_path,
"start_line": node.start_line,
"end_line": node.end_line,
"signature": node.signature,
"summary": node.summary or "",
"logic_flow": node.logic_flow or "",
"code_snippet": raw.get("code_snippet", ""),
"calls": raw.get("calls", []),
"called_by": raw.get("called_by", []),
"includes": raw.get("includes", []),
"embedding_model": QWEN_EMBEDDING_MODEL,
"embedding_dim": embedding_dim or (len(node.embedding) if node.embedding else 0),
"embedding_available": bool(embedding_available),
"embedding_error": raw.get("embedding_error", ""),
}
class SymbolResolver:
"""符号解析器将代码符号映射为图谱节点ID"""
def __init__(self):
self._symbol_table = {}
self._file_to_entities = defaultdict(list)
def register_entity(self, node: GraphNode):
"""注册新发现的实体到符号表"""
self._symbol_table[node.name] = node.id
if node.file_path:
self._file_to_entities[node.file_path].append(node.id)
def resolve_name(self, name: str, current_context: dict) -> str:
"""
解析代码中的名称到节点ID
:param name: 代码中的名称
:param current_context: 当前上下文(文件、类等)
:return: 节点ID或None
"""
if name in self._symbol_table:
return self._symbol_table[name]
for qualified_name, node_id in self._symbol_table.items():
if qualified_name.endswith('::' + name) or qualified_name == name:
return node_id
return None
def process_single_file(file_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map):
"""处理单个C/C++文件,提取代码实体和结构关系"""
logger.info(f"处理单个文件: {file_path}")
try:
with open(file_path, 'rb') as f:
source_bytes = f.read()
source_lines = source_bytes.split(b'\n')
tree = parser.parse(source_bytes)
# 1. 创建文件节点
file_node_id = f"File:{file_path}"
file_node = GraphNode(
id=file_node_id,
type=NodeType.FILE,
name=Path(file_path).name,
file_path=file_path
)
all_nodes.append(file_node)
symbol_resolver.register_entity(file_node)
# 2. 收集头文件依赖
includes = []
for node in tree.root_node.named_children:
if node.type == "preproc_include":
path_node = node.child_by_field_name("path")
if path_node:
include_text = path_node.text.decode('utf-8').strip('"<>')
includes.append(include_text)
FILE_DEPENDENCIES[file_path] = includes
# 3. 遍历文件内的顶级声明
for node in tree.root_node.named_children:
# 3.1 处理函数定义
if node.type == "function_definition":
func_name = extract_full_function_name(node)
if func_name == "<unknown>":
logger.warning(f" 跳过无法解析函数名的节点: {node.text[:50]}...")
continue
# 创建函数节点
func_node_id = f"Function:{func_name}"
func_node = GraphNode(
id=func_node_id,
type=NodeType.FUNCTION,
name=func_name,
signature=node.text.decode('utf-8', errors='ignore').split('{')[0].strip() + ' {...}',
file_path=file_path,
start_line=node.start_point[0] + 1,
end_line=node.end_point[0] + 1,
raw_attributes={'ast_node_type': 'function_definition'}
)
all_nodes.append(func_node)
symbol_resolver.register_entity(func_node)
# 存储原始信息
comment = extract_comment_before(node, source_lines)
code_snippet = extract_code_snippet(tree, node, source_lines)
function_raw_info_map[func_node_id] = {
'ast_node': node,
'comment': comment,
'code_snippet': code_snippet,
'source_lines': source_lines,
'tree': tree
}
# 添加 FILE -> CONTAINS -> FUNCTION 边
all_edges.append(GraphEdge(
source_id=file_node_id,
target_id=func_node_id,
type=EdgeType.CONTAINS
))
# 收集函数调用关系
callees = collect_call_expressions(node)
FUNCTION_CALL_GRAPH[func_name] = callees
# 3.2 处理类/结构体定义
elif node.type in ["class_specifier", "struct_specifier"]:
class_name = extract_class_name(node)
if not class_name or class_name == "<unknown>":
logger.debug(f" 跳过无法解析类名的节点")
continue
# 创建类节点
class_node_id = f"Class:{class_name}"
class_node = GraphNode(
id=class_node_id,
type=NodeType.CLASS,
name=class_name,
signature=node.text.decode('utf-8', errors='ignore').split('{')[0].strip() + ' {...}',
file_path=file_path,
start_line=node.start_point[0] + 1,
end_line=node.end_point[0] + 1,
raw_attributes={
'ast_type': node.type,
'is_struct': node.type == "struct_specifier"
}
)
all_nodes.append(class_node)
symbol_resolver.register_entity(class_node)
# 添加 FILE -> CONTAINS -> CLASS 边
all_edges.append(GraphEdge(
source_id=file_node_id,
target_id=class_node_id,
type=EdgeType.CONTAINS
))
# 处理基类/继承关系
base_clause = node.child_by_field_name("base_clause")
if base_clause:
base_classes = extract_base_classes(base_clause)
for base_class_info in base_classes:
base_name = base_class_info.get('name')
access_type = base_class_info.get('access', 'public')
edge_type = EdgeType.EXTENDS if access_type in ['public', 'protected',
'private'] else EdgeType.IMPLEMENTS
if base_name:
base_id = symbol_resolver.resolve_name(
base_name,
{'current_file': file_path, 'current_class': class_name}
)
if base_id:
all_edges.append(GraphEdge(
source_id=class_node_id,
target_id=base_id,
type=edge_type,
properties={'access': access_type}
))
# 处理类内方法
for member in node.named_children:
if member.type == "function_definition":
member_func_name = extract_member_function_name(member, class_name)
if not member_func_name or member_func_name == "<unknown>":
continue
member_node_id = f"Function:{class_name}::{member_func_name}"
existing_member_node = next((n for n in all_nodes if n.id == member_node_id), None)
if not existing_member_node:
member_node = GraphNode(
id=member_node_id,
type=NodeType.FUNCTION,
name=member_func_name,
signature=member.text.decode('utf-8', errors='ignore').split('{')[0].strip() + ' {...}',
file_path=file_path,
start_line=member.start_point[0] + 1,
end_line=member.end_point[0] + 1,
raw_attributes={
'is_member': True,
'parent_class': class_name,
'ast_node_type': 'member_function_definition'
}
)
all_nodes.append(member_node)
symbol_resolver.register_entity(member_node)
# 存储原始信息
member_comment = extract_comment_before(member, source_lines)
member_code = extract_code_snippet(tree, member, source_lines)
function_raw_info_map[member_node_id] = {
'ast_node': member,
'comment': member_comment,
'code_snippet': member_code,
'source_lines': source_lines,
'tree': tree
}
# 收集成员函数调用
member_callees = collect_call_expressions(member)
FUNCTION_CALL_GRAPH[f"{class_name}::{member_func_name}"] = member_callees
# 添加类包含边
all_edges.append(GraphEdge(
source_id=class_node_id,
target_id=member_node_id,
type=EdgeType.CONTAINS,
properties={'member_type': 'method'}
))
except Exception as e:
logger.error(f"解析文件 {file_path} 失败: {e}", exc_info=True)
raise
def process_project_files(project_root, all_nodes, all_edges, symbol_resolver, function_raw_info_map):
"""处理整个项目目录"""
logger.info(f"处理项目目录: {project_root}")
for root, dirs, files in os.walk(project_root):
# 跳过忽略目录
dirs[:] = [d for d in dirs if d not in IGNORE_DIRS]
for file in files:
file_path_obj = Path(file)
if file_path_obj.suffix.lower() not in CPP_EXTENSIONS:
continue
full_file_path = os.path.join(root, file)
process_single_file(full_file_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map)
return len(files)
def semantic_enhancement(all_nodes, all_edges, function_raw_info_map):
"""阶段三:语义增强,为函数节点生成摘要与向量嵌入"""
logger.info("阶段三:语义增强,为函数节点生成摘要与向量嵌入...")
# 1. 过滤出需要处理的函数节点
function_nodes = [n for n in all_nodes if n.type == NodeType.FUNCTION]
nodes_with_raw_info = []
nodes_without_raw_info = []
for func_node in function_nodes:
if func_node.id in function_raw_info_map:
nodes_with_raw_info.append(func_node)
else:
nodes_without_raw_info.append(func_node)
logger.info(f"有原始信息的节点: {len(nodes_with_raw_info)}")
if nodes_without_raw_info:
logger.warning(f"无原始信息的节点: {len(nodes_without_raw_info)}")
# 2. 处理有原始信息的节点
if nodes_with_raw_info:
print(f"\n开始为 {len(nodes_with_raw_info)} 个函数生成摘要和向量...")
print("这可能需要一些时间,请耐心等待...")
processed_count = 0
for func_node in tqdm(nodes_with_raw_info, desc="生成函数摘要", unit=""):
func_node_id = func_node.id
func_name = func_node.name
# 获取原始信息
raw_info = function_raw_info_map[func_node_id]
comment = raw_info['comment']
code_snippet = raw_info['code_snippet']
file_path = func_node.file_path
func_node.raw_attributes.update({
"comment": comment,
"code_snippet": code_snippet,
"includes": FILE_DEPENDENCIES.get(file_path, []),
"qualified_name": _qualified_name(func_node),
})
# 1. 生成流程逻辑
try:
logic = generate_code_logic(func_name, comment, code_snippet, file_path)
except Exception as e:
logger.error(f"为函数 {func_name} 生成逻辑失败: {e}")
logic = "逻辑生成失败。"
# 2. 生成功能摘要
try:
summary = generate_code_summary(
func_name, comment, logic, code_snippet, file_path
)
except Exception as e:
logger.error(f"为函数 {func_name} 生成摘要失败: {e}")
summary = "摘要生成失败。"
# 3. 更新节点信息
func_node.summary = summary
func_node.logic_flow = logic
# 获取调用关系
called_functions = FUNCTION_CALL_GRAPH.get(func_name, [])
caller_functions = CALLED_BY_GRAPH.get(func_name, [])
func_node.raw_attributes.update({
'calls': called_functions,
'called_by': caller_functions,
'includes': FILE_DEPENDENCIES.get(file_path, [])
})
# 4. 生成向量嵌入
context_text_for_embedding = (
f"函数名: {func_name}\n"
f"功能摘要: {summary}\n"
f"流程逻辑: {logic}\n"
f"所在文件: {os.path.basename(file_path)}"
)
try:
embedding_vector = get_qwen_embedding(context_text_for_embedding)
embedding_values = (
embedding_vector.tolist()
if hasattr(embedding_vector, "tolist")
else list(embedding_vector)
)
if not _has_real_embedding(embedding_values):
raise RuntimeError("Embedding API returned an empty or zero vector.")
func_node.embedding = embedding_values
func_node.raw_attributes["embedding_available"] = True
func_node.raw_attributes["embedding_error"] = ""
except Exception as e:
logger.error(f"为函数 {func_name} 生成嵌入失败: {e}")
func_node.embedding = [0.0] * 1024
func_node.raw_attributes["embedding_available"] = False
func_node.raw_attributes["embedding_error"] = str(e)[:500]
processed_count += 1
# 每处理5个函数显示一次进度
if processed_count % 5 == 0:
print(f" 已处理 {processed_count}/{len(nodes_with_raw_info)} 个函数")
# 添加延迟避免API限流
time.sleep(0.1)
print(f"✓ 完成 {len(nodes_with_raw_info)} 个函数的语义增强")
else:
logger.warning("没有找到任何有原始信息的函数节点,跳过摘要生成。")
# 3. 处理无原始信息的节点
if nodes_without_raw_info:
print(f"\n{len(nodes_without_raw_info)} 个无原始信息的节点生成基础信息...")
for func_node in nodes_without_raw_info:
func_node.summary = f"函数 {func_node.name},位于 {func_node.file_path}"
func_node.logic_flow = "无法生成逻辑流程(缺少原始代码信息)"
func_node.embedding = [0.0] * 1024
func_node.raw_attributes.update({
"calls": FUNCTION_CALL_GRAPH.get(func_node.name, []),
"called_by": CALLED_BY_GRAPH.get(func_node.name, []),
"includes": FILE_DEPENDENCIES.get(func_node.file_path, []),
"qualified_name": _qualified_name(func_node),
"embedding_available": False,
"embedding_error": "Missing source information.",
})
def build_call_edges(all_nodes, all_edges, symbol_resolver):
"""阶段二:构建函数调用边"""
logger.info("阶段二:解析符号,构建函数调用图 (CALLS 边)...")
# 构建反向调用图
CALLED_BY_GRAPH.clear()
for caller_name, callee_names in FUNCTION_CALL_GRAPH.items():
caller_node_id = f"Function:{caller_name}"
# 确保调用者节点已存在
if not any(n.id == caller_node_id for n in all_nodes):
continue
for callee_name in callee_names:
# 通过符号解析器将名称解析为节点ID
callee_node_id = symbol_resolver.resolve_name(
callee_name,
{'current_file': next((n.file_path for n in all_nodes if n.id == caller_node_id), None)}
)
if callee_node_id:
# 添加 CALLS 边
all_edges.append(GraphEdge(
source_id=caller_node_id,
target_id=callee_node_id,
type=EdgeType.CALLS
))
# 更新反向图
if callee_name not in CALLED_BY_GRAPH:
CALLED_BY_GRAPH[callee_name] = []
CALLED_BY_GRAPH[callee_name].append(caller_name)
logger.info(f"阶段二完成。已构建边数: {len(all_edges)}")
def save_results(all_nodes, all_edges, base_name, target_path, mode, output_dir=None):
"""阶段四:持久化存储结果"""
logger.info("阶段四:持久化图谱与向量索引...")
output_path = _normalize_output_dir(output_dir)
graph_data_path = str(output_path / f"{base_name}_code_knowledge_graph.json")
vector_db_path = str(output_path / f"{base_name}_rag.faiss")
metadata_path = str(output_path / f"{base_name}_rag_metadata.json")
# 1. 保存为图谱文件
nodes_dict = [node.dict() for node in all_nodes]
edges_dict = [edge.dict() for edge in all_edges]
graph_data = {
"metadata": {
"project_root": target_path if mode == "full_project" else "single_file",
"file_path": target_path if mode == "single_file" else None,
"mode": mode,
"generated_at": datetime.now().isoformat(),
"total_nodes": len(all_nodes),
"total_edges": len(all_edges)
},
"nodes": nodes_dict,
"edges": edges_dict
}
with open(graph_data_path, 'w', encoding='utf-8') as f:
json.dump(graph_data, f, indent=2, ensure_ascii=False, default=str)
logger.info(f" 图谱结构已保存至: {graph_data_path}")
print(f"✓ 图谱结构已保存: {graph_data_path}")
# 2. 构建并保存向量索引
function_nodes = [node for node in all_nodes if node.type == NodeType.FUNCTION and node.embedding]
if function_nodes:
dimension = len(function_nodes[0].embedding)
metadata_for_faiss = []
vectors = []
for node in function_nodes:
vectors.append(node.embedding)
metadata_for_faiss.append(_function_metadata(node, dimension))
if vectors:
if faiss is not None:
if np is None:
raise RuntimeError("numpy is required when writing a FAISS index.")
vectors_np = np.array(vectors, dtype='float32')
index = faiss.IndexFlatL2(dimension)
index.add(vectors_np)
faiss.write_index(index, vector_db_path)
else:
with open(vector_db_path, 'w', encoding='utf-8') as f:
json.dump(
{
"format": "simple_l2_vector_index",
"dimension": dimension,
"vectors": [[float(value) for value in vector] for vector in vectors],
},
f,
ensure_ascii=False,
)
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata_for_faiss, f, indent=2, ensure_ascii=False)
logger.info(f" 向量索引已保存。函数节点数: {len(function_nodes)}")
logger.info(f" 向量数据库: {vector_db_path}")
logger.info(f" 元数据: {metadata_path}")
print(f"✓ 向量索引已保存: {vector_db_path}")
else:
logger.warning(" 没有可生成向量的函数节点跳过FAISS索引构建。")
print("⚠ 无向量数据跳过FAISS索引构建")
return graph_data_path, vector_db_path, metadata_path
def build_code_knowledge_base(target_path, output_dir=None, semantic=True, base_name=None, embedding_dim=1024):
"""Build a code knowledge base without interactive prompts."""
target_path = str(Path(target_path).resolve())
if not os.path.exists(target_path):
raise FileNotFoundError(f"Target path does not exist: {target_path}")
all_nodes = []
all_edges = []
symbol_resolver = SymbolResolver()
function_raw_info_map = {}
FUNCTION_CALL_GRAPH.clear()
FILE_DEPENDENCIES.clear()
CALLED_BY_GRAPH.clear()
if os.path.isdir(target_path):
mode = "full_project"
process_project_files(target_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map)
base_name = base_name or "project"
else:
mode = "single_file"
process_single_file(target_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map)
base_name = base_name or f"file_{Path(target_path).stem}"
build_call_edges(all_nodes, all_edges, symbol_resolver)
if semantic:
semantic_enhancement(all_nodes, all_edges, function_raw_info_map)
else:
for node in all_nodes:
if node.type == NodeType.FUNCTION:
raw_info = function_raw_info_map.get(node.id, {})
node.summary = node.summary or f"Function {node.name}, located at {node.file_path}"
node.logic_flow = node.logic_flow or "Semantic enhancement was skipped."
node.embedding = node.embedding or [0.0] * embedding_dim
node.raw_attributes.update({
"comment": raw_info.get("comment", ""),
"code_snippet": raw_info.get("code_snippet", ""),
"calls": FUNCTION_CALL_GRAPH.get(node.name, []),
"called_by": CALLED_BY_GRAPH.get(node.name, []),
"includes": FILE_DEPENDENCIES.get(node.file_path, []),
"qualified_name": _qualified_name(node),
"embedding_available": False,
"embedding_error": "Semantic enhancement was skipped.",
})
return save_results(all_nodes, all_edges, base_name, target_path, mode, output_dir=output_dir)
def build_rag_database_interactive():
"""交互式知识库构建主流程"""
logger.info("=== 交互式知识图谱构建流程 ===")
# 显示模式选择菜单
print("\n" + "=" * 50)
print("知识库构建模式选择")
print("=" * 50)
print("1. 全项目模式 - 处理整个项目目录")
print("2. 单文件模式 - 处理单个文件")
print("3. 退出")
print("=" * 50)
choice = input("请选择模式 (1-3): ").strip()
if choice == "3":
logger.info("用户选择退出")
return None, None
if choice == "1":
# 全项目模式 - 直接引导用户输入
print("\n" + "=" * 60)
print("全项目知识库构建模式")
print("=" * 60)
# 显示当前config.py中配置的默认路径
current_default_path = str(PROJECT_ROOT) if PROJECT_ROOT else "未配置"
if not current_default_path or current_default_path == ".":
current_default_path = "当前工作目录"
# 直接要求用户输入
while True:
project_prompt = f"请直接输入要构建知识库的项目根目录完整路径\n(当前config.py中的默认路径: {current_default_path}): "
user_input = input(project_prompt).strip()
if not user_input:
# 用户没有输入尝试使用config中的默认路径
if PROJECT_ROOT and os.path.exists(PROJECT_ROOT) and os.path.isdir(PROJECT_ROOT):
target_path = PROJECT_ROOT
print(f"使用config.py中的默认路径: {target_path}")
else:
print("⚠ 您没有输入路径且config.py中的默认路径无效。")
retry = input("是否重新输入? (y/n): ").strip().lower()
if retry != 'y':
print("知识库构建取消。")
return None, None
continue
else:
# 使用用户输入的路径
target_path = Path(user_input)
# 验证路径
if not os.path.exists(target_path):
print(f"⚠ 路径不存在: {target_path}")
retry = input("是否重新输入? (y/n): ").strip().lower()
if retry != 'y':
print("知识库构建取消。")
return None, None
continue
if not os.path.isdir(target_path):
print(f"⚠ 路径不是目录: {target_path}")
retry = input("是否重新输入? (y/n): ").strip().lower()
if retry != 'y':
print("知识库构建取消。")
return None, None
continue
break # 路径有效,退出循环
mode = "full_project"
print(f"\n✓ 将处理整个项目: {target_path}")
print("开始处理...")
elif choice == "2":
# 单文件模式
while True:
file_path = input("\n请输入要处理的文件完整路径: ").strip()
if not file_path:
print("输入为空,退出单文件模式")
return None, None
if not os.path.exists(file_path):
print(f"错误: 文件不存在: {file_path}")
continue
file_ext = Path(file_path).suffix.lower()
if file_ext not in CPP_EXTENSIONS:
print(f"警告: 文件类型 {file_ext} 可能不是C/C++文件,继续吗? (y/n): ", end="")
confirm = input().strip().lower()
if confirm != 'y':
continue
target_path = file_path
mode = "single_file"
print(f"\n将处理单个文件: {target_path}")
print("开始处理...")
break
else:
print("无效的选择,退出")
return None, None
# 确认开始
print("\n即将开始构建知识库,确认继续? (y/n): ", end="")
confirm = input().strip().lower()
if confirm != 'y':
print("操作已取消")
return None, None
# --- 初始化数据结构 ---
all_nodes = [] # 存储 GraphNode 对象
all_edges = [] # 存储 GraphEdge 对象
symbol_resolver = SymbolResolver() # 符号解析器
function_raw_info_map = {} # 临时存储函数的原始信息
# 清空全局图
FUNCTION_CALL_GRAPH.clear()
FILE_DEPENDENCIES.clear()
CALLED_BY_GRAPH.clear()
# ========================
# 阶段一:静态分析,提取节点与硬性边
# ========================
logger.info("阶段一:静态分析,提取代码实体与结构关系...")
if choice == "1":
# 全项目模式
processed_files_count = process_project_files(target_path, all_nodes, all_edges,
symbol_resolver, function_raw_info_map)
logger.info(f"已处理 {processed_files_count} 个文件")
else:
# 单文件模式
process_single_file(target_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map)
logger.info(f"已处理单个文件: {target_path}")
logger.info(f"阶段一完成。初步识别节点数: {len(all_nodes)}")
# ========================
# 阶段二:解析符号,构建准确的调用边
# ========================
build_call_edges(all_nodes, all_edges, symbol_resolver)
# ========================
# 阶段三:语义增强,生成摘要与向量
# ========================
print("\n" + "=" * 50)
print("阶段三:语义增强")
print("=" * 50)
print("此阶段将调用LLM API为函数生成摘要和向量嵌入")
print(f"需要处理的函数节点数: {len([n for n in all_nodes if n.type == NodeType.FUNCTION])}")
print("这可能会消耗API调用额度并需要一些时间")
print("-" * 50)
proceed = input("是否继续阶段三? (y/n): ").strip().lower()
if proceed != 'y':
logger.info("用户跳过阶段三(语义增强)")
print("跳过阶段三,直接进入阶段四")
# 为未处理的函数节点设置默认值
for node in all_nodes:
if node.type == NodeType.FUNCTION and not node.summary:
node.summary = f"函数 {node.name},位于 {node.file_path}"
node.logic_flow = "跳过语义增强阶段"
node.embedding = [0.0] * 1024
raw_info = function_raw_info_map.get(node.id, {})
node.raw_attributes.update({
"comment": raw_info.get("comment", ""),
"code_snippet": raw_info.get("code_snippet", ""),
"calls": FUNCTION_CALL_GRAPH.get(node.name, []),
"called_by": CALLED_BY_GRAPH.get(node.name, []),
"includes": FILE_DEPENDENCIES.get(node.file_path, []),
"qualified_name": _qualified_name(node),
"embedding_available": False,
"embedding_error": "Semantic enhancement was skipped.",
})
else:
semantic_enhancement(all_nodes, all_edges, function_raw_info_map)
# ========================
# 阶段四:持久化存储
# ========================
print("\n" + "=" * 50)
print("阶段四:持久化存储")
print("=" * 50)
save_choice = input("是否保存结果? (y/n): ").strip().lower()
if save_choice != 'y':
logger.info("用户取消保存,退出构建流程")
print("构建流程已取消")
return all_nodes, all_edges
# 确定保存文件名
if choice == "1":
base_name = "project"
else:
file_name = Path(target_path).stem
base_name = f"file_{file_name}"
graph_data_path, vector_db_path, metadata_path = save_results(
all_nodes, all_edges, base_name, target_path, mode
)
print("\n" + "=" * 50)
print("构建完成!")
print("=" * 50)
print(f"总节点数: {len(all_nodes)}")
print(f"总边数: {len(all_edges)}")
print(f"图谱文件: {graph_data_path}")
if os.path.exists(vector_db_path):
print(f"向量文件: {vector_db_path}")
logger.info("=== 知识图谱构建流程全部完成 ===")
return all_nodes, all_edges