826 lines
33 KiB
Python
826 lines
33 KiB
Python
"""
|
||
graph_builder.py - 知识图谱构建的核心流程控制器
|
||
包含完整的构建流程和交互逻辑
|
||
"""
|
||
import os
|
||
import json
|
||
try:
|
||
import faiss
|
||
except ImportError: # pragma: no cover - depends on deployment environment
|
||
faiss = None
|
||
try:
|
||
import numpy as np
|
||
except ImportError: # pragma: no cover - depends on deployment environment
|
||
np = None
|
||
import time
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
import logging
|
||
from tqdm import tqdm
|
||
from collections import defaultdict
|
||
|
||
from models import GraphNode, GraphEdge, NodeType, EdgeType
|
||
from code_parser import (
|
||
parser, CPP_EXTENSIONS, IGNORE_DIRS,
|
||
safe_decode, extract_comment_before, extract_code_snippet,
|
||
extract_full_function_name, collect_call_expressions,
|
||
extract_class_name, extract_member_function_name, extract_base_classes
|
||
)
|
||
from llm_processor import generate_code_logic, generate_code_summary, get_qwen_embedding
|
||
from config import PROJECT_ROOT, MAX_CODE_LENGTH, QWEN_EMBEDDING_MODEL
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 全局图数据结构
|
||
FUNCTION_CALL_GRAPH = defaultdict(list)
|
||
FILE_DEPENDENCIES = defaultdict(list)
|
||
CALLED_BY_GRAPH = defaultdict(list)
|
||
|
||
|
||
def _as_node_type(value):
|
||
return value.value if hasattr(value, "value") else value
|
||
|
||
|
||
def _normalize_output_dir(output_dir):
|
||
if output_dir is None:
|
||
return Path(".")
|
||
path = Path(output_dir)
|
||
path.mkdir(parents=True, exist_ok=True)
|
||
return path
|
||
|
||
|
||
def _qualified_name(node):
|
||
if node.id.startswith("Function:"):
|
||
return node.id[len("Function:"):]
|
||
return node.name
|
||
|
||
|
||
def _has_real_embedding(embedding):
|
||
if not embedding:
|
||
return False
|
||
try:
|
||
return any(abs(float(value)) > 1e-12 for value in embedding)
|
||
except (TypeError, ValueError):
|
||
return False
|
||
|
||
|
||
def _function_metadata(node, embedding_dim=None):
|
||
raw = node.raw_attributes or {}
|
||
embedding_available = raw.get("embedding_available")
|
||
if embedding_available is None:
|
||
embedding_available = _has_real_embedding(node.embedding)
|
||
return {
|
||
"node_id": node.id,
|
||
"name": node.name,
|
||
"function_name": node.name,
|
||
"qualified_name": raw.get("qualified_name") or _qualified_name(node),
|
||
"file": node.file_path,
|
||
"file_path": node.file_path,
|
||
"start_line": node.start_line,
|
||
"end_line": node.end_line,
|
||
"signature": node.signature,
|
||
"summary": node.summary or "",
|
||
"logic_flow": node.logic_flow or "",
|
||
"code_snippet": raw.get("code_snippet", ""),
|
||
"calls": raw.get("calls", []),
|
||
"called_by": raw.get("called_by", []),
|
||
"includes": raw.get("includes", []),
|
||
"embedding_model": QWEN_EMBEDDING_MODEL,
|
||
"embedding_dim": embedding_dim or (len(node.embedding) if node.embedding else 0),
|
||
"embedding_available": bool(embedding_available),
|
||
"embedding_error": raw.get("embedding_error", ""),
|
||
}
|
||
|
||
|
||
class SymbolResolver:
|
||
"""符号解析器:将代码符号映射为图谱节点ID"""
|
||
|
||
def __init__(self):
|
||
self._symbol_table = {}
|
||
self._file_to_entities = defaultdict(list)
|
||
|
||
def register_entity(self, node: GraphNode):
|
||
"""注册新发现的实体到符号表"""
|
||
self._symbol_table[node.name] = node.id
|
||
if node.file_path:
|
||
self._file_to_entities[node.file_path].append(node.id)
|
||
|
||
def resolve_name(self, name: str, current_context: dict) -> str:
|
||
"""
|
||
解析代码中的名称到节点ID
|
||
:param name: 代码中的名称
|
||
:param current_context: 当前上下文(文件、类等)
|
||
:return: 节点ID或None
|
||
"""
|
||
if name in self._symbol_table:
|
||
return self._symbol_table[name]
|
||
|
||
for qualified_name, node_id in self._symbol_table.items():
|
||
if qualified_name.endswith('::' + name) or qualified_name == name:
|
||
return node_id
|
||
|
||
return None
|
||
|
||
|
||
def process_single_file(file_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map):
|
||
"""处理单个C/C++文件,提取代码实体和结构关系"""
|
||
logger.info(f"处理单个文件: {file_path}")
|
||
|
||
try:
|
||
with open(file_path, 'rb') as f:
|
||
source_bytes = f.read()
|
||
source_lines = source_bytes.split(b'\n')
|
||
tree = parser.parse(source_bytes)
|
||
|
||
# 1. 创建文件节点
|
||
file_node_id = f"File:{file_path}"
|
||
file_node = GraphNode(
|
||
id=file_node_id,
|
||
type=NodeType.FILE,
|
||
name=Path(file_path).name,
|
||
file_path=file_path
|
||
)
|
||
all_nodes.append(file_node)
|
||
symbol_resolver.register_entity(file_node)
|
||
|
||
# 2. 收集头文件依赖
|
||
includes = []
|
||
for node in tree.root_node.named_children:
|
||
if node.type == "preproc_include":
|
||
path_node = node.child_by_field_name("path")
|
||
if path_node:
|
||
include_text = path_node.text.decode('utf-8').strip('"<>')
|
||
includes.append(include_text)
|
||
FILE_DEPENDENCIES[file_path] = includes
|
||
|
||
# 3. 遍历文件内的顶级声明
|
||
for node in tree.root_node.named_children:
|
||
# 3.1 处理函数定义
|
||
if node.type == "function_definition":
|
||
func_name = extract_full_function_name(node)
|
||
if func_name == "<unknown>":
|
||
logger.warning(f" 跳过无法解析函数名的节点: {node.text[:50]}...")
|
||
continue
|
||
|
||
# 创建函数节点
|
||
func_node_id = f"Function:{func_name}"
|
||
func_node = GraphNode(
|
||
id=func_node_id,
|
||
type=NodeType.FUNCTION,
|
||
name=func_name,
|
||
signature=node.text.decode('utf-8', errors='ignore').split('{')[0].strip() + ' {...}',
|
||
file_path=file_path,
|
||
start_line=node.start_point[0] + 1,
|
||
end_line=node.end_point[0] + 1,
|
||
raw_attributes={'ast_node_type': 'function_definition'}
|
||
)
|
||
all_nodes.append(func_node)
|
||
symbol_resolver.register_entity(func_node)
|
||
|
||
# 存储原始信息
|
||
comment = extract_comment_before(node, source_lines)
|
||
code_snippet = extract_code_snippet(tree, node, source_lines)
|
||
function_raw_info_map[func_node_id] = {
|
||
'ast_node': node,
|
||
'comment': comment,
|
||
'code_snippet': code_snippet,
|
||
'source_lines': source_lines,
|
||
'tree': tree
|
||
}
|
||
|
||
# 添加 FILE -> CONTAINS -> FUNCTION 边
|
||
all_edges.append(GraphEdge(
|
||
source_id=file_node_id,
|
||
target_id=func_node_id,
|
||
type=EdgeType.CONTAINS
|
||
))
|
||
|
||
# 收集函数调用关系
|
||
callees = collect_call_expressions(node)
|
||
FUNCTION_CALL_GRAPH[func_name] = callees
|
||
|
||
# 3.2 处理类/结构体定义
|
||
elif node.type in ["class_specifier", "struct_specifier"]:
|
||
class_name = extract_class_name(node)
|
||
if not class_name or class_name == "<unknown>":
|
||
logger.debug(f" 跳过无法解析类名的节点")
|
||
continue
|
||
|
||
# 创建类节点
|
||
class_node_id = f"Class:{class_name}"
|
||
class_node = GraphNode(
|
||
id=class_node_id,
|
||
type=NodeType.CLASS,
|
||
name=class_name,
|
||
signature=node.text.decode('utf-8', errors='ignore').split('{')[0].strip() + ' {...}',
|
||
file_path=file_path,
|
||
start_line=node.start_point[0] + 1,
|
||
end_line=node.end_point[0] + 1,
|
||
raw_attributes={
|
||
'ast_type': node.type,
|
||
'is_struct': node.type == "struct_specifier"
|
||
}
|
||
)
|
||
all_nodes.append(class_node)
|
||
symbol_resolver.register_entity(class_node)
|
||
|
||
# 添加 FILE -> CONTAINS -> CLASS 边
|
||
all_edges.append(GraphEdge(
|
||
source_id=file_node_id,
|
||
target_id=class_node_id,
|
||
type=EdgeType.CONTAINS
|
||
))
|
||
|
||
# 处理基类/继承关系
|
||
base_clause = node.child_by_field_name("base_clause")
|
||
if base_clause:
|
||
base_classes = extract_base_classes(base_clause)
|
||
for base_class_info in base_classes:
|
||
base_name = base_class_info.get('name')
|
||
access_type = base_class_info.get('access', 'public')
|
||
edge_type = EdgeType.EXTENDS if access_type in ['public', 'protected',
|
||
'private'] else EdgeType.IMPLEMENTS
|
||
|
||
if base_name:
|
||
base_id = symbol_resolver.resolve_name(
|
||
base_name,
|
||
{'current_file': file_path, 'current_class': class_name}
|
||
)
|
||
if base_id:
|
||
all_edges.append(GraphEdge(
|
||
source_id=class_node_id,
|
||
target_id=base_id,
|
||
type=edge_type,
|
||
properties={'access': access_type}
|
||
))
|
||
|
||
# 处理类内方法
|
||
for member in node.named_children:
|
||
if member.type == "function_definition":
|
||
member_func_name = extract_member_function_name(member, class_name)
|
||
if not member_func_name or member_func_name == "<unknown>":
|
||
continue
|
||
|
||
member_node_id = f"Function:{class_name}::{member_func_name}"
|
||
existing_member_node = next((n for n in all_nodes if n.id == member_node_id), None)
|
||
|
||
if not existing_member_node:
|
||
member_node = GraphNode(
|
||
id=member_node_id,
|
||
type=NodeType.FUNCTION,
|
||
name=member_func_name,
|
||
signature=member.text.decode('utf-8', errors='ignore').split('{')[0].strip() + ' {...}',
|
||
file_path=file_path,
|
||
start_line=member.start_point[0] + 1,
|
||
end_line=member.end_point[0] + 1,
|
||
raw_attributes={
|
||
'is_member': True,
|
||
'parent_class': class_name,
|
||
'ast_node_type': 'member_function_definition'
|
||
}
|
||
)
|
||
all_nodes.append(member_node)
|
||
symbol_resolver.register_entity(member_node)
|
||
|
||
# 存储原始信息
|
||
member_comment = extract_comment_before(member, source_lines)
|
||
member_code = extract_code_snippet(tree, member, source_lines)
|
||
function_raw_info_map[member_node_id] = {
|
||
'ast_node': member,
|
||
'comment': member_comment,
|
||
'code_snippet': member_code,
|
||
'source_lines': source_lines,
|
||
'tree': tree
|
||
}
|
||
|
||
# 收集成员函数调用
|
||
member_callees = collect_call_expressions(member)
|
||
FUNCTION_CALL_GRAPH[f"{class_name}::{member_func_name}"] = member_callees
|
||
|
||
# 添加类包含边
|
||
all_edges.append(GraphEdge(
|
||
source_id=class_node_id,
|
||
target_id=member_node_id,
|
||
type=EdgeType.CONTAINS,
|
||
properties={'member_type': 'method'}
|
||
))
|
||
|
||
except Exception as e:
|
||
logger.error(f"解析文件 {file_path} 失败: {e}", exc_info=True)
|
||
raise
|
||
|
||
|
||
def process_project_files(project_root, all_nodes, all_edges, symbol_resolver, function_raw_info_map):
|
||
"""处理整个项目目录"""
|
||
logger.info(f"处理项目目录: {project_root}")
|
||
|
||
for root, dirs, files in os.walk(project_root):
|
||
# 跳过忽略目录
|
||
dirs[:] = [d for d in dirs if d not in IGNORE_DIRS]
|
||
|
||
for file in files:
|
||
file_path_obj = Path(file)
|
||
if file_path_obj.suffix.lower() not in CPP_EXTENSIONS:
|
||
continue
|
||
|
||
full_file_path = os.path.join(root, file)
|
||
process_single_file(full_file_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map)
|
||
|
||
return len(files)
|
||
|
||
|
||
def semantic_enhancement(all_nodes, all_edges, function_raw_info_map):
|
||
"""阶段三:语义增强,为函数节点生成摘要与向量嵌入"""
|
||
logger.info("阶段三:语义增强,为函数节点生成摘要与向量嵌入...")
|
||
|
||
# 1. 过滤出需要处理的函数节点
|
||
function_nodes = [n for n in all_nodes if n.type == NodeType.FUNCTION]
|
||
nodes_with_raw_info = []
|
||
nodes_without_raw_info = []
|
||
|
||
for func_node in function_nodes:
|
||
if func_node.id in function_raw_info_map:
|
||
nodes_with_raw_info.append(func_node)
|
||
else:
|
||
nodes_without_raw_info.append(func_node)
|
||
|
||
logger.info(f"有原始信息的节点: {len(nodes_with_raw_info)} 个")
|
||
if nodes_without_raw_info:
|
||
logger.warning(f"无原始信息的节点: {len(nodes_without_raw_info)} 个")
|
||
|
||
# 2. 处理有原始信息的节点
|
||
if nodes_with_raw_info:
|
||
print(f"\n开始为 {len(nodes_with_raw_info)} 个函数生成摘要和向量...")
|
||
print("这可能需要一些时间,请耐心等待...")
|
||
|
||
processed_count = 0
|
||
for func_node in tqdm(nodes_with_raw_info, desc="生成函数摘要", unit="个"):
|
||
func_node_id = func_node.id
|
||
func_name = func_node.name
|
||
|
||
# 获取原始信息
|
||
raw_info = function_raw_info_map[func_node_id]
|
||
comment = raw_info['comment']
|
||
code_snippet = raw_info['code_snippet']
|
||
file_path = func_node.file_path
|
||
func_node.raw_attributes.update({
|
||
"comment": comment,
|
||
"code_snippet": code_snippet,
|
||
"includes": FILE_DEPENDENCIES.get(file_path, []),
|
||
"qualified_name": _qualified_name(func_node),
|
||
})
|
||
|
||
# 1. 生成流程逻辑
|
||
try:
|
||
logic = generate_code_logic(func_name, comment, code_snippet, file_path)
|
||
except Exception as e:
|
||
logger.error(f"为函数 {func_name} 生成逻辑失败: {e}")
|
||
logic = "逻辑生成失败。"
|
||
|
||
# 2. 生成功能摘要
|
||
try:
|
||
summary = generate_code_summary(
|
||
func_name, comment, logic, code_snippet, file_path
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"为函数 {func_name} 生成摘要失败: {e}")
|
||
summary = "摘要生成失败。"
|
||
|
||
# 3. 更新节点信息
|
||
func_node.summary = summary
|
||
func_node.logic_flow = logic
|
||
|
||
# 获取调用关系
|
||
called_functions = FUNCTION_CALL_GRAPH.get(func_name, [])
|
||
caller_functions = CALLED_BY_GRAPH.get(func_name, [])
|
||
func_node.raw_attributes.update({
|
||
'calls': called_functions,
|
||
'called_by': caller_functions,
|
||
'includes': FILE_DEPENDENCIES.get(file_path, [])
|
||
})
|
||
|
||
# 4. 生成向量嵌入
|
||
context_text_for_embedding = (
|
||
f"函数名: {func_name}\n"
|
||
f"功能摘要: {summary}\n"
|
||
f"流程逻辑: {logic}\n"
|
||
f"所在文件: {os.path.basename(file_path)}"
|
||
)
|
||
|
||
try:
|
||
embedding_vector = get_qwen_embedding(context_text_for_embedding)
|
||
embedding_values = (
|
||
embedding_vector.tolist()
|
||
if hasattr(embedding_vector, "tolist")
|
||
else list(embedding_vector)
|
||
)
|
||
if not _has_real_embedding(embedding_values):
|
||
raise RuntimeError("Embedding API returned an empty or zero vector.")
|
||
func_node.embedding = embedding_values
|
||
func_node.raw_attributes["embedding_available"] = True
|
||
func_node.raw_attributes["embedding_error"] = ""
|
||
except Exception as e:
|
||
logger.error(f"为函数 {func_name} 生成嵌入失败: {e}")
|
||
func_node.embedding = [0.0] * 1024
|
||
func_node.raw_attributes["embedding_available"] = False
|
||
func_node.raw_attributes["embedding_error"] = str(e)[:500]
|
||
|
||
processed_count += 1
|
||
# 每处理5个函数显示一次进度
|
||
if processed_count % 5 == 0:
|
||
print(f" 已处理 {processed_count}/{len(nodes_with_raw_info)} 个函数")
|
||
|
||
# 添加延迟避免API限流
|
||
time.sleep(0.1)
|
||
|
||
print(f"✓ 完成 {len(nodes_with_raw_info)} 个函数的语义增强")
|
||
else:
|
||
logger.warning("没有找到任何有原始信息的函数节点,跳过摘要生成。")
|
||
|
||
# 3. 处理无原始信息的节点
|
||
if nodes_without_raw_info:
|
||
print(f"\n为 {len(nodes_without_raw_info)} 个无原始信息的节点生成基础信息...")
|
||
for func_node in nodes_without_raw_info:
|
||
func_node.summary = f"函数 {func_node.name},位于 {func_node.file_path}"
|
||
func_node.logic_flow = "无法生成逻辑流程(缺少原始代码信息)"
|
||
func_node.embedding = [0.0] * 1024
|
||
func_node.raw_attributes.update({
|
||
"calls": FUNCTION_CALL_GRAPH.get(func_node.name, []),
|
||
"called_by": CALLED_BY_GRAPH.get(func_node.name, []),
|
||
"includes": FILE_DEPENDENCIES.get(func_node.file_path, []),
|
||
"qualified_name": _qualified_name(func_node),
|
||
"embedding_available": False,
|
||
"embedding_error": "Missing source information.",
|
||
})
|
||
|
||
|
||
def build_call_edges(all_nodes, all_edges, symbol_resolver):
|
||
"""阶段二:构建函数调用边"""
|
||
logger.info("阶段二:解析符号,构建函数调用图 (CALLS 边)...")
|
||
|
||
# 构建反向调用图
|
||
CALLED_BY_GRAPH.clear()
|
||
for caller_name, callee_names in FUNCTION_CALL_GRAPH.items():
|
||
caller_node_id = f"Function:{caller_name}"
|
||
# 确保调用者节点已存在
|
||
if not any(n.id == caller_node_id for n in all_nodes):
|
||
continue
|
||
|
||
for callee_name in callee_names:
|
||
# 通过符号解析器将名称解析为节点ID
|
||
callee_node_id = symbol_resolver.resolve_name(
|
||
callee_name,
|
||
{'current_file': next((n.file_path for n in all_nodes if n.id == caller_node_id), None)}
|
||
)
|
||
if callee_node_id:
|
||
# 添加 CALLS 边
|
||
all_edges.append(GraphEdge(
|
||
source_id=caller_node_id,
|
||
target_id=callee_node_id,
|
||
type=EdgeType.CALLS
|
||
))
|
||
# 更新反向图
|
||
if callee_name not in CALLED_BY_GRAPH:
|
||
CALLED_BY_GRAPH[callee_name] = []
|
||
CALLED_BY_GRAPH[callee_name].append(caller_name)
|
||
|
||
logger.info(f"阶段二完成。已构建边数: {len(all_edges)}")
|
||
|
||
|
||
def save_results(all_nodes, all_edges, base_name, target_path, mode, output_dir=None):
|
||
"""阶段四:持久化存储结果"""
|
||
logger.info("阶段四:持久化图谱与向量索引...")
|
||
|
||
output_path = _normalize_output_dir(output_dir)
|
||
graph_data_path = str(output_path / f"{base_name}_code_knowledge_graph.json")
|
||
vector_db_path = str(output_path / f"{base_name}_rag.faiss")
|
||
metadata_path = str(output_path / f"{base_name}_rag_metadata.json")
|
||
|
||
# 1. 保存为图谱文件
|
||
nodes_dict = [node.dict() for node in all_nodes]
|
||
edges_dict = [edge.dict() for edge in all_edges]
|
||
|
||
graph_data = {
|
||
"metadata": {
|
||
"project_root": target_path if mode == "full_project" else "single_file",
|
||
"file_path": target_path if mode == "single_file" else None,
|
||
"mode": mode,
|
||
"generated_at": datetime.now().isoformat(),
|
||
"total_nodes": len(all_nodes),
|
||
"total_edges": len(all_edges)
|
||
},
|
||
"nodes": nodes_dict,
|
||
"edges": edges_dict
|
||
}
|
||
|
||
with open(graph_data_path, 'w', encoding='utf-8') as f:
|
||
json.dump(graph_data, f, indent=2, ensure_ascii=False, default=str)
|
||
logger.info(f" 图谱结构已保存至: {graph_data_path}")
|
||
print(f"✓ 图谱结构已保存: {graph_data_path}")
|
||
|
||
# 2. 构建并保存向量索引
|
||
function_nodes = [node for node in all_nodes if node.type == NodeType.FUNCTION and node.embedding]
|
||
if function_nodes:
|
||
dimension = len(function_nodes[0].embedding)
|
||
metadata_for_faiss = []
|
||
|
||
vectors = []
|
||
for node in function_nodes:
|
||
vectors.append(node.embedding)
|
||
metadata_for_faiss.append(_function_metadata(node, dimension))
|
||
|
||
if vectors:
|
||
if faiss is not None:
|
||
if np is None:
|
||
raise RuntimeError("numpy is required when writing a FAISS index.")
|
||
vectors_np = np.array(vectors, dtype='float32')
|
||
index = faiss.IndexFlatL2(dimension)
|
||
index.add(vectors_np)
|
||
faiss.write_index(index, vector_db_path)
|
||
else:
|
||
with open(vector_db_path, 'w', encoding='utf-8') as f:
|
||
json.dump(
|
||
{
|
||
"format": "simple_l2_vector_index",
|
||
"dimension": dimension,
|
||
"vectors": [[float(value) for value in vector] for vector in vectors],
|
||
},
|
||
f,
|
||
ensure_ascii=False,
|
||
)
|
||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||
json.dump(metadata_for_faiss, f, indent=2, ensure_ascii=False)
|
||
logger.info(f" 向量索引已保存。函数节点数: {len(function_nodes)}")
|
||
logger.info(f" 向量数据库: {vector_db_path}")
|
||
logger.info(f" 元数据: {metadata_path}")
|
||
print(f"✓ 向量索引已保存: {vector_db_path}")
|
||
else:
|
||
logger.warning(" 没有可生成向量的函数节点,跳过FAISS索引构建。")
|
||
print("⚠ 无向量数据,跳过FAISS索引构建")
|
||
|
||
return graph_data_path, vector_db_path, metadata_path
|
||
|
||
|
||
def build_code_knowledge_base(target_path, output_dir=None, semantic=True, base_name=None, embedding_dim=1024):
|
||
"""Build a code knowledge base without interactive prompts."""
|
||
target_path = str(Path(target_path).resolve())
|
||
if not os.path.exists(target_path):
|
||
raise FileNotFoundError(f"Target path does not exist: {target_path}")
|
||
|
||
all_nodes = []
|
||
all_edges = []
|
||
symbol_resolver = SymbolResolver()
|
||
function_raw_info_map = {}
|
||
|
||
FUNCTION_CALL_GRAPH.clear()
|
||
FILE_DEPENDENCIES.clear()
|
||
CALLED_BY_GRAPH.clear()
|
||
|
||
if os.path.isdir(target_path):
|
||
mode = "full_project"
|
||
process_project_files(target_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map)
|
||
base_name = base_name or "project"
|
||
else:
|
||
mode = "single_file"
|
||
process_single_file(target_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map)
|
||
base_name = base_name or f"file_{Path(target_path).stem}"
|
||
|
||
build_call_edges(all_nodes, all_edges, symbol_resolver)
|
||
|
||
if semantic:
|
||
semantic_enhancement(all_nodes, all_edges, function_raw_info_map)
|
||
else:
|
||
for node in all_nodes:
|
||
if node.type == NodeType.FUNCTION:
|
||
raw_info = function_raw_info_map.get(node.id, {})
|
||
node.summary = node.summary or f"Function {node.name}, located at {node.file_path}"
|
||
node.logic_flow = node.logic_flow or "Semantic enhancement was skipped."
|
||
node.embedding = node.embedding or [0.0] * embedding_dim
|
||
node.raw_attributes.update({
|
||
"comment": raw_info.get("comment", ""),
|
||
"code_snippet": raw_info.get("code_snippet", ""),
|
||
"calls": FUNCTION_CALL_GRAPH.get(node.name, []),
|
||
"called_by": CALLED_BY_GRAPH.get(node.name, []),
|
||
"includes": FILE_DEPENDENCIES.get(node.file_path, []),
|
||
"qualified_name": _qualified_name(node),
|
||
"embedding_available": False,
|
||
"embedding_error": "Semantic enhancement was skipped.",
|
||
})
|
||
|
||
return save_results(all_nodes, all_edges, base_name, target_path, mode, output_dir=output_dir)
|
||
|
||
|
||
def build_rag_database_interactive():
|
||
"""交互式知识库构建主流程"""
|
||
logger.info("=== 交互式知识图谱构建流程 ===")
|
||
|
||
# 显示模式选择菜单
|
||
print("\n" + "=" * 50)
|
||
print("知识库构建模式选择")
|
||
print("=" * 50)
|
||
print("1. 全项目模式 - 处理整个项目目录")
|
||
print("2. 单文件模式 - 处理单个文件")
|
||
print("3. 退出")
|
||
print("=" * 50)
|
||
|
||
choice = input("请选择模式 (1-3): ").strip()
|
||
|
||
if choice == "3":
|
||
logger.info("用户选择退出")
|
||
return None, None
|
||
|
||
if choice == "1":
|
||
# 全项目模式 - 直接引导用户输入
|
||
print("\n" + "=" * 60)
|
||
print("全项目知识库构建模式")
|
||
print("=" * 60)
|
||
|
||
# 显示当前config.py中配置的默认路径
|
||
current_default_path = str(PROJECT_ROOT) if PROJECT_ROOT else "未配置"
|
||
if not current_default_path or current_default_path == ".":
|
||
current_default_path = "当前工作目录"
|
||
|
||
# 直接要求用户输入
|
||
while True:
|
||
project_prompt = f"请直接输入要构建知识库的项目根目录完整路径\n(当前config.py中的默认路径: {current_default_path}): "
|
||
user_input = input(project_prompt).strip()
|
||
|
||
if not user_input:
|
||
# 用户没有输入,尝试使用config中的默认路径
|
||
if PROJECT_ROOT and os.path.exists(PROJECT_ROOT) and os.path.isdir(PROJECT_ROOT):
|
||
target_path = PROJECT_ROOT
|
||
print(f"使用config.py中的默认路径: {target_path}")
|
||
else:
|
||
print("⚠ 您没有输入路径,且config.py中的默认路径无效。")
|
||
retry = input("是否重新输入? (y/n): ").strip().lower()
|
||
if retry != 'y':
|
||
print("知识库构建取消。")
|
||
return None, None
|
||
continue
|
||
else:
|
||
# 使用用户输入的路径
|
||
target_path = Path(user_input)
|
||
|
||
# 验证路径
|
||
if not os.path.exists(target_path):
|
||
print(f"⚠ 路径不存在: {target_path}")
|
||
retry = input("是否重新输入? (y/n): ").strip().lower()
|
||
if retry != 'y':
|
||
print("知识库构建取消。")
|
||
return None, None
|
||
continue
|
||
|
||
if not os.path.isdir(target_path):
|
||
print(f"⚠ 路径不是目录: {target_path}")
|
||
retry = input("是否重新输入? (y/n): ").strip().lower()
|
||
if retry != 'y':
|
||
print("知识库构建取消。")
|
||
return None, None
|
||
continue
|
||
|
||
break # 路径有效,退出循环
|
||
|
||
mode = "full_project"
|
||
print(f"\n✓ 将处理整个项目: {target_path}")
|
||
print("开始处理...")
|
||
|
||
elif choice == "2":
|
||
# 单文件模式
|
||
while True:
|
||
file_path = input("\n请输入要处理的文件完整路径: ").strip()
|
||
if not file_path:
|
||
print("输入为空,退出单文件模式")
|
||
return None, None
|
||
|
||
if not os.path.exists(file_path):
|
||
print(f"错误: 文件不存在: {file_path}")
|
||
continue
|
||
|
||
file_ext = Path(file_path).suffix.lower()
|
||
if file_ext not in CPP_EXTENSIONS:
|
||
print(f"警告: 文件类型 {file_ext} 可能不是C/C++文件,继续吗? (y/n): ", end="")
|
||
confirm = input().strip().lower()
|
||
if confirm != 'y':
|
||
continue
|
||
|
||
target_path = file_path
|
||
mode = "single_file"
|
||
print(f"\n将处理单个文件: {target_path}")
|
||
print("开始处理...")
|
||
break
|
||
else:
|
||
print("无效的选择,退出")
|
||
return None, None
|
||
|
||
# 确认开始
|
||
print("\n即将开始构建知识库,确认继续? (y/n): ", end="")
|
||
confirm = input().strip().lower()
|
||
if confirm != 'y':
|
||
print("操作已取消")
|
||
return None, None
|
||
|
||
# --- 初始化数据结构 ---
|
||
all_nodes = [] # 存储 GraphNode 对象
|
||
all_edges = [] # 存储 GraphEdge 对象
|
||
symbol_resolver = SymbolResolver() # 符号解析器
|
||
function_raw_info_map = {} # 临时存储函数的原始信息
|
||
|
||
# 清空全局图
|
||
FUNCTION_CALL_GRAPH.clear()
|
||
FILE_DEPENDENCIES.clear()
|
||
CALLED_BY_GRAPH.clear()
|
||
|
||
# ========================
|
||
# 阶段一:静态分析,提取节点与硬性边
|
||
# ========================
|
||
logger.info("阶段一:静态分析,提取代码实体与结构关系...")
|
||
|
||
if choice == "1":
|
||
# 全项目模式
|
||
processed_files_count = process_project_files(target_path, all_nodes, all_edges,
|
||
symbol_resolver, function_raw_info_map)
|
||
logger.info(f"已处理 {processed_files_count} 个文件")
|
||
else:
|
||
# 单文件模式
|
||
process_single_file(target_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map)
|
||
logger.info(f"已处理单个文件: {target_path}")
|
||
|
||
logger.info(f"阶段一完成。初步识别节点数: {len(all_nodes)}")
|
||
|
||
# ========================
|
||
# 阶段二:解析符号,构建准确的调用边
|
||
# ========================
|
||
build_call_edges(all_nodes, all_edges, symbol_resolver)
|
||
|
||
# ========================
|
||
# 阶段三:语义增强,生成摘要与向量
|
||
# ========================
|
||
print("\n" + "=" * 50)
|
||
print("阶段三:语义增强")
|
||
print("=" * 50)
|
||
print("此阶段将调用LLM API为函数生成摘要和向量嵌入")
|
||
print(f"需要处理的函数节点数: {len([n for n in all_nodes if n.type == NodeType.FUNCTION])}")
|
||
print("这可能会消耗API调用额度并需要一些时间")
|
||
print("-" * 50)
|
||
|
||
proceed = input("是否继续阶段三? (y/n): ").strip().lower()
|
||
if proceed != 'y':
|
||
logger.info("用户跳过阶段三(语义增强)")
|
||
print("跳过阶段三,直接进入阶段四")
|
||
|
||
# 为未处理的函数节点设置默认值
|
||
for node in all_nodes:
|
||
if node.type == NodeType.FUNCTION and not node.summary:
|
||
node.summary = f"函数 {node.name},位于 {node.file_path}"
|
||
node.logic_flow = "跳过语义增强阶段"
|
||
node.embedding = [0.0] * 1024
|
||
raw_info = function_raw_info_map.get(node.id, {})
|
||
node.raw_attributes.update({
|
||
"comment": raw_info.get("comment", ""),
|
||
"code_snippet": raw_info.get("code_snippet", ""),
|
||
"calls": FUNCTION_CALL_GRAPH.get(node.name, []),
|
||
"called_by": CALLED_BY_GRAPH.get(node.name, []),
|
||
"includes": FILE_DEPENDENCIES.get(node.file_path, []),
|
||
"qualified_name": _qualified_name(node),
|
||
"embedding_available": False,
|
||
"embedding_error": "Semantic enhancement was skipped.",
|
||
})
|
||
else:
|
||
semantic_enhancement(all_nodes, all_edges, function_raw_info_map)
|
||
|
||
# ========================
|
||
# 阶段四:持久化存储
|
||
# ========================
|
||
print("\n" + "=" * 50)
|
||
print("阶段四:持久化存储")
|
||
print("=" * 50)
|
||
|
||
save_choice = input("是否保存结果? (y/n): ").strip().lower()
|
||
if save_choice != 'y':
|
||
logger.info("用户取消保存,退出构建流程")
|
||
print("构建流程已取消")
|
||
return all_nodes, all_edges
|
||
|
||
# 确定保存文件名
|
||
if choice == "1":
|
||
base_name = "project"
|
||
else:
|
||
file_name = Path(target_path).stem
|
||
base_name = f"file_{file_name}"
|
||
|
||
graph_data_path, vector_db_path, metadata_path = save_results(
|
||
all_nodes, all_edges, base_name, target_path, mode
|
||
)
|
||
|
||
print("\n" + "=" * 50)
|
||
print("构建完成!")
|
||
print("=" * 50)
|
||
print(f"总节点数: {len(all_nodes)}")
|
||
print(f"总边数: {len(all_edges)}")
|
||
print(f"图谱文件: {graph_data_path}")
|
||
if os.path.exists(vector_db_path):
|
||
print(f"向量文件: {vector_db_path}")
|
||
|
||
logger.info("=== 知识图谱构建流程全部完成 ===")
|
||
return all_nodes, all_edges
|