#!/usr/bin/env python3 """ 卫星星务软件代码RAG知识库构建工具(严格保留Tree-sitter/Qwen调用方式) - 仅优化关联关系处理逻辑 - 保留原始Tree-sitter解析和Qwen-V3调用流程 - 通过全局关联图确保关联关系正确性 - 适用于卫星软件项目(C/C++) """ import os import re import json from collections import defaultdict from pathlib import Path import numpy as np import faiss import tree_sitter import tree_sitter_c import tree_sitter_cpp from openai import OpenAI from tree_sitter import Language, Parser import requests from tqdm import tqdm import logging # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("rag_builder.log"), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # === 配置常量 === QWEN_API_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1" QWEN_API_KEY = "" PROJECT_ROOT = "C:/Users/Administrator/Desktop\静态分析问题单过滤/gs-zk和扫描结果/PrjAttCtrlMng" # 替换为实际项目路径 VECTOR_DB_PATH = "satellite_rag.faiss" METADATA_PATH = "satellite_rag_metadata.json" MAX_CODE_LENGTH = 800 # 代码片段最大长度 # === 全局状态 === FUNCTION_CALL_GRAPH = defaultdict(list) FILE_DEPENDENCIES = defaultdict(list) CALLED_BY_GRAPH = defaultdict(list) # 反向图:callee -> [caller] CPP_LANGUAGE = Language(tree_sitter_cpp.language()) parser = Parser() parser.language=CPP_LANGUAGE CPP_EXTENSIONS = {'.c', '.cpp', '.h', '.hpp', '.tcc'} IGNORE_DIRS = {'build', 'cmake-build', '.git', 'vendor', 'lib', 'external','Debug'} def generate_code_summary(func_name, comment, logic, code_snippet, file_path): called_functions = FUNCTION_CALL_GRAPH.get(func_name, []) caller_functions = CALLED_BY_GRAPH.get(func_name, []) # ← 直接获取,O(1) included_headers = FILE_DEPENDENCIES.get(file_path, []) prompt = f""" 你是一名资深的航天软件工程师,请总结以下C++函数的核心功能,**必须严格包含以下6点**: 1. 函数流程与逻辑: {logic} 2. 函数目的: → 结合航天术语总结,必须根据上述“函数流程与逻辑”推导生成,不得依赖假设! 3. 输入参数(名称、类型、作用,用括号列出) 4. 返回值(类型和含义) 5. 与其他函数的关联关系: - 被调用的函数: {', '.join(called_functions) or '无'} - 调用此函数的函数: {', '.join(caller_functions) or '无'} 6. 与跨文件关联的函数(头文件): {', '.join(included_headers) or '无'} --- 函数名: {func_name} 注释: {comment or '无'} 代码片段(截断至{MAX_CODE_LENGTH}字符): {code_snippet[:MAX_CODE_LENGTH]} 输出格式: "功能: [按上述6点组织的连贯段落,不要编号]" """ headers = {"Authorization": f"Bearer {QWEN_API_KEY}", "Content-Type": "application/json"} payload = { "model": "qwen-max", "messages": [{"role": "user", "content": prompt}], "temperature": 0.1 } response = requests.post( "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions", headers=headers, json=payload ) response.raise_for_status() summary = response.json()['choices'][0]['message']['content'].strip() return summary.replace("功能: ", "") def generate_code_logic(func_name, comment, code_snippet, file_path): prompt = f""" 你是一名资深的航天软件工程师,请生成以下C++函数的markdown格式的函数流程图,并总结函数的核心流程逻辑。 """ # 你的提示词保持不变 headers = { "Authorization": f"Bearer {QWEN_API_KEY}", "Content-Type": "application/json" } payload = { "model": "qwen-max", "messages": [ # ← 直接顶层字段,无 "input" 包裹 {"role": "user", "content": prompt} ], "temperature": 0.1 # 可选:提高摘要一致性 } # 使用 OpenAI 兼容端点 response = requests.post( "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions", headers=headers, json=payload ) response.raise_for_status() result = response.json() logic = result['choices'][0]['message']['content'].strip() # ← 注意路径变化! return logic def safe_decode(b: bytes) -> str: for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']: try: return b.decode(encoding) except UnicodeDecodeError: continue return b.decode('utf-8', errors='replace') # === 严格保留的代码提取辅助函数 === def extract_comment_before(node, source_lines): """严格保留注释提取逻辑(已验证正确)""" start_line = node.start_point[0] comment_lines = [] for i in range(start_line - 1, max(0, start_line - 10), -1): line = safe_decode(source_lines[i]).strip() if line.startswith('//') or line.startswith('/*'): comment_lines.append(line) return " ".join(comment_lines[::-1]) or "无注释" def extract_code_snippet(tree, node, source_lines): """严格保留代码片段提取逻辑(已验证正确)""" start_line = node.start_point[0] end_line = node.end_point[0] code_lines = [] for i in range(start_line, end_line + 1): if i < len(source_lines): code_lines.append(safe_decode(source_lines[i])) return "\n".join(code_lines) # === 严格保留的向量化嵌入 === def get_qwen_embedding(text: str) -> np.ndarray: """ 调用 DashScope 的 text-embedding-v4 模型生成 1024 维文本嵌入向量。 使用 OpenAI 兼容 SDK 方式调用。 """ try: client = OpenAI( api_key=QWEN_API_KEY, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" ) response = client.embeddings.create( model="text-embedding-v4", input=text # OpenAI SDK 支持单个字符串或字符串列表 ) embedding = response.data[0].embedding return np.array(embedding, dtype='float32') except Exception as e: # 假设 logger 已定义;若未定义,可临时用 print 替代 try: logger.error(f"获取嵌入失败: {e}") except NameError: print(f"获取嵌入失败: {e}") return np.zeros(1024, dtype='float32') def extract_function_name(func_def_node): """ 从 function_definition 节点中提取完整的函数名(含命名空间/类作用域) """ declarator = func_def_node.child_by_field_name("declarator") if not declarator: return "" # 处理 A::B::func 这种形式:可能是 field_expression 嵌套 def get_qualified_name(node): if node.type == "identifier": return node.text.decode('utf-8') elif node.type == "field_expression": # left.field -> left 可能是 identifier 或 field_expression left = node.child_by_field_name("argument") or node.children[0] right = node.child_by_field_name("field") or node.children[-1] left_name = get_qualified_name(left) if left else "" right_name = get_qualified_name(right) if right else "" return f"{left_name}::{right_name}" if left_name and right_name else (left_name or right_name) elif node.type == "function_declarator": # 函数声明器:取其 declarator 部分 func_declarator = node.child_by_field_name("declarator") if func_declarator: return get_qualified_name(func_declarator) # 兜底:直接取文本 return node.text.decode('utf-8').split('(')[0].strip() return get_qualified_name(declarator) def extract_qualified_name_from_field(field_node): """递归解析 field_expression,如 A::B::foo""" def _extract(n): if n.type == "identifier": return n.text.decode('utf-8') elif n.type == "field_expression": # left.field -> argument.field left = n.child_by_field_name("argument") or (n.children[0] if n.children else None) right = n.child_by_field_name("field") or (n.children[-1] if n.children else None) left_str = _extract(left) if left else "" right_str = _extract(right) if right else "" if left_str and right_str: return f"{left_str}::{right_str}" return left_str or right_str else: # 兜底:取文本,截断参数列表 text = n.text.decode('utf-8') return text.split('(')[0].strip() return _extract(field_node) def extract_full_function_name(func_def_node): """ 从 function_definition 节点中提取干净的函数名(支持 A::foo) """ declarator = func_def_node.child_by_field_name("declarator") if not declarator: return "" # 情况1: declarator 是 function_declarator(最常见) if declarator.type == "function_declarator": inner_declarator = declarator.child_by_field_name("declarator") if not inner_declarator: return "" if inner_declarator.type == "identifier": return inner_declarator.text.decode('utf-8') elif inner_declarator.type == "field_expression": return extract_qualified_name_from_field(inner_declarator) else: # 兜底处理(如数组、指针等,但函数名通常不会这样) raw = inner_declarator.text.decode('utf-8') return raw.split('(')[0].strip() # 情况2: declarator 直接是 identifier(旧式C或简单情况) elif declarator.type == "identifier": return declarator.text.decode('utf-8') # 情况3: 其他类型(如 pointer_declarator),尝试兜底 else: raw = declarator.text.decode('utf-8') return raw.split('(')[0].strip() def traverse_type(node, target_type): """递归遍历 AST 节点,查找指定类型""" if node.type == target_type: yield node for child in node.children: yield from traverse_type(child, target_type) # === 主处理流程(仅优化关联关系处理)=== def build_rag_database(): """严格保留Tree-sitter/Qwen调用流程,仅优化关联关系处理""" global FUNCTION_CALL_GRAPH, FILE_DEPENDENCIES, CALLED_BY_GRAPH # 清空旧图(避免多次调用累积) FUNCTION_CALL_GRAPH.clear() FILE_DEPENDENCIES.clear() CALLED_BY_GRAPH.clear() dimension = 1024 index = faiss.IndexFlatL2(dimension) metadata_list = [] import re # 全局图(确保在阶段1前清空) FUNCTION_CALL_GRAPH = {} CALLED_BY_GRAPH = {} FILE_DEPENDENCIES = {} def extract_callee_name(func_node): """从 call_expression 的 function 节点中提取函数名标识符""" if func_node.type == "identifier": return func_node.text.decode('utf-8') elif func_node.type == "field_expression": # 如 A::func 或 a.b field = func_node.child_by_field_name("field") if field and field.type == "identifier": return field.text.decode('utf-8') # 递归处理复杂命名空间 text = func_node.text.decode('utf-8') parts = re.split(r'[.:]+', text) return parts[-1].split('(')[0].strip() if parts else "" elif func_node.type == "member_expression": # C++ 风格: obj->method 或 obj.method member = func_node.child_by_field_name("member") if member and member.type == "identifier": return member.text.decode('utf-8') # 兜底 text = func_node.text.decode('utf-8') for sep in ['->', '.']: if sep in text: return text.split(sep)[-1].split('(')[0].strip() return text.split('(')[0].strip() else: # 兜底:从文本中提取最可能的函数名 text = func_node.text.decode('utf-8') # 移除模板 <...> text = re.sub(r'<[^>]*>', '', text) # 取括号前部分 base = text.split('(')[0].strip() # 处理 a::b::c -> c for sep in ['::', '->', '.']: if sep in base: base = base.split(sep)[-1] # 简单合法性检查 if base and base[0].isalpha() and all(c.isalnum() or c in '_$' for c in base): return base return "" def collect_call_expressions(node): """递归遍历 AST 节点,收集所有 call_expression 中的 callee 名称(去重)""" calls = set() def traverse(n): if n.type == "call_expression": func_node = n.child_by_field_name("function") if func_node: callee = extract_callee_name(func_node) if callee != "": calls.add(callee) # 递归子节点 for child in n.children: traverse(child) traverse(node) return list(calls) # ======================== # 阶段1:构建调用图与依赖图 # ======================== logger.info("阶段1: 构建函数调用图和文件依赖图") total_funcs = 0 for root, _, files in os.walk(PROJECT_ROOT): if any(ignore_dir in Path(root).parts for ignore_dir in IGNORE_DIRS): continue for file in files: if Path(file).suffix.lower() in CPP_EXTENSIONS: file_path = os.path.join(root, file) logger.debug(f"解析文件构建关联图: {file_path}") try: with open(file_path, 'rb') as f: source_bytes = f.read() tree = parser.parse(source_bytes) # --- 1. 收集头文件依赖 --- includes = [] for node in tree.root_node.named_children: if node.type == "preproc_include": path_node = node.child_by_field_name("path") if path_node: include_text = path_node.text.decode('utf-8').strip('"<>') includes.append(include_text) FILE_DEPENDENCIES[file_path] = includes # --- 2. 遍历顶层节点,找函数定义 --- for node in tree.root_node.named_children: if node.type == "function_definition": func_name = extract_full_function_name(node) if func_name == "": continue total_funcs += 1 FUNCTION_CALL_GRAPH[func_name] = [] # 初始化 # 获取函数体(body) body = node.child_by_field_name("body") if body: callees = collect_call_expressions(body) FUNCTION_CALL_GRAPH[func_name] = callees except Exception as e: logger.error(f"解析文件失败 {file_path}: {e}", exc_info=True) # --- 3. 构建反向调用图(谁调用了我)--- CALLED_BY_GRAPH.clear() for caller, callees in FUNCTION_CALL_GRAPH.items(): for callee in callees: if callee not in CALLED_BY_GRAPH: CALLED_BY_GRAPH[callee] = [] CALLED_BY_GRAPH[callee].append(caller) logger.info(f"阶段1完成。共识别 {total_funcs} 个函数,{len(FILE_DEPENDENCIES)} 个文件依赖。") # === 阶段2: 生成摘要并构建RAG === logger.info("阶段2: 生成函数摘要并构建RAG") for root, _, files in os.walk(PROJECT_ROOT): if any(ignore_dir in Path(root).parts for ignore_dir in IGNORE_DIRS): continue for file in files: if Path(file).suffix.lower() in CPP_EXTENSIONS: file_path = os.path.join(root, file) logger.debug(f"处理文件: {file_path}") try: with open(file_path, 'rb') as f: source = f.read() tree = parser.parse(source) source_lines = source.split(b'\n') for node in tree.root_node.named_children: if node.type == "function_definition": # === 关键修正:使用统一的函数名提取逻辑 === func_name = extract_full_function_name(node) if func_name == "": logger.warning(f"无法解析函数名,跳过节点: {node.text[:100]}...") continue # 提取前置注释(保持不变) comment = extract_comment_before(node, source_lines) # 提取代码片段(保持不变) func_code = extract_code_snippet(tree, node, source_lines) print(func_code) # 生成摘要(保持不变) logic = generate_code_logic(func_name, comment, func_code, file_path) summary = generate_code_summary(func_name, comment,logic,func_code, file_path) print(summary) # 从全局图中获取关系(确保 func_name 与阶段1一致!) called_functions = FUNCTION_CALL_GRAPH.get(func_name, []) caller_functions = CALLED_BY_GRAPH.get(func_name, []) included_headers = FILE_DEPENDENCIES.get(file_path, []) # 构建上下文文本 context_text = ( f"【实体类型】函数\n" f"【函数名】{func_name}\n" f"【所在文件】{os.path.basename(file_path)}\n" f"【代码注释】{comment.strip() if comment else '无'}\n" f"【功能摘要】{summary}\n" f"【流程逻辑】{logic}\n" f"【调用的函数】{', '.join(called_functions) if called_functions else '无'}\n" f"【被以下函数调用】{', '.join(caller_functions) if caller_functions else '无'}\n" f"【包含的头文件】{', '.join(included_headers) if included_headers else '无'}" ) # 生成嵌入向量 vector = get_qwen_embedding(context_text) # 构建元数据 metadata = { "file": file_path, "line": node.start_point[0] + 1, "type": "function", "name": func_name, "summary": summary, "logic": logic, "calls": called_functions, "called_by": caller_functions, "includes": included_headers, "comment": comment or "" } # 添加到向量库 index.add(np.array([vector], dtype='float32')) metadata_list.append(metadata) except Exception as e: logger.error(f"处理文件 {file_path} 失败: {str(e)}", exc_info=True) # 保存结果 faiss.write_index(index, VECTOR_DB_PATH) with open(METADATA_PATH, 'w', encoding='utf-8') as f: json.dump(metadata_list, f, indent=2, ensure_ascii=False) logger.info(f"知识库构建完成! 总实体: {len(metadata_list)}") logger.info(f"向量数据库保存至: {VECTOR_DB_PATH}") logger.info(f"元数据保存至: {METADATA_PATH}") return index, metadata_list # 如果你确实需要返回,可以放这里 if __name__ == "__main__": build_rag_database()