""" graph_builder.py - 知识图谱构建的核心流程控制器 包含完整的构建流程和交互逻辑 """ import os import json try: import faiss except ImportError: # pragma: no cover - depends on deployment environment faiss = None try: import numpy as np except ImportError: # pragma: no cover - depends on deployment environment np = None import time from datetime import datetime from pathlib import Path import logging from tqdm import tqdm from collections import defaultdict from models import GraphNode, GraphEdge, NodeType, EdgeType from code_parser import ( parser, CPP_EXTENSIONS, IGNORE_DIRS, safe_decode, extract_comment_before, extract_code_snippet, extract_full_function_name, collect_call_expressions, extract_class_name, extract_member_function_name, extract_base_classes ) from llm_processor import generate_code_logic, generate_code_summary, get_qwen_embedding from config import PROJECT_ROOT, MAX_CODE_LENGTH, QWEN_EMBEDDING_MODEL logger = logging.getLogger(__name__) # 全局图数据结构 FUNCTION_CALL_GRAPH = defaultdict(list) FILE_DEPENDENCIES = defaultdict(list) CALLED_BY_GRAPH = defaultdict(list) def _as_node_type(value): return value.value if hasattr(value, "value") else value def _normalize_output_dir(output_dir): if output_dir is None: return Path(".") path = Path(output_dir) path.mkdir(parents=True, exist_ok=True) return path def _qualified_name(node): if node.id.startswith("Function:"): return node.id[len("Function:"):] return node.name def _has_real_embedding(embedding): if not embedding: return False try: return any(abs(float(value)) > 1e-12 for value in embedding) except (TypeError, ValueError): return False def _function_metadata(node, embedding_dim=None): raw = node.raw_attributes or {} embedding_available = raw.get("embedding_available") if embedding_available is None: embedding_available = _has_real_embedding(node.embedding) return { "node_id": node.id, "name": node.name, "function_name": node.name, "qualified_name": raw.get("qualified_name") or _qualified_name(node), "file": node.file_path, "file_path": node.file_path, "start_line": node.start_line, "end_line": node.end_line, "signature": node.signature, "summary": node.summary or "", "logic_flow": node.logic_flow or "", "code_snippet": raw.get("code_snippet", ""), "calls": raw.get("calls", []), "called_by": raw.get("called_by", []), "includes": raw.get("includes", []), "embedding_model": QWEN_EMBEDDING_MODEL, "embedding_dim": embedding_dim or (len(node.embedding) if node.embedding else 0), "embedding_available": bool(embedding_available), "embedding_error": raw.get("embedding_error", ""), } class SymbolResolver: """符号解析器:将代码符号映射为图谱节点ID""" def __init__(self): self._symbol_table = {} self._file_to_entities = defaultdict(list) def register_entity(self, node: GraphNode): """注册新发现的实体到符号表""" self._symbol_table[node.name] = node.id if node.file_path: self._file_to_entities[node.file_path].append(node.id) def resolve_name(self, name: str, current_context: dict) -> str: """ 解析代码中的名称到节点ID :param name: 代码中的名称 :param current_context: 当前上下文(文件、类等) :return: 节点ID或None """ if name in self._symbol_table: return self._symbol_table[name] for qualified_name, node_id in self._symbol_table.items(): if qualified_name.endswith('::' + name) or qualified_name == name: return node_id return None def process_single_file(file_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map): """处理单个C/C++文件,提取代码实体和结构关系""" logger.info(f"处理单个文件: {file_path}") try: with open(file_path, 'rb') as f: source_bytes = f.read() source_lines = source_bytes.split(b'\n') tree = parser.parse(source_bytes) # 1. 创建文件节点 file_node_id = f"File:{file_path}" file_node = GraphNode( id=file_node_id, type=NodeType.FILE, name=Path(file_path).name, file_path=file_path ) all_nodes.append(file_node) symbol_resolver.register_entity(file_node) # 2. 收集头文件依赖 includes = [] for node in tree.root_node.named_children: if node.type == "preproc_include": path_node = node.child_by_field_name("path") if path_node: include_text = path_node.text.decode('utf-8').strip('"<>') includes.append(include_text) FILE_DEPENDENCIES[file_path] = includes # 3. 遍历文件内的顶级声明 for node in tree.root_node.named_children: # 3.1 处理函数定义 if node.type == "function_definition": func_name = extract_full_function_name(node) if func_name == "": logger.warning(f" 跳过无法解析函数名的节点: {node.text[:50]}...") continue # 创建函数节点 func_node_id = f"Function:{func_name}" func_node = GraphNode( id=func_node_id, type=NodeType.FUNCTION, name=func_name, signature=node.text.decode('utf-8', errors='ignore').split('{')[0].strip() + ' {...}', file_path=file_path, start_line=node.start_point[0] + 1, end_line=node.end_point[0] + 1, raw_attributes={'ast_node_type': 'function_definition'} ) all_nodes.append(func_node) symbol_resolver.register_entity(func_node) # 存储原始信息 comment = extract_comment_before(node, source_lines) code_snippet = extract_code_snippet(tree, node, source_lines) function_raw_info_map[func_node_id] = { 'ast_node': node, 'comment': comment, 'code_snippet': code_snippet, 'source_lines': source_lines, 'tree': tree } # 添加 FILE -> CONTAINS -> FUNCTION 边 all_edges.append(GraphEdge( source_id=file_node_id, target_id=func_node_id, type=EdgeType.CONTAINS )) # 收集函数调用关系 callees = collect_call_expressions(node) FUNCTION_CALL_GRAPH[func_name] = callees # 3.2 处理类/结构体定义 elif node.type in ["class_specifier", "struct_specifier"]: class_name = extract_class_name(node) if not class_name or class_name == "": logger.debug(f" 跳过无法解析类名的节点") continue # 创建类节点 class_node_id = f"Class:{class_name}" class_node = GraphNode( id=class_node_id, type=NodeType.CLASS, name=class_name, signature=node.text.decode('utf-8', errors='ignore').split('{')[0].strip() + ' {...}', file_path=file_path, start_line=node.start_point[0] + 1, end_line=node.end_point[0] + 1, raw_attributes={ 'ast_type': node.type, 'is_struct': node.type == "struct_specifier" } ) all_nodes.append(class_node) symbol_resolver.register_entity(class_node) # 添加 FILE -> CONTAINS -> CLASS 边 all_edges.append(GraphEdge( source_id=file_node_id, target_id=class_node_id, type=EdgeType.CONTAINS )) # 处理基类/继承关系 base_clause = node.child_by_field_name("base_clause") if base_clause: base_classes = extract_base_classes(base_clause) for base_class_info in base_classes: base_name = base_class_info.get('name') access_type = base_class_info.get('access', 'public') edge_type = EdgeType.EXTENDS if access_type in ['public', 'protected', 'private'] else EdgeType.IMPLEMENTS if base_name: base_id = symbol_resolver.resolve_name( base_name, {'current_file': file_path, 'current_class': class_name} ) if base_id: all_edges.append(GraphEdge( source_id=class_node_id, target_id=base_id, type=edge_type, properties={'access': access_type} )) # 处理类内方法 for member in node.named_children: if member.type == "function_definition": member_func_name = extract_member_function_name(member, class_name) if not member_func_name or member_func_name == "": continue member_node_id = f"Function:{class_name}::{member_func_name}" existing_member_node = next((n for n in all_nodes if n.id == member_node_id), None) if not existing_member_node: member_node = GraphNode( id=member_node_id, type=NodeType.FUNCTION, name=member_func_name, signature=member.text.decode('utf-8', errors='ignore').split('{')[0].strip() + ' {...}', file_path=file_path, start_line=member.start_point[0] + 1, end_line=member.end_point[0] + 1, raw_attributes={ 'is_member': True, 'parent_class': class_name, 'ast_node_type': 'member_function_definition' } ) all_nodes.append(member_node) symbol_resolver.register_entity(member_node) # 存储原始信息 member_comment = extract_comment_before(member, source_lines) member_code = extract_code_snippet(tree, member, source_lines) function_raw_info_map[member_node_id] = { 'ast_node': member, 'comment': member_comment, 'code_snippet': member_code, 'source_lines': source_lines, 'tree': tree } # 收集成员函数调用 member_callees = collect_call_expressions(member) FUNCTION_CALL_GRAPH[f"{class_name}::{member_func_name}"] = member_callees # 添加类包含边 all_edges.append(GraphEdge( source_id=class_node_id, target_id=member_node_id, type=EdgeType.CONTAINS, properties={'member_type': 'method'} )) except Exception as e: logger.error(f"解析文件 {file_path} 失败: {e}", exc_info=True) raise def process_project_files(project_root, all_nodes, all_edges, symbol_resolver, function_raw_info_map): """处理整个项目目录""" logger.info(f"处理项目目录: {project_root}") for root, dirs, files in os.walk(project_root): # 跳过忽略目录 dirs[:] = [d for d in dirs if d not in IGNORE_DIRS] for file in files: file_path_obj = Path(file) if file_path_obj.suffix.lower() not in CPP_EXTENSIONS: continue full_file_path = os.path.join(root, file) process_single_file(full_file_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map) return len(files) def semantic_enhancement(all_nodes, all_edges, function_raw_info_map): """阶段三:语义增强,为函数节点生成摘要与向量嵌入""" logger.info("阶段三:语义增强,为函数节点生成摘要与向量嵌入...") # 1. 过滤出需要处理的函数节点 function_nodes = [n for n in all_nodes if n.type == NodeType.FUNCTION] nodes_with_raw_info = [] nodes_without_raw_info = [] for func_node in function_nodes: if func_node.id in function_raw_info_map: nodes_with_raw_info.append(func_node) else: nodes_without_raw_info.append(func_node) logger.info(f"有原始信息的节点: {len(nodes_with_raw_info)} 个") if nodes_without_raw_info: logger.warning(f"无原始信息的节点: {len(nodes_without_raw_info)} 个") # 2. 处理有原始信息的节点 if nodes_with_raw_info: print(f"\n开始为 {len(nodes_with_raw_info)} 个函数生成摘要和向量...") print("这可能需要一些时间,请耐心等待...") processed_count = 0 for func_node in tqdm(nodes_with_raw_info, desc="生成函数摘要", unit="个"): func_node_id = func_node.id func_name = func_node.name # 获取原始信息 raw_info = function_raw_info_map[func_node_id] comment = raw_info['comment'] code_snippet = raw_info['code_snippet'] file_path = func_node.file_path func_node.raw_attributes.update({ "comment": comment, "code_snippet": code_snippet, "includes": FILE_DEPENDENCIES.get(file_path, []), "qualified_name": _qualified_name(func_node), }) # 1. 生成流程逻辑 try: logic = generate_code_logic(func_name, comment, code_snippet, file_path) except Exception as e: logger.error(f"为函数 {func_name} 生成逻辑失败: {e}") logic = "逻辑生成失败。" # 2. 生成功能摘要 try: summary = generate_code_summary( func_name, comment, logic, code_snippet, file_path ) except Exception as e: logger.error(f"为函数 {func_name} 生成摘要失败: {e}") summary = "摘要生成失败。" # 3. 更新节点信息 func_node.summary = summary func_node.logic_flow = logic # 获取调用关系 called_functions = FUNCTION_CALL_GRAPH.get(func_name, []) caller_functions = CALLED_BY_GRAPH.get(func_name, []) func_node.raw_attributes.update({ 'calls': called_functions, 'called_by': caller_functions, 'includes': FILE_DEPENDENCIES.get(file_path, []) }) # 4. 生成向量嵌入 context_text_for_embedding = ( f"函数名: {func_name}\n" f"功能摘要: {summary}\n" f"流程逻辑: {logic}\n" f"所在文件: {os.path.basename(file_path)}" ) try: embedding_vector = get_qwen_embedding(context_text_for_embedding) embedding_values = ( embedding_vector.tolist() if hasattr(embedding_vector, "tolist") else list(embedding_vector) ) if not _has_real_embedding(embedding_values): raise RuntimeError("Embedding API returned an empty or zero vector.") func_node.embedding = embedding_values func_node.raw_attributes["embedding_available"] = True func_node.raw_attributes["embedding_error"] = "" except Exception as e: logger.error(f"为函数 {func_name} 生成嵌入失败: {e}") func_node.embedding = [0.0] * 1024 func_node.raw_attributes["embedding_available"] = False func_node.raw_attributes["embedding_error"] = str(e)[:500] processed_count += 1 # 每处理5个函数显示一次进度 if processed_count % 5 == 0: print(f" 已处理 {processed_count}/{len(nodes_with_raw_info)} 个函数") # 添加延迟避免API限流 time.sleep(0.1) print(f"✓ 完成 {len(nodes_with_raw_info)} 个函数的语义增强") else: logger.warning("没有找到任何有原始信息的函数节点,跳过摘要生成。") # 3. 处理无原始信息的节点 if nodes_without_raw_info: print(f"\n为 {len(nodes_without_raw_info)} 个无原始信息的节点生成基础信息...") for func_node in nodes_without_raw_info: func_node.summary = f"函数 {func_node.name},位于 {func_node.file_path}" func_node.logic_flow = "无法生成逻辑流程(缺少原始代码信息)" func_node.embedding = [0.0] * 1024 func_node.raw_attributes.update({ "calls": FUNCTION_CALL_GRAPH.get(func_node.name, []), "called_by": CALLED_BY_GRAPH.get(func_node.name, []), "includes": FILE_DEPENDENCIES.get(func_node.file_path, []), "qualified_name": _qualified_name(func_node), "embedding_available": False, "embedding_error": "Missing source information.", }) def build_call_edges(all_nodes, all_edges, symbol_resolver): """阶段二:构建函数调用边""" logger.info("阶段二:解析符号,构建函数调用图 (CALLS 边)...") # 构建反向调用图 CALLED_BY_GRAPH.clear() for caller_name, callee_names in FUNCTION_CALL_GRAPH.items(): caller_node_id = f"Function:{caller_name}" # 确保调用者节点已存在 if not any(n.id == caller_node_id for n in all_nodes): continue for callee_name in callee_names: # 通过符号解析器将名称解析为节点ID callee_node_id = symbol_resolver.resolve_name( callee_name, {'current_file': next((n.file_path for n in all_nodes if n.id == caller_node_id), None)} ) if callee_node_id: # 添加 CALLS 边 all_edges.append(GraphEdge( source_id=caller_node_id, target_id=callee_node_id, type=EdgeType.CALLS )) # 更新反向图 if callee_name not in CALLED_BY_GRAPH: CALLED_BY_GRAPH[callee_name] = [] CALLED_BY_GRAPH[callee_name].append(caller_name) logger.info(f"阶段二完成。已构建边数: {len(all_edges)}") def save_results(all_nodes, all_edges, base_name, target_path, mode, output_dir=None): """阶段四:持久化存储结果""" logger.info("阶段四:持久化图谱与向量索引...") output_path = _normalize_output_dir(output_dir) graph_data_path = str(output_path / f"{base_name}_code_knowledge_graph.json") vector_db_path = str(output_path / f"{base_name}_rag.faiss") metadata_path = str(output_path / f"{base_name}_rag_metadata.json") # 1. 保存为图谱文件 nodes_dict = [node.dict() for node in all_nodes] edges_dict = [edge.dict() for edge in all_edges] graph_data = { "metadata": { "project_root": target_path if mode == "full_project" else "single_file", "file_path": target_path if mode == "single_file" else None, "mode": mode, "generated_at": datetime.now().isoformat(), "total_nodes": len(all_nodes), "total_edges": len(all_edges) }, "nodes": nodes_dict, "edges": edges_dict } with open(graph_data_path, 'w', encoding='utf-8') as f: json.dump(graph_data, f, indent=2, ensure_ascii=False, default=str) logger.info(f" 图谱结构已保存至: {graph_data_path}") print(f"✓ 图谱结构已保存: {graph_data_path}") # 2. 构建并保存向量索引 function_nodes = [node for node in all_nodes if node.type == NodeType.FUNCTION and node.embedding] if function_nodes: dimension = len(function_nodes[0].embedding) metadata_for_faiss = [] vectors = [] for node in function_nodes: vectors.append(node.embedding) metadata_for_faiss.append(_function_metadata(node, dimension)) if vectors: if faiss is not None: if np is None: raise RuntimeError("numpy is required when writing a FAISS index.") vectors_np = np.array(vectors, dtype='float32') index = faiss.IndexFlatL2(dimension) index.add(vectors_np) faiss.write_index(index, vector_db_path) else: with open(vector_db_path, 'w', encoding='utf-8') as f: json.dump( { "format": "simple_l2_vector_index", "dimension": dimension, "vectors": [[float(value) for value in vector] for vector in vectors], }, f, ensure_ascii=False, ) with open(metadata_path, 'w', encoding='utf-8') as f: json.dump(metadata_for_faiss, f, indent=2, ensure_ascii=False) logger.info(f" 向量索引已保存。函数节点数: {len(function_nodes)}") logger.info(f" 向量数据库: {vector_db_path}") logger.info(f" 元数据: {metadata_path}") print(f"✓ 向量索引已保存: {vector_db_path}") else: logger.warning(" 没有可生成向量的函数节点,跳过FAISS索引构建。") print("⚠ 无向量数据,跳过FAISS索引构建") return graph_data_path, vector_db_path, metadata_path def build_code_knowledge_base(target_path, output_dir=None, semantic=True, base_name=None, embedding_dim=1024): """Build a code knowledge base without interactive prompts.""" target_path = str(Path(target_path).resolve()) if not os.path.exists(target_path): raise FileNotFoundError(f"Target path does not exist: {target_path}") all_nodes = [] all_edges = [] symbol_resolver = SymbolResolver() function_raw_info_map = {} FUNCTION_CALL_GRAPH.clear() FILE_DEPENDENCIES.clear() CALLED_BY_GRAPH.clear() if os.path.isdir(target_path): mode = "full_project" process_project_files(target_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map) base_name = base_name or "project" else: mode = "single_file" process_single_file(target_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map) base_name = base_name or f"file_{Path(target_path).stem}" build_call_edges(all_nodes, all_edges, symbol_resolver) if semantic: semantic_enhancement(all_nodes, all_edges, function_raw_info_map) else: for node in all_nodes: if node.type == NodeType.FUNCTION: raw_info = function_raw_info_map.get(node.id, {}) node.summary = node.summary or f"Function {node.name}, located at {node.file_path}" node.logic_flow = node.logic_flow or "Semantic enhancement was skipped." node.embedding = node.embedding or [0.0] * embedding_dim node.raw_attributes.update({ "comment": raw_info.get("comment", ""), "code_snippet": raw_info.get("code_snippet", ""), "calls": FUNCTION_CALL_GRAPH.get(node.name, []), "called_by": CALLED_BY_GRAPH.get(node.name, []), "includes": FILE_DEPENDENCIES.get(node.file_path, []), "qualified_name": _qualified_name(node), "embedding_available": False, "embedding_error": "Semantic enhancement was skipped.", }) return save_results(all_nodes, all_edges, base_name, target_path, mode, output_dir=output_dir) def build_rag_database_interactive(): """交互式知识库构建主流程""" logger.info("=== 交互式知识图谱构建流程 ===") # 显示模式选择菜单 print("\n" + "=" * 50) print("知识库构建模式选择") print("=" * 50) print("1. 全项目模式 - 处理整个项目目录") print("2. 单文件模式 - 处理单个文件") print("3. 退出") print("=" * 50) choice = input("请选择模式 (1-3): ").strip() if choice == "3": logger.info("用户选择退出") return None, None if choice == "1": # 全项目模式 - 直接引导用户输入 print("\n" + "=" * 60) print("全项目知识库构建模式") print("=" * 60) # 显示当前config.py中配置的默认路径 current_default_path = str(PROJECT_ROOT) if PROJECT_ROOT else "未配置" if not current_default_path or current_default_path == ".": current_default_path = "当前工作目录" # 直接要求用户输入 while True: project_prompt = f"请直接输入要构建知识库的项目根目录完整路径\n(当前config.py中的默认路径: {current_default_path}): " user_input = input(project_prompt).strip() if not user_input: # 用户没有输入,尝试使用config中的默认路径 if PROJECT_ROOT and os.path.exists(PROJECT_ROOT) and os.path.isdir(PROJECT_ROOT): target_path = PROJECT_ROOT print(f"使用config.py中的默认路径: {target_path}") else: print("⚠ 您没有输入路径,且config.py中的默认路径无效。") retry = input("是否重新输入? (y/n): ").strip().lower() if retry != 'y': print("知识库构建取消。") return None, None continue else: # 使用用户输入的路径 target_path = Path(user_input) # 验证路径 if not os.path.exists(target_path): print(f"⚠ 路径不存在: {target_path}") retry = input("是否重新输入? (y/n): ").strip().lower() if retry != 'y': print("知识库构建取消。") return None, None continue if not os.path.isdir(target_path): print(f"⚠ 路径不是目录: {target_path}") retry = input("是否重新输入? (y/n): ").strip().lower() if retry != 'y': print("知识库构建取消。") return None, None continue break # 路径有效,退出循环 mode = "full_project" print(f"\n✓ 将处理整个项目: {target_path}") print("开始处理...") elif choice == "2": # 单文件模式 while True: file_path = input("\n请输入要处理的文件完整路径: ").strip() if not file_path: print("输入为空,退出单文件模式") return None, None if not os.path.exists(file_path): print(f"错误: 文件不存在: {file_path}") continue file_ext = Path(file_path).suffix.lower() if file_ext not in CPP_EXTENSIONS: print(f"警告: 文件类型 {file_ext} 可能不是C/C++文件,继续吗? (y/n): ", end="") confirm = input().strip().lower() if confirm != 'y': continue target_path = file_path mode = "single_file" print(f"\n将处理单个文件: {target_path}") print("开始处理...") break else: print("无效的选择,退出") return None, None # 确认开始 print("\n即将开始构建知识库,确认继续? (y/n): ", end="") confirm = input().strip().lower() if confirm != 'y': print("操作已取消") return None, None # --- 初始化数据结构 --- all_nodes = [] # 存储 GraphNode 对象 all_edges = [] # 存储 GraphEdge 对象 symbol_resolver = SymbolResolver() # 符号解析器 function_raw_info_map = {} # 临时存储函数的原始信息 # 清空全局图 FUNCTION_CALL_GRAPH.clear() FILE_DEPENDENCIES.clear() CALLED_BY_GRAPH.clear() # ======================== # 阶段一:静态分析,提取节点与硬性边 # ======================== logger.info("阶段一:静态分析,提取代码实体与结构关系...") if choice == "1": # 全项目模式 processed_files_count = process_project_files(target_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map) logger.info(f"已处理 {processed_files_count} 个文件") else: # 单文件模式 process_single_file(target_path, all_nodes, all_edges, symbol_resolver, function_raw_info_map) logger.info(f"已处理单个文件: {target_path}") logger.info(f"阶段一完成。初步识别节点数: {len(all_nodes)}") # ======================== # 阶段二:解析符号,构建准确的调用边 # ======================== build_call_edges(all_nodes, all_edges, symbol_resolver) # ======================== # 阶段三:语义增强,生成摘要与向量 # ======================== print("\n" + "=" * 50) print("阶段三:语义增强") print("=" * 50) print("此阶段将调用LLM API为函数生成摘要和向量嵌入") print(f"需要处理的函数节点数: {len([n for n in all_nodes if n.type == NodeType.FUNCTION])}") print("这可能会消耗API调用额度并需要一些时间") print("-" * 50) proceed = input("是否继续阶段三? (y/n): ").strip().lower() if proceed != 'y': logger.info("用户跳过阶段三(语义增强)") print("跳过阶段三,直接进入阶段四") # 为未处理的函数节点设置默认值 for node in all_nodes: if node.type == NodeType.FUNCTION and not node.summary: node.summary = f"函数 {node.name},位于 {node.file_path}" node.logic_flow = "跳过语义增强阶段" node.embedding = [0.0] * 1024 raw_info = function_raw_info_map.get(node.id, {}) node.raw_attributes.update({ "comment": raw_info.get("comment", ""), "code_snippet": raw_info.get("code_snippet", ""), "calls": FUNCTION_CALL_GRAPH.get(node.name, []), "called_by": CALLED_BY_GRAPH.get(node.name, []), "includes": FILE_DEPENDENCIES.get(node.file_path, []), "qualified_name": _qualified_name(node), "embedding_available": False, "embedding_error": "Semantic enhancement was skipped.", }) else: semantic_enhancement(all_nodes, all_edges, function_raw_info_map) # ======================== # 阶段四:持久化存储 # ======================== print("\n" + "=" * 50) print("阶段四:持久化存储") print("=" * 50) save_choice = input("是否保存结果? (y/n): ").strip().lower() if save_choice != 'y': logger.info("用户取消保存,退出构建流程") print("构建流程已取消") return all_nodes, all_edges # 确定保存文件名 if choice == "1": base_name = "project" else: file_name = Path(target_path).stem base_name = f"file_{file_name}" graph_data_path, vector_db_path, metadata_path = save_results( all_nodes, all_edges, base_name, target_path, mode ) print("\n" + "=" * 50) print("构建完成!") print("=" * 50) print(f"总节点数: {len(all_nodes)}") print(f"总边数: {len(all_edges)}") print(f"图谱文件: {graph_data_path}") if os.path.exists(vector_db_path): print(f"向量文件: {vector_db_path}") logger.info("=== 知识图谱构建流程全部完成 ===") return all_nodes, all_edges