Files

826 lines
33 KiB
Python
Raw Permalink Normal View History

"""
graph_builder.py - 知识图谱构建的核心流程控制器
包含完整的构建流程和交互逻辑
"""
import os
import json
try:
import faiss
except ImportError: # pragma: no cover - depends on deployment environment
faiss = None
try:
import numpy as np
except ImportError: # pragma: no cover - depends on deployment environment
np = None
import time
from datetime import datetime
from pathlib import Path
import logging
from tqdm import tqdm
from collections import defaultdict
from models import GraphNode, GraphEdge, NodeType, EdgeType
from code_parser import (
parser, CPP_EXTENSIONS, IGNORE_DIRS,
safe_decode, extract_comment_before, extract_code_snippet,
extract_full_function_name, collect_call_expressions,
extract_class_name, extract_member_function_name, extract_base_classes
)
from llm_processor import generate_code_logic, generate_code_summary, get_qwen_embedding
from config import PROJECT_ROOT, MAX_CODE_LENGTH, QWEN_EMBEDDING_MODEL
logger = logging.getLogger(__name__)
# 全局图数据结构
FUNCTION_CALL_GRAPH = defaultdict(list)
FILE_DEPENDENCIES = defaultdict(list)
CALLED_BY_GRAPH = defaultdict(list)
def _as_node_type(value):
return value.value if hasattr(value, "value") else value
def _normalize_output_dir(output_dir):
if output_dir is None:
return Path(".")
path = Path(output_dir)
path.mkdir(parents=True, exist_ok=True)
return path
def _qualified_name(node):
if node.id.startswith("Function:"):
return node.id[len("Function:"):]
return node.name
def _has_real_embedding(embedding):
if not embedding:
return False
try:
return any(abs(float(value)) > 1e-12 for value in embedding)
except (TypeError, ValueError):
return False
def _function_metadata(node, embedding_dim=None):
raw = node.raw_attributes or {}
embedding_available = raw.get("embedding_available")
if embedding_available is None:
embedding_available = _has_real_embedding(node.embedding)
return {
"node_id": node.id,
"name": node.name,
"function_name": node.name,
"qualified_name": raw.get("qualified_name") or _qualified_name(node),
"file": node.file_path,
"file_path": node.file_path,
"start_line": node.start_line,
"end_line": node.end_line,
"signature": node.signature,
"summary": node.summary or "",
"logic_flow": node.logic_flow or "",
"code_snippet": raw.get("code_snippet", ""),
"calls": raw.get("calls", []),
"called_by": raw.get("called_by", []),
"includes": raw.get("includes", []),
"embedding_model": QWEN_EMBEDDING_MODEL,
"embedding_dim": embedding_dim or (len(node.embedding) if node.embedding else 0),
"embedding_available": bool(embedding_available),
"embedding_error": raw.get("embedding_error", ""),
}
class SymbolResolver:
"""符号解析器将代码符号映射为图谱节点ID"""
def __init__(self):
self._symbol_table = {}
self._file_to_entities = defaultdict(list)
def register_entity(self, node: GraphNode):
"""注册新发现的实体到符号表"""
self._symbol_table[node.name] = node.id
if node.file_path:
self._file_to_entities[node.file_path].append(node.id)
def resolve_name(self, name: str, current_context: dict) -> str:
"""
解析代码中的名称到节点ID
:param name: 代码中的名称
:param current_context: 当前上下文文件类等
:return: 节点ID或None
"""
if name in self._symbol_table:
return self._symbol_table[name]
for qualified_name, node_id in self._symbol_table.items():
if qualified_name.endswith('::' + name) or qualified_name == name:
return node_id
return None
def process_single_file(file_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map):
"""处理单个C/C++文件,提取代码实体和结构关系"""
logger.info(f"处理单个文件: {file_path}")
try:
with open(file_path, 'rb') as f:
source_bytes = f.read()
source_lines = source_bytes.split(b'\n')
tree = parser.parse(source_bytes)
# 1. 创建文件节点
file_node_id = f"File:{file_path}"
file_node = GraphNode(
id=file_node_id,
type=NodeType.FILE,
name=Path(file_path).name,
file_path=file_path
)
all_nodes.append(file_node)
symbol_resolver.register_entity(file_node)
# 2. 收集头文件依赖
includes = []
for node in tree.root_node.named_children:
if node.type == "preproc_include":
path_node = node.child_by_field_name("path")
if path_node:
include_text = path_node.text.decode('utf-8').strip('"<>')
includes.append(include_text)
FILE_DEPENDENCIES[file_path] = includes
# 3. 遍历文件内的顶级声明
for node in tree.root_node.named_children:
# 3.1 处理函数定义
if node.type == "function_definition":
func_name = extract_full_function_name(node)
if func_name == "<unknown>":
logger.warning(f" 跳过无法解析函数名的节点: {node.text[:50]}...")
continue
# 创建函数节点
func_node_id = f"Function:{func_name}"
func_node = GraphNode(
id=func_node_id,
type=NodeType.FUNCTION,
name=func_name,
signature=node.text.decode('utf-8', errors='ignore').split('{')[0].strip() + ' {...}',
file_path=file_path,
start_line=node.start_point[0] + 1,
end_line=node.end_point[0] + 1,
raw_attributes={'ast_node_type': 'function_definition'}
)
all_nodes.append(func_node)
symbol_resolver.register_entity(func_node)
# 存储原始信息
comment = extract_comment_before(node, source_lines)
code_snippet = extract_code_snippet(tree, node, source_lines)
function_raw_info_map[func_node_id] = {
'ast_node': node,
'comment': comment,
'code_snippet': code_snippet,
'source_lines': source_lines,
'tree': tree
}
# 添加 FILE -> CONTAINS -> FUNCTION 边
all_edges.append(GraphEdge(
source_id=file_node_id,
target_id=func_node_id,
type=EdgeType.CONTAINS
))
# 收集函数调用关系
callees = collect_call_expressions(node)
FUNCTION_CALL_GRAPH[func_name] = callees
# 3.2 处理类/结构体定义
elif node.type in ["class_specifier", "struct_specifier"]:
class_name = extract_class_name(node)
if not class_name or class_name == "<unknown>":
logger.debug(f" 跳过无法解析类名的节点")
continue
# 创建类节点
class_node_id = f"Class:{class_name}"
class_node = GraphNode(
id=class_node_id,
type=NodeType.CLASS,
name=class_name,
signature=node.text.decode('utf-8', errors='ignore').split('{')[0].strip() + ' {...}',
file_path=file_path,
start_line=node.start_point[0] + 1,
end_line=node.end_point[0] + 1,
raw_attributes={
'ast_type': node.type,
'is_struct': node.type == "struct_specifier"
}
)
all_nodes.append(class_node)
symbol_resolver.register_entity(class_node)
# 添加 FILE -> CONTAINS -> CLASS 边
all_edges.append(GraphEdge(
source_id=file_node_id,
target_id=class_node_id,
type=EdgeType.CONTAINS
))
# 处理基类/继承关系
base_clause = node.child_by_field_name("base_clause")
if base_clause:
base_classes = extract_base_classes(base_clause)
for base_class_info in base_classes:
base_name = base_class_info.get('name')
access_type = base_class_info.get('access', 'public')
edge_type = EdgeType.EXTENDS if access_type in ['public', 'protected',
'private'] else EdgeType.IMPLEMENTS
if base_name:
base_id = symbol_resolver.resolve_name(
base_name,
{'current_file': file_path, 'current_class': class_name}
)
if base_id:
all_edges.append(GraphEdge(
source_id=class_node_id,
target_id=base_id,
type=edge_type,
properties={'access': access_type}
))
# 处理类内方法
for member in node.named_children:
if member.type == "function_definition":
member_func_name = extract_member_function_name(member, class_name)
if not member_func_name or member_func_name == "<unknown>":
continue
member_node_id = f"Function:{class_name}::{member_func_name}"
existing_member_node = next((n for n in all_nodes if n.id == member_node_id), None)
if not existing_member_node:
member_node = GraphNode(
id=member_node_id,
type=NodeType.FUNCTION,
name=member_func_name,
signature=member.text.decode('utf-8', errors='ignore').split('{')[0].strip() + ' {...}',
file_path=file_path,
start_line=member.start_point[0] + 1,
end_line=member.end_point[0] + 1,
raw_attributes={
'is_member': True,
'parent_class': class_name,
'ast_node_type': 'member_function_definition'
}
)
all_nodes.append(member_node)
symbol_resolver.register_entity(member_node)
# 存储原始信息
member_comment = extract_comment_before(member, source_lines)
member_code = extract_code_snippet(tree, member, source_lines)
function_raw_info_map[member_node_id] = {
'ast_node': member,
'comment': member_comment,
'code_snippet': member_code,
'source_lines': source_lines,
'tree': tree
}
# 收集成员函数调用
member_callees = collect_call_expressions(member)
FUNCTION_CALL_GRAPH[f"{class_name}::{member_func_name}"] = member_callees
# 添加类包含边
all_edges.append(GraphEdge(
source_id=class_node_id,
target_id=member_node_id,
type=EdgeType.CONTAINS,
properties={'member_type': 'method'}
))
except Exception as e:
logger.error(f"解析文件 {file_path} 失败: {e}", exc_info=True)
raise
def process_project_files(project_root, all_nodes, all_edges, symbol_resolver, function_raw_info_map):
"""处理整个项目目录"""
logger.info(f"处理项目目录: {project_root}")
for root, dirs, files in os.walk(project_root):
# 跳过忽略目录
dirs[:] = [d for d in dirs if d not in IGNORE_DIRS]
for file in files:
file_path_obj = Path(file)
if file_path_obj.suffix.lower() not in CPP_EXTENSIONS:
continue
full_file_path = os.path.join(root, file)
process_single_file(full_file_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map)
return len(files)
def semantic_enhancement(all_nodes, all_edges, function_raw_info_map):
"""阶段三:语义增强,为函数节点生成摘要与向量嵌入"""
logger.info("阶段三:语义增强,为函数节点生成摘要与向量嵌入...")
# 1. 过滤出需要处理的函数节点
function_nodes = [n for n in all_nodes if n.type == NodeType.FUNCTION]
nodes_with_raw_info = []
nodes_without_raw_info = []
for func_node in function_nodes:
if func_node.id in function_raw_info_map:
nodes_with_raw_info.append(func_node)
else:
nodes_without_raw_info.append(func_node)
logger.info(f"有原始信息的节点: {len(nodes_with_raw_info)}")
if nodes_without_raw_info:
logger.warning(f"无原始信息的节点: {len(nodes_without_raw_info)}")
# 2. 处理有原始信息的节点
if nodes_with_raw_info:
print(f"\n开始为 {len(nodes_with_raw_info)} 个函数生成摘要和向量...")
print("这可能需要一些时间,请耐心等待...")
processed_count = 0
for func_node in tqdm(nodes_with_raw_info, desc="生成函数摘要", unit=""):
func_node_id = func_node.id
func_name = func_node.name
# 获取原始信息
raw_info = function_raw_info_map[func_node_id]
comment = raw_info['comment']
code_snippet = raw_info['code_snippet']
file_path = func_node.file_path
func_node.raw_attributes.update({
"comment": comment,
"code_snippet": code_snippet,
"includes": FILE_DEPENDENCIES.get(file_path, []),
"qualified_name": _qualified_name(func_node),
})
# 1. 生成流程逻辑
try:
logic = generate_code_logic(func_name, comment, code_snippet, file_path)
except Exception as e:
logger.error(f"为函数 {func_name} 生成逻辑失败: {e}")
logic = "逻辑生成失败。"
# 2. 生成功能摘要
try:
summary = generate_code_summary(
func_name, comment, logic, code_snippet, file_path
)
except Exception as e:
logger.error(f"为函数 {func_name} 生成摘要失败: {e}")
summary = "摘要生成失败。"
# 3. 更新节点信息
func_node.summary = summary
func_node.logic_flow = logic
# 获取调用关系
called_functions = FUNCTION_CALL_GRAPH.get(func_name, [])
caller_functions = CALLED_BY_GRAPH.get(func_name, [])
func_node.raw_attributes.update({
'calls': called_functions,
'called_by': caller_functions,
'includes': FILE_DEPENDENCIES.get(file_path, [])
})
# 4. 生成向量嵌入
context_text_for_embedding = (
f"函数名: {func_name}\n"
f"功能摘要: {summary}\n"
f"流程逻辑: {logic}\n"
f"所在文件: {os.path.basename(file_path)}"
)
try:
embedding_vector = get_qwen_embedding(context_text_for_embedding)
embedding_values = (
embedding_vector.tolist()
if hasattr(embedding_vector, "tolist")
else list(embedding_vector)
)
if not _has_real_embedding(embedding_values):
raise RuntimeError("Embedding API returned an empty or zero vector.")
func_node.embedding = embedding_values
func_node.raw_attributes["embedding_available"] = True
func_node.raw_attributes["embedding_error"] = ""
except Exception as e:
logger.error(f"为函数 {func_name} 生成嵌入失败: {e}")
func_node.embedding = [0.0] * 1024
func_node.raw_attributes["embedding_available"] = False
func_node.raw_attributes["embedding_error"] = str(e)[:500]
processed_count += 1
# 每处理5个函数显示一次进度
if processed_count % 5 == 0:
print(f" 已处理 {processed_count}/{len(nodes_with_raw_info)} 个函数")
# 添加延迟避免API限流
time.sleep(0.1)
print(f"✓ 完成 {len(nodes_with_raw_info)} 个函数的语义增强")
else:
logger.warning("没有找到任何有原始信息的函数节点,跳过摘要生成。")
# 3. 处理无原始信息的节点
if nodes_without_raw_info:
print(f"\n{len(nodes_without_raw_info)} 个无原始信息的节点生成基础信息...")
for func_node in nodes_without_raw_info:
func_node.summary = f"函数 {func_node.name},位于 {func_node.file_path}"
func_node.logic_flow = "无法生成逻辑流程(缺少原始代码信息)"
func_node.embedding = [0.0] * 1024
func_node.raw_attributes.update({
"calls": FUNCTION_CALL_GRAPH.get(func_node.name, []),
"called_by": CALLED_BY_GRAPH.get(func_node.name, []),
"includes": FILE_DEPENDENCIES.get(func_node.file_path, []),
"qualified_name": _qualified_name(func_node),
"embedding_available": False,
"embedding_error": "Missing source information.",
})
def build_call_edges(all_nodes, all_edges, symbol_resolver):
"""阶段二:构建函数调用边"""
logger.info("阶段二:解析符号,构建函数调用图 (CALLS 边)...")
# 构建反向调用图
CALLED_BY_GRAPH.clear()
for caller_name, callee_names in FUNCTION_CALL_GRAPH.items():
caller_node_id = f"Function:{caller_name}"
# 确保调用者节点已存在
if not any(n.id == caller_node_id for n in all_nodes):
continue
for callee_name in callee_names:
# 通过符号解析器将名称解析为节点ID
callee_node_id = symbol_resolver.resolve_name(
callee_name,
{'current_file': next((n.file_path for n in all_nodes if n.id == caller_node_id), None)}
)
if callee_node_id:
# 添加 CALLS 边
all_edges.append(GraphEdge(
source_id=caller_node_id,
target_id=callee_node_id,
type=EdgeType.CALLS
))
# 更新反向图
if callee_name not in CALLED_BY_GRAPH:
CALLED_BY_GRAPH[callee_name] = []
CALLED_BY_GRAPH[callee_name].append(caller_name)
logger.info(f"阶段二完成。已构建边数: {len(all_edges)}")
def save_results(all_nodes, all_edges, base_name, target_path, mode, output_dir=None):
"""阶段四:持久化存储结果"""
logger.info("阶段四:持久化图谱与向量索引...")
output_path = _normalize_output_dir(output_dir)
graph_data_path = str(output_path / f"{base_name}_code_knowledge_graph.json")
vector_db_path = str(output_path / f"{base_name}_rag.faiss")
metadata_path = str(output_path / f"{base_name}_rag_metadata.json")
# 1. 保存为图谱文件
nodes_dict = [node.dict() for node in all_nodes]
edges_dict = [edge.dict() for edge in all_edges]
graph_data = {
"metadata": {
"project_root": target_path if mode == "full_project" else "single_file",
"file_path": target_path if mode == "single_file" else None,
"mode": mode,
"generated_at": datetime.now().isoformat(),
"total_nodes": len(all_nodes),
"total_edges": len(all_edges)
},
"nodes": nodes_dict,
"edges": edges_dict
}
with open(graph_data_path, 'w', encoding='utf-8') as f:
json.dump(graph_data, f, indent=2, ensure_ascii=False, default=str)
logger.info(f" 图谱结构已保存至: {graph_data_path}")
print(f"✓ 图谱结构已保存: {graph_data_path}")
# 2. 构建并保存向量索引
function_nodes = [node for node in all_nodes if node.type == NodeType.FUNCTION and node.embedding]
if function_nodes:
dimension = len(function_nodes[0].embedding)
metadata_for_faiss = []
vectors = []
for node in function_nodes:
vectors.append(node.embedding)
metadata_for_faiss.append(_function_metadata(node, dimension))
if vectors:
if faiss is not None:
if np is None:
raise RuntimeError("numpy is required when writing a FAISS index.")
vectors_np = np.array(vectors, dtype='float32')
index = faiss.IndexFlatL2(dimension)
index.add(vectors_np)
faiss.write_index(index, vector_db_path)
else:
with open(vector_db_path, 'w', encoding='utf-8') as f:
json.dump(
{
"format": "simple_l2_vector_index",
"dimension": dimension,
"vectors": [[float(value) for value in vector] for vector in vectors],
},
f,
ensure_ascii=False,
)
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata_for_faiss, f, indent=2, ensure_ascii=False)
logger.info(f" 向量索引已保存。函数节点数: {len(function_nodes)}")
logger.info(f" 向量数据库: {vector_db_path}")
logger.info(f" 元数据: {metadata_path}")
print(f"✓ 向量索引已保存: {vector_db_path}")
else:
logger.warning(" 没有可生成向量的函数节点跳过FAISS索引构建。")
print("⚠ 无向量数据跳过FAISS索引构建")
return graph_data_path, vector_db_path, metadata_path
def build_code_knowledge_base(target_path, output_dir=None, semantic=True, base_name=None, embedding_dim=1024):
"""Build a code knowledge base without interactive prompts."""
target_path = str(Path(target_path).resolve())
if not os.path.exists(target_path):
raise FileNotFoundError(f"Target path does not exist: {target_path}")
all_nodes = []
all_edges = []
symbol_resolver = SymbolResolver()
function_raw_info_map = {}
FUNCTION_CALL_GRAPH.clear()
FILE_DEPENDENCIES.clear()
CALLED_BY_GRAPH.clear()
if os.path.isdir(target_path):
mode = "full_project"
process_project_files(target_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map)
base_name = base_name or "project"
else:
mode = "single_file"
process_single_file(target_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map)
base_name = base_name or f"file_{Path(target_path).stem}"
build_call_edges(all_nodes, all_edges, symbol_resolver)
if semantic:
semantic_enhancement(all_nodes, all_edges, function_raw_info_map)
else:
for node in all_nodes:
if node.type == NodeType.FUNCTION:
raw_info = function_raw_info_map.get(node.id, {})
node.summary = node.summary or f"Function {node.name}, located at {node.file_path}"
node.logic_flow = node.logic_flow or "Semantic enhancement was skipped."
node.embedding = node.embedding or [0.0] * embedding_dim
node.raw_attributes.update({
"comment": raw_info.get("comment", ""),
"code_snippet": raw_info.get("code_snippet", ""),
"calls": FUNCTION_CALL_GRAPH.get(node.name, []),
"called_by": CALLED_BY_GRAPH.get(node.name, []),
"includes": FILE_DEPENDENCIES.get(node.file_path, []),
"qualified_name": _qualified_name(node),
"embedding_available": False,
"embedding_error": "Semantic enhancement was skipped.",
})
return save_results(all_nodes, all_edges, base_name, target_path, mode, output_dir=output_dir)
def build_rag_database_interactive():
"""交互式知识库构建主流程"""
logger.info("=== 交互式知识图谱构建流程 ===")
# 显示模式选择菜单
print("\n" + "=" * 50)
print("知识库构建模式选择")
print("=" * 50)
print("1. 全项目模式 - 处理整个项目目录")
print("2. 单文件模式 - 处理单个文件")
print("3. 退出")
print("=" * 50)
choice = input("请选择模式 (1-3): ").strip()
if choice == "3":
logger.info("用户选择退出")
return None, None
if choice == "1":
# 全项目模式 - 直接引导用户输入
print("\n" + "=" * 60)
print("全项目知识库构建模式")
print("=" * 60)
# 显示当前config.py中配置的默认路径
current_default_path = str(PROJECT_ROOT) if PROJECT_ROOT else "未配置"
if not current_default_path or current_default_path == ".":
current_default_path = "当前工作目录"
# 直接要求用户输入
while True:
project_prompt = f"请直接输入要构建知识库的项目根目录完整路径\n(当前config.py中的默认路径: {current_default_path}): "
user_input = input(project_prompt).strip()
if not user_input:
# 用户没有输入尝试使用config中的默认路径
if PROJECT_ROOT and os.path.exists(PROJECT_ROOT) and os.path.isdir(PROJECT_ROOT):
target_path = PROJECT_ROOT
print(f"使用config.py中的默认路径: {target_path}")
else:
print("⚠ 您没有输入路径且config.py中的默认路径无效。")
retry = input("是否重新输入? (y/n): ").strip().lower()
if retry != 'y':
print("知识库构建取消。")
return None, None
continue
else:
# 使用用户输入的路径
target_path = Path(user_input)
# 验证路径
if not os.path.exists(target_path):
print(f"⚠ 路径不存在: {target_path}")
retry = input("是否重新输入? (y/n): ").strip().lower()
if retry != 'y':
print("知识库构建取消。")
return None, None
continue
if not os.path.isdir(target_path):
print(f"⚠ 路径不是目录: {target_path}")
retry = input("是否重新输入? (y/n): ").strip().lower()
if retry != 'y':
print("知识库构建取消。")
return None, None
continue
break # 路径有效,退出循环
mode = "full_project"
print(f"\n✓ 将处理整个项目: {target_path}")
print("开始处理...")
elif choice == "2":
# 单文件模式
while True:
file_path = input("\n请输入要处理的文件完整路径: ").strip()
if not file_path:
print("输入为空,退出单文件模式")
return None, None
if not os.path.exists(file_path):
print(f"错误: 文件不存在: {file_path}")
continue
file_ext = Path(file_path).suffix.lower()
if file_ext not in CPP_EXTENSIONS:
print(f"警告: 文件类型 {file_ext} 可能不是C/C++文件,继续吗? (y/n): ", end="")
confirm = input().strip().lower()
if confirm != 'y':
continue
target_path = file_path
mode = "single_file"
print(f"\n将处理单个文件: {target_path}")
print("开始处理...")
break
else:
print("无效的选择,退出")
return None, None
# 确认开始
print("\n即将开始构建知识库,确认继续? (y/n): ", end="")
confirm = input().strip().lower()
if confirm != 'y':
print("操作已取消")
return None, None
# --- 初始化数据结构 ---
all_nodes = [] # 存储 GraphNode 对象
all_edges = [] # 存储 GraphEdge 对象
symbol_resolver = SymbolResolver() # 符号解析器
function_raw_info_map = {} # 临时存储函数的原始信息
# 清空全局图
FUNCTION_CALL_GRAPH.clear()
FILE_DEPENDENCIES.clear()
CALLED_BY_GRAPH.clear()
# ========================
# 阶段一:静态分析,提取节点与硬性边
# ========================
logger.info("阶段一:静态分析,提取代码实体与结构关系...")
if choice == "1":
# 全项目模式
processed_files_count = process_project_files(target_path, all_nodes, all_edges,
symbol_resolver, function_raw_info_map)
logger.info(f"已处理 {processed_files_count} 个文件")
else:
# 单文件模式
process_single_file(target_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map)
logger.info(f"已处理单个文件: {target_path}")
logger.info(f"阶段一完成。初步识别节点数: {len(all_nodes)}")
# ========================
# 阶段二:解析符号,构建准确的调用边
# ========================
build_call_edges(all_nodes, all_edges, symbol_resolver)
# ========================
# 阶段三:语义增强,生成摘要与向量
# ========================
print("\n" + "=" * 50)
print("阶段三:语义增强")
print("=" * 50)
print("此阶段将调用LLM API为函数生成摘要和向量嵌入")
print(f"需要处理的函数节点数: {len([n for n in all_nodes if n.type == NodeType.FUNCTION])}")
print("这可能会消耗API调用额度并需要一些时间")
print("-" * 50)
proceed = input("是否继续阶段三? (y/n): ").strip().lower()
if proceed != 'y':
logger.info("用户跳过阶段三(语义增强)")
print("跳过阶段三,直接进入阶段四")
# 为未处理的函数节点设置默认值
for node in all_nodes:
if node.type == NodeType.FUNCTION and not node.summary:
node.summary = f"函数 {node.name},位于 {node.file_path}"
node.logic_flow = "跳过语义增强阶段"
node.embedding = [0.0] * 1024
raw_info = function_raw_info_map.get(node.id, {})
node.raw_attributes.update({
"comment": raw_info.get("comment", ""),
"code_snippet": raw_info.get("code_snippet", ""),
"calls": FUNCTION_CALL_GRAPH.get(node.name, []),
"called_by": CALLED_BY_GRAPH.get(node.name, []),
"includes": FILE_DEPENDENCIES.get(node.file_path, []),
"qualified_name": _qualified_name(node),
"embedding_available": False,
"embedding_error": "Semantic enhancement was skipped.",
})
else:
semantic_enhancement(all_nodes, all_edges, function_raw_info_map)
# ========================
# 阶段四:持久化存储
# ========================
print("\n" + "=" * 50)
print("阶段四:持久化存储")
print("=" * 50)
save_choice = input("是否保存结果? (y/n): ").strip().lower()
if save_choice != 'y':
logger.info("用户取消保存,退出构建流程")
print("构建流程已取消")
return all_nodes, all_edges
# 确定保存文件名
if choice == "1":
base_name = "project"
else:
file_name = Path(target_path).stem
base_name = f"file_{file_name}"
graph_data_path, vector_db_path, metadata_path = save_results(
all_nodes, all_edges, base_name, target_path, mode
)
print("\n" + "=" * 50)
print("构建完成!")
print("=" * 50)
print(f"总节点数: {len(all_nodes)}")
print(f"总边数: {len(all_edges)}")
print(f"图谱文件: {graph_data_path}")
if os.path.exists(vector_db_path):
print(f"向量文件: {vector_db_path}")
logger.info("=== 知识图谱构建流程全部完成 ===")
return all_nodes, all_edges