826 lines
33 KiB
Python
826 lines
33 KiB
Python
|
|
"""
|
|||
|
|
graph_builder.py - 知识图谱构建的核心流程控制器
|
|||
|
|
包含完整的构建流程和交互逻辑
|
|||
|
|
"""
|
|||
|
|
import os
|
|||
|
|
import json
|
|||
|
|
try:
|
|||
|
|
import faiss
|
|||
|
|
except ImportError: # pragma: no cover - depends on deployment environment
|
|||
|
|
faiss = None
|
|||
|
|
try:
|
|||
|
|
import numpy as np
|
|||
|
|
except ImportError: # pragma: no cover - depends on deployment environment
|
|||
|
|
np = None
|
|||
|
|
import time
|
|||
|
|
from datetime import datetime
|
|||
|
|
from pathlib import Path
|
|||
|
|
import logging
|
|||
|
|
from tqdm import tqdm
|
|||
|
|
from collections import defaultdict
|
|||
|
|
|
|||
|
|
from models import GraphNode, GraphEdge, NodeType, EdgeType
|
|||
|
|
from code_parser import (
|
|||
|
|
parser, CPP_EXTENSIONS, IGNORE_DIRS,
|
|||
|
|
safe_decode, extract_comment_before, extract_code_snippet,
|
|||
|
|
extract_full_function_name, collect_call_expressions,
|
|||
|
|
extract_class_name, extract_member_function_name, extract_base_classes
|
|||
|
|
)
|
|||
|
|
from llm_processor import generate_code_logic, generate_code_summary, get_qwen_embedding
|
|||
|
|
from config import PROJECT_ROOT, MAX_CODE_LENGTH, QWEN_EMBEDDING_MODEL
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
# 全局图数据结构
|
|||
|
|
FUNCTION_CALL_GRAPH = defaultdict(list)
|
|||
|
|
FILE_DEPENDENCIES = defaultdict(list)
|
|||
|
|
CALLED_BY_GRAPH = defaultdict(list)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _as_node_type(value):
|
|||
|
|
return value.value if hasattr(value, "value") else value
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _normalize_output_dir(output_dir):
|
|||
|
|
if output_dir is None:
|
|||
|
|
return Path(".")
|
|||
|
|
path = Path(output_dir)
|
|||
|
|
path.mkdir(parents=True, exist_ok=True)
|
|||
|
|
return path
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _qualified_name(node):
|
|||
|
|
if node.id.startswith("Function:"):
|
|||
|
|
return node.id[len("Function:"):]
|
|||
|
|
return node.name
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _has_real_embedding(embedding):
|
|||
|
|
if not embedding:
|
|||
|
|
return False
|
|||
|
|
try:
|
|||
|
|
return any(abs(float(value)) > 1e-12 for value in embedding)
|
|||
|
|
except (TypeError, ValueError):
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _function_metadata(node, embedding_dim=None):
|
|||
|
|
raw = node.raw_attributes or {}
|
|||
|
|
embedding_available = raw.get("embedding_available")
|
|||
|
|
if embedding_available is None:
|
|||
|
|
embedding_available = _has_real_embedding(node.embedding)
|
|||
|
|
return {
|
|||
|
|
"node_id": node.id,
|
|||
|
|
"name": node.name,
|
|||
|
|
"function_name": node.name,
|
|||
|
|
"qualified_name": raw.get("qualified_name") or _qualified_name(node),
|
|||
|
|
"file": node.file_path,
|
|||
|
|
"file_path": node.file_path,
|
|||
|
|
"start_line": node.start_line,
|
|||
|
|
"end_line": node.end_line,
|
|||
|
|
"signature": node.signature,
|
|||
|
|
"summary": node.summary or "",
|
|||
|
|
"logic_flow": node.logic_flow or "",
|
|||
|
|
"code_snippet": raw.get("code_snippet", ""),
|
|||
|
|
"calls": raw.get("calls", []),
|
|||
|
|
"called_by": raw.get("called_by", []),
|
|||
|
|
"includes": raw.get("includes", []),
|
|||
|
|
"embedding_model": QWEN_EMBEDDING_MODEL,
|
|||
|
|
"embedding_dim": embedding_dim or (len(node.embedding) if node.embedding else 0),
|
|||
|
|
"embedding_available": bool(embedding_available),
|
|||
|
|
"embedding_error": raw.get("embedding_error", ""),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SymbolResolver:
|
|||
|
|
"""符号解析器:将代码符号映射为图谱节点ID"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self._symbol_table = {}
|
|||
|
|
self._file_to_entities = defaultdict(list)
|
|||
|
|
|
|||
|
|
def register_entity(self, node: GraphNode):
|
|||
|
|
"""注册新发现的实体到符号表"""
|
|||
|
|
self._symbol_table[node.name] = node.id
|
|||
|
|
if node.file_path:
|
|||
|
|
self._file_to_entities[node.file_path].append(node.id)
|
|||
|
|
|
|||
|
|
def resolve_name(self, name: str, current_context: dict) -> str:
|
|||
|
|
"""
|
|||
|
|
解析代码中的名称到节点ID
|
|||
|
|
:param name: 代码中的名称
|
|||
|
|
:param current_context: 当前上下文(文件、类等)
|
|||
|
|
:return: 节点ID或None
|
|||
|
|
"""
|
|||
|
|
if name in self._symbol_table:
|
|||
|
|
return self._symbol_table[name]
|
|||
|
|
|
|||
|
|
for qualified_name, node_id in self._symbol_table.items():
|
|||
|
|
if qualified_name.endswith('::' + name) or qualified_name == name:
|
|||
|
|
return node_id
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def process_single_file(file_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map):
|
|||
|
|
"""处理单个C/C++文件,提取代码实体和结构关系"""
|
|||
|
|
logger.info(f"处理单个文件: {file_path}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with open(file_path, 'rb') as f:
|
|||
|
|
source_bytes = f.read()
|
|||
|
|
source_lines = source_bytes.split(b'\n')
|
|||
|
|
tree = parser.parse(source_bytes)
|
|||
|
|
|
|||
|
|
# 1. 创建文件节点
|
|||
|
|
file_node_id = f"File:{file_path}"
|
|||
|
|
file_node = GraphNode(
|
|||
|
|
id=file_node_id,
|
|||
|
|
type=NodeType.FILE,
|
|||
|
|
name=Path(file_path).name,
|
|||
|
|
file_path=file_path
|
|||
|
|
)
|
|||
|
|
all_nodes.append(file_node)
|
|||
|
|
symbol_resolver.register_entity(file_node)
|
|||
|
|
|
|||
|
|
# 2. 收集头文件依赖
|
|||
|
|
includes = []
|
|||
|
|
for node in tree.root_node.named_children:
|
|||
|
|
if node.type == "preproc_include":
|
|||
|
|
path_node = node.child_by_field_name("path")
|
|||
|
|
if path_node:
|
|||
|
|
include_text = path_node.text.decode('utf-8').strip('"<>')
|
|||
|
|
includes.append(include_text)
|
|||
|
|
FILE_DEPENDENCIES[file_path] = includes
|
|||
|
|
|
|||
|
|
# 3. 遍历文件内的顶级声明
|
|||
|
|
for node in tree.root_node.named_children:
|
|||
|
|
# 3.1 处理函数定义
|
|||
|
|
if node.type == "function_definition":
|
|||
|
|
func_name = extract_full_function_name(node)
|
|||
|
|
if func_name == "<unknown>":
|
|||
|
|
logger.warning(f" 跳过无法解析函数名的节点: {node.text[:50]}...")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 创建函数节点
|
|||
|
|
func_node_id = f"Function:{func_name}"
|
|||
|
|
func_node = GraphNode(
|
|||
|
|
id=func_node_id,
|
|||
|
|
type=NodeType.FUNCTION,
|
|||
|
|
name=func_name,
|
|||
|
|
signature=node.text.decode('utf-8', errors='ignore').split('{')[0].strip() + ' {...}',
|
|||
|
|
file_path=file_path,
|
|||
|
|
start_line=node.start_point[0] + 1,
|
|||
|
|
end_line=node.end_point[0] + 1,
|
|||
|
|
raw_attributes={'ast_node_type': 'function_definition'}
|
|||
|
|
)
|
|||
|
|
all_nodes.append(func_node)
|
|||
|
|
symbol_resolver.register_entity(func_node)
|
|||
|
|
|
|||
|
|
# 存储原始信息
|
|||
|
|
comment = extract_comment_before(node, source_lines)
|
|||
|
|
code_snippet = extract_code_snippet(tree, node, source_lines)
|
|||
|
|
function_raw_info_map[func_node_id] = {
|
|||
|
|
'ast_node': node,
|
|||
|
|
'comment': comment,
|
|||
|
|
'code_snippet': code_snippet,
|
|||
|
|
'source_lines': source_lines,
|
|||
|
|
'tree': tree
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 添加 FILE -> CONTAINS -> FUNCTION 边
|
|||
|
|
all_edges.append(GraphEdge(
|
|||
|
|
source_id=file_node_id,
|
|||
|
|
target_id=func_node_id,
|
|||
|
|
type=EdgeType.CONTAINS
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
# 收集函数调用关系
|
|||
|
|
callees = collect_call_expressions(node)
|
|||
|
|
FUNCTION_CALL_GRAPH[func_name] = callees
|
|||
|
|
|
|||
|
|
# 3.2 处理类/结构体定义
|
|||
|
|
elif node.type in ["class_specifier", "struct_specifier"]:
|
|||
|
|
class_name = extract_class_name(node)
|
|||
|
|
if not class_name or class_name == "<unknown>":
|
|||
|
|
logger.debug(f" 跳过无法解析类名的节点")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 创建类节点
|
|||
|
|
class_node_id = f"Class:{class_name}"
|
|||
|
|
class_node = GraphNode(
|
|||
|
|
id=class_node_id,
|
|||
|
|
type=NodeType.CLASS,
|
|||
|
|
name=class_name,
|
|||
|
|
signature=node.text.decode('utf-8', errors='ignore').split('{')[0].strip() + ' {...}',
|
|||
|
|
file_path=file_path,
|
|||
|
|
start_line=node.start_point[0] + 1,
|
|||
|
|
end_line=node.end_point[0] + 1,
|
|||
|
|
raw_attributes={
|
|||
|
|
'ast_type': node.type,
|
|||
|
|
'is_struct': node.type == "struct_specifier"
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
all_nodes.append(class_node)
|
|||
|
|
symbol_resolver.register_entity(class_node)
|
|||
|
|
|
|||
|
|
# 添加 FILE -> CONTAINS -> CLASS 边
|
|||
|
|
all_edges.append(GraphEdge(
|
|||
|
|
source_id=file_node_id,
|
|||
|
|
target_id=class_node_id,
|
|||
|
|
type=EdgeType.CONTAINS
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
# 处理基类/继承关系
|
|||
|
|
base_clause = node.child_by_field_name("base_clause")
|
|||
|
|
if base_clause:
|
|||
|
|
base_classes = extract_base_classes(base_clause)
|
|||
|
|
for base_class_info in base_classes:
|
|||
|
|
base_name = base_class_info.get('name')
|
|||
|
|
access_type = base_class_info.get('access', 'public')
|
|||
|
|
edge_type = EdgeType.EXTENDS if access_type in ['public', 'protected',
|
|||
|
|
'private'] else EdgeType.IMPLEMENTS
|
|||
|
|
|
|||
|
|
if base_name:
|
|||
|
|
base_id = symbol_resolver.resolve_name(
|
|||
|
|
base_name,
|
|||
|
|
{'current_file': file_path, 'current_class': class_name}
|
|||
|
|
)
|
|||
|
|
if base_id:
|
|||
|
|
all_edges.append(GraphEdge(
|
|||
|
|
source_id=class_node_id,
|
|||
|
|
target_id=base_id,
|
|||
|
|
type=edge_type,
|
|||
|
|
properties={'access': access_type}
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
# 处理类内方法
|
|||
|
|
for member in node.named_children:
|
|||
|
|
if member.type == "function_definition":
|
|||
|
|
member_func_name = extract_member_function_name(member, class_name)
|
|||
|
|
if not member_func_name or member_func_name == "<unknown>":
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
member_node_id = f"Function:{class_name}::{member_func_name}"
|
|||
|
|
existing_member_node = next((n for n in all_nodes if n.id == member_node_id), None)
|
|||
|
|
|
|||
|
|
if not existing_member_node:
|
|||
|
|
member_node = GraphNode(
|
|||
|
|
id=member_node_id,
|
|||
|
|
type=NodeType.FUNCTION,
|
|||
|
|
name=member_func_name,
|
|||
|
|
signature=member.text.decode('utf-8', errors='ignore').split('{')[0].strip() + ' {...}',
|
|||
|
|
file_path=file_path,
|
|||
|
|
start_line=member.start_point[0] + 1,
|
|||
|
|
end_line=member.end_point[0] + 1,
|
|||
|
|
raw_attributes={
|
|||
|
|
'is_member': True,
|
|||
|
|
'parent_class': class_name,
|
|||
|
|
'ast_node_type': 'member_function_definition'
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
all_nodes.append(member_node)
|
|||
|
|
symbol_resolver.register_entity(member_node)
|
|||
|
|
|
|||
|
|
# 存储原始信息
|
|||
|
|
member_comment = extract_comment_before(member, source_lines)
|
|||
|
|
member_code = extract_code_snippet(tree, member, source_lines)
|
|||
|
|
function_raw_info_map[member_node_id] = {
|
|||
|
|
'ast_node': member,
|
|||
|
|
'comment': member_comment,
|
|||
|
|
'code_snippet': member_code,
|
|||
|
|
'source_lines': source_lines,
|
|||
|
|
'tree': tree
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 收集成员函数调用
|
|||
|
|
member_callees = collect_call_expressions(member)
|
|||
|
|
FUNCTION_CALL_GRAPH[f"{class_name}::{member_func_name}"] = member_callees
|
|||
|
|
|
|||
|
|
# 添加类包含边
|
|||
|
|
all_edges.append(GraphEdge(
|
|||
|
|
source_id=class_node_id,
|
|||
|
|
target_id=member_node_id,
|
|||
|
|
type=EdgeType.CONTAINS,
|
|||
|
|
properties={'member_type': 'method'}
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"解析文件 {file_path} 失败: {e}", exc_info=True)
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
|
|||
|
|
def process_project_files(project_root, all_nodes, all_edges, symbol_resolver, function_raw_info_map):
|
|||
|
|
"""处理整个项目目录"""
|
|||
|
|
logger.info(f"处理项目目录: {project_root}")
|
|||
|
|
|
|||
|
|
for root, dirs, files in os.walk(project_root):
|
|||
|
|
# 跳过忽略目录
|
|||
|
|
dirs[:] = [d for d in dirs if d not in IGNORE_DIRS]
|
|||
|
|
|
|||
|
|
for file in files:
|
|||
|
|
file_path_obj = Path(file)
|
|||
|
|
if file_path_obj.suffix.lower() not in CPP_EXTENSIONS:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
full_file_path = os.path.join(root, file)
|
|||
|
|
process_single_file(full_file_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map)
|
|||
|
|
|
|||
|
|
return len(files)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def semantic_enhancement(all_nodes, all_edges, function_raw_info_map):
|
|||
|
|
"""阶段三:语义增强,为函数节点生成摘要与向量嵌入"""
|
|||
|
|
logger.info("阶段三:语义增强,为函数节点生成摘要与向量嵌入...")
|
|||
|
|
|
|||
|
|
# 1. 过滤出需要处理的函数节点
|
|||
|
|
function_nodes = [n for n in all_nodes if n.type == NodeType.FUNCTION]
|
|||
|
|
nodes_with_raw_info = []
|
|||
|
|
nodes_without_raw_info = []
|
|||
|
|
|
|||
|
|
for func_node in function_nodes:
|
|||
|
|
if func_node.id in function_raw_info_map:
|
|||
|
|
nodes_with_raw_info.append(func_node)
|
|||
|
|
else:
|
|||
|
|
nodes_without_raw_info.append(func_node)
|
|||
|
|
|
|||
|
|
logger.info(f"有原始信息的节点: {len(nodes_with_raw_info)} 个")
|
|||
|
|
if nodes_without_raw_info:
|
|||
|
|
logger.warning(f"无原始信息的节点: {len(nodes_without_raw_info)} 个")
|
|||
|
|
|
|||
|
|
# 2. 处理有原始信息的节点
|
|||
|
|
if nodes_with_raw_info:
|
|||
|
|
print(f"\n开始为 {len(nodes_with_raw_info)} 个函数生成摘要和向量...")
|
|||
|
|
print("这可能需要一些时间,请耐心等待...")
|
|||
|
|
|
|||
|
|
processed_count = 0
|
|||
|
|
for func_node in tqdm(nodes_with_raw_info, desc="生成函数摘要", unit="个"):
|
|||
|
|
func_node_id = func_node.id
|
|||
|
|
func_name = func_node.name
|
|||
|
|
|
|||
|
|
# 获取原始信息
|
|||
|
|
raw_info = function_raw_info_map[func_node_id]
|
|||
|
|
comment = raw_info['comment']
|
|||
|
|
code_snippet = raw_info['code_snippet']
|
|||
|
|
file_path = func_node.file_path
|
|||
|
|
func_node.raw_attributes.update({
|
|||
|
|
"comment": comment,
|
|||
|
|
"code_snippet": code_snippet,
|
|||
|
|
"includes": FILE_DEPENDENCIES.get(file_path, []),
|
|||
|
|
"qualified_name": _qualified_name(func_node),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 1. 生成流程逻辑
|
|||
|
|
try:
|
|||
|
|
logic = generate_code_logic(func_name, comment, code_snippet, file_path)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"为函数 {func_name} 生成逻辑失败: {e}")
|
|||
|
|
logic = "逻辑生成失败。"
|
|||
|
|
|
|||
|
|
# 2. 生成功能摘要
|
|||
|
|
try:
|
|||
|
|
summary = generate_code_summary(
|
|||
|
|
func_name, comment, logic, code_snippet, file_path
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"为函数 {func_name} 生成摘要失败: {e}")
|
|||
|
|
summary = "摘要生成失败。"
|
|||
|
|
|
|||
|
|
# 3. 更新节点信息
|
|||
|
|
func_node.summary = summary
|
|||
|
|
func_node.logic_flow = logic
|
|||
|
|
|
|||
|
|
# 获取调用关系
|
|||
|
|
called_functions = FUNCTION_CALL_GRAPH.get(func_name, [])
|
|||
|
|
caller_functions = CALLED_BY_GRAPH.get(func_name, [])
|
|||
|
|
func_node.raw_attributes.update({
|
|||
|
|
'calls': called_functions,
|
|||
|
|
'called_by': caller_functions,
|
|||
|
|
'includes': FILE_DEPENDENCIES.get(file_path, [])
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 4. 生成向量嵌入
|
|||
|
|
context_text_for_embedding = (
|
|||
|
|
f"函数名: {func_name}\n"
|
|||
|
|
f"功能摘要: {summary}\n"
|
|||
|
|
f"流程逻辑: {logic}\n"
|
|||
|
|
f"所在文件: {os.path.basename(file_path)}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
embedding_vector = get_qwen_embedding(context_text_for_embedding)
|
|||
|
|
embedding_values = (
|
|||
|
|
embedding_vector.tolist()
|
|||
|
|
if hasattr(embedding_vector, "tolist")
|
|||
|
|
else list(embedding_vector)
|
|||
|
|
)
|
|||
|
|
if not _has_real_embedding(embedding_values):
|
|||
|
|
raise RuntimeError("Embedding API returned an empty or zero vector.")
|
|||
|
|
func_node.embedding = embedding_values
|
|||
|
|
func_node.raw_attributes["embedding_available"] = True
|
|||
|
|
func_node.raw_attributes["embedding_error"] = ""
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"为函数 {func_name} 生成嵌入失败: {e}")
|
|||
|
|
func_node.embedding = [0.0] * 1024
|
|||
|
|
func_node.raw_attributes["embedding_available"] = False
|
|||
|
|
func_node.raw_attributes["embedding_error"] = str(e)[:500]
|
|||
|
|
|
|||
|
|
processed_count += 1
|
|||
|
|
# 每处理5个函数显示一次进度
|
|||
|
|
if processed_count % 5 == 0:
|
|||
|
|
print(f" 已处理 {processed_count}/{len(nodes_with_raw_info)} 个函数")
|
|||
|
|
|
|||
|
|
# 添加延迟避免API限流
|
|||
|
|
time.sleep(0.1)
|
|||
|
|
|
|||
|
|
print(f"✓ 完成 {len(nodes_with_raw_info)} 个函数的语义增强")
|
|||
|
|
else:
|
|||
|
|
logger.warning("没有找到任何有原始信息的函数节点,跳过摘要生成。")
|
|||
|
|
|
|||
|
|
# 3. 处理无原始信息的节点
|
|||
|
|
if nodes_without_raw_info:
|
|||
|
|
print(f"\n为 {len(nodes_without_raw_info)} 个无原始信息的节点生成基础信息...")
|
|||
|
|
for func_node in nodes_without_raw_info:
|
|||
|
|
func_node.summary = f"函数 {func_node.name},位于 {func_node.file_path}"
|
|||
|
|
func_node.logic_flow = "无法生成逻辑流程(缺少原始代码信息)"
|
|||
|
|
func_node.embedding = [0.0] * 1024
|
|||
|
|
func_node.raw_attributes.update({
|
|||
|
|
"calls": FUNCTION_CALL_GRAPH.get(func_node.name, []),
|
|||
|
|
"called_by": CALLED_BY_GRAPH.get(func_node.name, []),
|
|||
|
|
"includes": FILE_DEPENDENCIES.get(func_node.file_path, []),
|
|||
|
|
"qualified_name": _qualified_name(func_node),
|
|||
|
|
"embedding_available": False,
|
|||
|
|
"embedding_error": "Missing source information.",
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_call_edges(all_nodes, all_edges, symbol_resolver):
|
|||
|
|
"""阶段二:构建函数调用边"""
|
|||
|
|
logger.info("阶段二:解析符号,构建函数调用图 (CALLS 边)...")
|
|||
|
|
|
|||
|
|
# 构建反向调用图
|
|||
|
|
CALLED_BY_GRAPH.clear()
|
|||
|
|
for caller_name, callee_names in FUNCTION_CALL_GRAPH.items():
|
|||
|
|
caller_node_id = f"Function:{caller_name}"
|
|||
|
|
# 确保调用者节点已存在
|
|||
|
|
if not any(n.id == caller_node_id for n in all_nodes):
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
for callee_name in callee_names:
|
|||
|
|
# 通过符号解析器将名称解析为节点ID
|
|||
|
|
callee_node_id = symbol_resolver.resolve_name(
|
|||
|
|
callee_name,
|
|||
|
|
{'current_file': next((n.file_path for n in all_nodes if n.id == caller_node_id), None)}
|
|||
|
|
)
|
|||
|
|
if callee_node_id:
|
|||
|
|
# 添加 CALLS 边
|
|||
|
|
all_edges.append(GraphEdge(
|
|||
|
|
source_id=caller_node_id,
|
|||
|
|
target_id=callee_node_id,
|
|||
|
|
type=EdgeType.CALLS
|
|||
|
|
))
|
|||
|
|
# 更新反向图
|
|||
|
|
if callee_name not in CALLED_BY_GRAPH:
|
|||
|
|
CALLED_BY_GRAPH[callee_name] = []
|
|||
|
|
CALLED_BY_GRAPH[callee_name].append(caller_name)
|
|||
|
|
|
|||
|
|
logger.info(f"阶段二完成。已构建边数: {len(all_edges)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_results(all_nodes, all_edges, base_name, target_path, mode, output_dir=None):
|
|||
|
|
"""阶段四:持久化存储结果"""
|
|||
|
|
logger.info("阶段四:持久化图谱与向量索引...")
|
|||
|
|
|
|||
|
|
output_path = _normalize_output_dir(output_dir)
|
|||
|
|
graph_data_path = str(output_path / f"{base_name}_code_knowledge_graph.json")
|
|||
|
|
vector_db_path = str(output_path / f"{base_name}_rag.faiss")
|
|||
|
|
metadata_path = str(output_path / f"{base_name}_rag_metadata.json")
|
|||
|
|
|
|||
|
|
# 1. 保存为图谱文件
|
|||
|
|
nodes_dict = [node.dict() for node in all_nodes]
|
|||
|
|
edges_dict = [edge.dict() for edge in all_edges]
|
|||
|
|
|
|||
|
|
graph_data = {
|
|||
|
|
"metadata": {
|
|||
|
|
"project_root": target_path if mode == "full_project" else "single_file",
|
|||
|
|
"file_path": target_path if mode == "single_file" else None,
|
|||
|
|
"mode": mode,
|
|||
|
|
"generated_at": datetime.now().isoformat(),
|
|||
|
|
"total_nodes": len(all_nodes),
|
|||
|
|
"total_edges": len(all_edges)
|
|||
|
|
},
|
|||
|
|
"nodes": nodes_dict,
|
|||
|
|
"edges": edges_dict
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
with open(graph_data_path, 'w', encoding='utf-8') as f:
|
|||
|
|
json.dump(graph_data, f, indent=2, ensure_ascii=False, default=str)
|
|||
|
|
logger.info(f" 图谱结构已保存至: {graph_data_path}")
|
|||
|
|
print(f"✓ 图谱结构已保存: {graph_data_path}")
|
|||
|
|
|
|||
|
|
# 2. 构建并保存向量索引
|
|||
|
|
function_nodes = [node for node in all_nodes if node.type == NodeType.FUNCTION and node.embedding]
|
|||
|
|
if function_nodes:
|
|||
|
|
dimension = len(function_nodes[0].embedding)
|
|||
|
|
metadata_for_faiss = []
|
|||
|
|
|
|||
|
|
vectors = []
|
|||
|
|
for node in function_nodes:
|
|||
|
|
vectors.append(node.embedding)
|
|||
|
|
metadata_for_faiss.append(_function_metadata(node, dimension))
|
|||
|
|
|
|||
|
|
if vectors:
|
|||
|
|
if faiss is not None:
|
|||
|
|
if np is None:
|
|||
|
|
raise RuntimeError("numpy is required when writing a FAISS index.")
|
|||
|
|
vectors_np = np.array(vectors, dtype='float32')
|
|||
|
|
index = faiss.IndexFlatL2(dimension)
|
|||
|
|
index.add(vectors_np)
|
|||
|
|
faiss.write_index(index, vector_db_path)
|
|||
|
|
else:
|
|||
|
|
with open(vector_db_path, 'w', encoding='utf-8') as f:
|
|||
|
|
json.dump(
|
|||
|
|
{
|
|||
|
|
"format": "simple_l2_vector_index",
|
|||
|
|
"dimension": dimension,
|
|||
|
|
"vectors": [[float(value) for value in vector] for vector in vectors],
|
|||
|
|
},
|
|||
|
|
f,
|
|||
|
|
ensure_ascii=False,
|
|||
|
|
)
|
|||
|
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
|||
|
|
json.dump(metadata_for_faiss, f, indent=2, ensure_ascii=False)
|
|||
|
|
logger.info(f" 向量索引已保存。函数节点数: {len(function_nodes)}")
|
|||
|
|
logger.info(f" 向量数据库: {vector_db_path}")
|
|||
|
|
logger.info(f" 元数据: {metadata_path}")
|
|||
|
|
print(f"✓ 向量索引已保存: {vector_db_path}")
|
|||
|
|
else:
|
|||
|
|
logger.warning(" 没有可生成向量的函数节点,跳过FAISS索引构建。")
|
|||
|
|
print("⚠ 无向量数据,跳过FAISS索引构建")
|
|||
|
|
|
|||
|
|
return graph_data_path, vector_db_path, metadata_path
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_code_knowledge_base(target_path, output_dir=None, semantic=True, base_name=None, embedding_dim=1024):
|
|||
|
|
"""Build a code knowledge base without interactive prompts."""
|
|||
|
|
target_path = str(Path(target_path).resolve())
|
|||
|
|
if not os.path.exists(target_path):
|
|||
|
|
raise FileNotFoundError(f"Target path does not exist: {target_path}")
|
|||
|
|
|
|||
|
|
all_nodes = []
|
|||
|
|
all_edges = []
|
|||
|
|
symbol_resolver = SymbolResolver()
|
|||
|
|
function_raw_info_map = {}
|
|||
|
|
|
|||
|
|
FUNCTION_CALL_GRAPH.clear()
|
|||
|
|
FILE_DEPENDENCIES.clear()
|
|||
|
|
CALLED_BY_GRAPH.clear()
|
|||
|
|
|
|||
|
|
if os.path.isdir(target_path):
|
|||
|
|
mode = "full_project"
|
|||
|
|
process_project_files(target_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map)
|
|||
|
|
base_name = base_name or "project"
|
|||
|
|
else:
|
|||
|
|
mode = "single_file"
|
|||
|
|
process_single_file(target_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map)
|
|||
|
|
base_name = base_name or f"file_{Path(target_path).stem}"
|
|||
|
|
|
|||
|
|
build_call_edges(all_nodes, all_edges, symbol_resolver)
|
|||
|
|
|
|||
|
|
if semantic:
|
|||
|
|
semantic_enhancement(all_nodes, all_edges, function_raw_info_map)
|
|||
|
|
else:
|
|||
|
|
for node in all_nodes:
|
|||
|
|
if node.type == NodeType.FUNCTION:
|
|||
|
|
raw_info = function_raw_info_map.get(node.id, {})
|
|||
|
|
node.summary = node.summary or f"Function {node.name}, located at {node.file_path}"
|
|||
|
|
node.logic_flow = node.logic_flow or "Semantic enhancement was skipped."
|
|||
|
|
node.embedding = node.embedding or [0.0] * embedding_dim
|
|||
|
|
node.raw_attributes.update({
|
|||
|
|
"comment": raw_info.get("comment", ""),
|
|||
|
|
"code_snippet": raw_info.get("code_snippet", ""),
|
|||
|
|
"calls": FUNCTION_CALL_GRAPH.get(node.name, []),
|
|||
|
|
"called_by": CALLED_BY_GRAPH.get(node.name, []),
|
|||
|
|
"includes": FILE_DEPENDENCIES.get(node.file_path, []),
|
|||
|
|
"qualified_name": _qualified_name(node),
|
|||
|
|
"embedding_available": False,
|
|||
|
|
"embedding_error": "Semantic enhancement was skipped.",
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return save_results(all_nodes, all_edges, base_name, target_path, mode, output_dir=output_dir)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_rag_database_interactive():
|
|||
|
|
"""交互式知识库构建主流程"""
|
|||
|
|
logger.info("=== 交互式知识图谱构建流程 ===")
|
|||
|
|
|
|||
|
|
# 显示模式选择菜单
|
|||
|
|
print("\n" + "=" * 50)
|
|||
|
|
print("知识库构建模式选择")
|
|||
|
|
print("=" * 50)
|
|||
|
|
print("1. 全项目模式 - 处理整个项目目录")
|
|||
|
|
print("2. 单文件模式 - 处理单个文件")
|
|||
|
|
print("3. 退出")
|
|||
|
|
print("=" * 50)
|
|||
|
|
|
|||
|
|
choice = input("请选择模式 (1-3): ").strip()
|
|||
|
|
|
|||
|
|
if choice == "3":
|
|||
|
|
logger.info("用户选择退出")
|
|||
|
|
return None, None
|
|||
|
|
|
|||
|
|
if choice == "1":
|
|||
|
|
# 全项目模式 - 直接引导用户输入
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("全项目知识库构建模式")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
# 显示当前config.py中配置的默认路径
|
|||
|
|
current_default_path = str(PROJECT_ROOT) if PROJECT_ROOT else "未配置"
|
|||
|
|
if not current_default_path or current_default_path == ".":
|
|||
|
|
current_default_path = "当前工作目录"
|
|||
|
|
|
|||
|
|
# 直接要求用户输入
|
|||
|
|
while True:
|
|||
|
|
project_prompt = f"请直接输入要构建知识库的项目根目录完整路径\n(当前config.py中的默认路径: {current_default_path}): "
|
|||
|
|
user_input = input(project_prompt).strip()
|
|||
|
|
|
|||
|
|
if not user_input:
|
|||
|
|
# 用户没有输入,尝试使用config中的默认路径
|
|||
|
|
if PROJECT_ROOT and os.path.exists(PROJECT_ROOT) and os.path.isdir(PROJECT_ROOT):
|
|||
|
|
target_path = PROJECT_ROOT
|
|||
|
|
print(f"使用config.py中的默认路径: {target_path}")
|
|||
|
|
else:
|
|||
|
|
print("⚠ 您没有输入路径,且config.py中的默认路径无效。")
|
|||
|
|
retry = input("是否重新输入? (y/n): ").strip().lower()
|
|||
|
|
if retry != 'y':
|
|||
|
|
print("知识库构建取消。")
|
|||
|
|
return None, None
|
|||
|
|
continue
|
|||
|
|
else:
|
|||
|
|
# 使用用户输入的路径
|
|||
|
|
target_path = Path(user_input)
|
|||
|
|
|
|||
|
|
# 验证路径
|
|||
|
|
if not os.path.exists(target_path):
|
|||
|
|
print(f"⚠ 路径不存在: {target_path}")
|
|||
|
|
retry = input("是否重新输入? (y/n): ").strip().lower()
|
|||
|
|
if retry != 'y':
|
|||
|
|
print("知识库构建取消。")
|
|||
|
|
return None, None
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
if not os.path.isdir(target_path):
|
|||
|
|
print(f"⚠ 路径不是目录: {target_path}")
|
|||
|
|
retry = input("是否重新输入? (y/n): ").strip().lower()
|
|||
|
|
if retry != 'y':
|
|||
|
|
print("知识库构建取消。")
|
|||
|
|
return None, None
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
break # 路径有效,退出循环
|
|||
|
|
|
|||
|
|
mode = "full_project"
|
|||
|
|
print(f"\n✓ 将处理整个项目: {target_path}")
|
|||
|
|
print("开始处理...")
|
|||
|
|
|
|||
|
|
elif choice == "2":
|
|||
|
|
# 单文件模式
|
|||
|
|
while True:
|
|||
|
|
file_path = input("\n请输入要处理的文件完整路径: ").strip()
|
|||
|
|
if not file_path:
|
|||
|
|
print("输入为空,退出单文件模式")
|
|||
|
|
return None, None
|
|||
|
|
|
|||
|
|
if not os.path.exists(file_path):
|
|||
|
|
print(f"错误: 文件不存在: {file_path}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
file_ext = Path(file_path).suffix.lower()
|
|||
|
|
if file_ext not in CPP_EXTENSIONS:
|
|||
|
|
print(f"警告: 文件类型 {file_ext} 可能不是C/C++文件,继续吗? (y/n): ", end="")
|
|||
|
|
confirm = input().strip().lower()
|
|||
|
|
if confirm != 'y':
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
target_path = file_path
|
|||
|
|
mode = "single_file"
|
|||
|
|
print(f"\n将处理单个文件: {target_path}")
|
|||
|
|
print("开始处理...")
|
|||
|
|
break
|
|||
|
|
else:
|
|||
|
|
print("无效的选择,退出")
|
|||
|
|
return None, None
|
|||
|
|
|
|||
|
|
# 确认开始
|
|||
|
|
print("\n即将开始构建知识库,确认继续? (y/n): ", end="")
|
|||
|
|
confirm = input().strip().lower()
|
|||
|
|
if confirm != 'y':
|
|||
|
|
print("操作已取消")
|
|||
|
|
return None, None
|
|||
|
|
|
|||
|
|
# --- 初始化数据结构 ---
|
|||
|
|
all_nodes = [] # 存储 GraphNode 对象
|
|||
|
|
all_edges = [] # 存储 GraphEdge 对象
|
|||
|
|
symbol_resolver = SymbolResolver() # 符号解析器
|
|||
|
|
function_raw_info_map = {} # 临时存储函数的原始信息
|
|||
|
|
|
|||
|
|
# 清空全局图
|
|||
|
|
FUNCTION_CALL_GRAPH.clear()
|
|||
|
|
FILE_DEPENDENCIES.clear()
|
|||
|
|
CALLED_BY_GRAPH.clear()
|
|||
|
|
|
|||
|
|
# ========================
|
|||
|
|
# 阶段一:静态分析,提取节点与硬性边
|
|||
|
|
# ========================
|
|||
|
|
logger.info("阶段一:静态分析,提取代码实体与结构关系...")
|
|||
|
|
|
|||
|
|
if choice == "1":
|
|||
|
|
# 全项目模式
|
|||
|
|
processed_files_count = process_project_files(target_path, all_nodes, all_edges,
|
|||
|
|
symbol_resolver, function_raw_info_map)
|
|||
|
|
logger.info(f"已处理 {processed_files_count} 个文件")
|
|||
|
|
else:
|
|||
|
|
# 单文件模式
|
|||
|
|
process_single_file(target_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map)
|
|||
|
|
logger.info(f"已处理单个文件: {target_path}")
|
|||
|
|
|
|||
|
|
logger.info(f"阶段一完成。初步识别节点数: {len(all_nodes)}")
|
|||
|
|
|
|||
|
|
# ========================
|
|||
|
|
# 阶段二:解析符号,构建准确的调用边
|
|||
|
|
# ========================
|
|||
|
|
build_call_edges(all_nodes, all_edges, symbol_resolver)
|
|||
|
|
|
|||
|
|
# ========================
|
|||
|
|
# 阶段三:语义增强,生成摘要与向量
|
|||
|
|
# ========================
|
|||
|
|
print("\n" + "=" * 50)
|
|||
|
|
print("阶段三:语义增强")
|
|||
|
|
print("=" * 50)
|
|||
|
|
print("此阶段将调用LLM API为函数生成摘要和向量嵌入")
|
|||
|
|
print(f"需要处理的函数节点数: {len([n for n in all_nodes if n.type == NodeType.FUNCTION])}")
|
|||
|
|
print("这可能会消耗API调用额度并需要一些时间")
|
|||
|
|
print("-" * 50)
|
|||
|
|
|
|||
|
|
proceed = input("是否继续阶段三? (y/n): ").strip().lower()
|
|||
|
|
if proceed != 'y':
|
|||
|
|
logger.info("用户跳过阶段三(语义增强)")
|
|||
|
|
print("跳过阶段三,直接进入阶段四")
|
|||
|
|
|
|||
|
|
# 为未处理的函数节点设置默认值
|
|||
|
|
for node in all_nodes:
|
|||
|
|
if node.type == NodeType.FUNCTION and not node.summary:
|
|||
|
|
node.summary = f"函数 {node.name},位于 {node.file_path}"
|
|||
|
|
node.logic_flow = "跳过语义增强阶段"
|
|||
|
|
node.embedding = [0.0] * 1024
|
|||
|
|
raw_info = function_raw_info_map.get(node.id, {})
|
|||
|
|
node.raw_attributes.update({
|
|||
|
|
"comment": raw_info.get("comment", ""),
|
|||
|
|
"code_snippet": raw_info.get("code_snippet", ""),
|
|||
|
|
"calls": FUNCTION_CALL_GRAPH.get(node.name, []),
|
|||
|
|
"called_by": CALLED_BY_GRAPH.get(node.name, []),
|
|||
|
|
"includes": FILE_DEPENDENCIES.get(node.file_path, []),
|
|||
|
|
"qualified_name": _qualified_name(node),
|
|||
|
|
"embedding_available": False,
|
|||
|
|
"embedding_error": "Semantic enhancement was skipped.",
|
|||
|
|
})
|
|||
|
|
else:
|
|||
|
|
semantic_enhancement(all_nodes, all_edges, function_raw_info_map)
|
|||
|
|
|
|||
|
|
# ========================
|
|||
|
|
# 阶段四:持久化存储
|
|||
|
|
# ========================
|
|||
|
|
print("\n" + "=" * 50)
|
|||
|
|
print("阶段四:持久化存储")
|
|||
|
|
print("=" * 50)
|
|||
|
|
|
|||
|
|
save_choice = input("是否保存结果? (y/n): ").strip().lower()
|
|||
|
|
if save_choice != 'y':
|
|||
|
|
logger.info("用户取消保存,退出构建流程")
|
|||
|
|
print("构建流程已取消")
|
|||
|
|
return all_nodes, all_edges
|
|||
|
|
|
|||
|
|
# 确定保存文件名
|
|||
|
|
if choice == "1":
|
|||
|
|
base_name = "project"
|
|||
|
|
else:
|
|||
|
|
file_name = Path(target_path).stem
|
|||
|
|
base_name = f"file_{file_name}"
|
|||
|
|
|
|||
|
|
graph_data_path, vector_db_path, metadata_path = save_results(
|
|||
|
|
all_nodes, all_edges, base_name, target_path, mode
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print("\n" + "=" * 50)
|
|||
|
|
print("构建完成!")
|
|||
|
|
print("=" * 50)
|
|||
|
|
print(f"总节点数: {len(all_nodes)}")
|
|||
|
|
print(f"总边数: {len(all_edges)}")
|
|||
|
|
print(f"图谱文件: {graph_data_path}")
|
|||
|
|
if os.path.exists(vector_db_path):
|
|||
|
|
print(f"向量文件: {vector_db_path}")
|
|||
|
|
|
|||
|
|
logger.info("=== 知识图谱构建流程全部完成 ===")
|
|||
|
|
return all_nodes, all_edges
|