diff --git a/Knowledge_Build.py b/Knowledge_Build.py new file mode 100644 index 0000000..1080b83 --- /dev/null +++ b/Knowledge_Build.py @@ -0,0 +1,498 @@ +#!/usr/bin/env python3 +""" +卫星星务软件代码RAG知识库构建工具(严格保留Tree-sitter/Qwen调用方式) +- 仅优化关联关系处理逻辑 +- 保留原始Tree-sitter解析和Qwen-V3调用流程 +- 通过全局关联图确保关联关系正确性 +- 适用于卫星软件项目(C/C++) +""" + +import os +import re +import json +from collections import defaultdict +from pathlib import Path +import numpy as np +import faiss +import tree_sitter +import tree_sitter_c +import tree_sitter_cpp +from openai import OpenAI +from tree_sitter import Language, Parser +import requests +from tqdm import tqdm +import logging + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("rag_builder.log"), + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +# === 配置常量 === +QWEN_API_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1" +QWEN_API_KEY = "" +PROJECT_ROOT = "C:/Users/Administrator/Desktop\静态分析问题单过滤/gs-zk和扫描结果/PrjAttCtrlMng" # 替换为实际项目路径 +VECTOR_DB_PATH = "satellite_rag.faiss" +METADATA_PATH = "satellite_rag_metadata.json" +MAX_CODE_LENGTH = 800 # 代码片段最大长度 + + +# === 全局状态 === +FUNCTION_CALL_GRAPH = defaultdict(list) +FILE_DEPENDENCIES = defaultdict(list) +CALLED_BY_GRAPH = defaultdict(list) # 反向图:callee -> [caller] +CPP_LANGUAGE = Language(tree_sitter_cpp.language()) +parser = Parser() +parser.language=CPP_LANGUAGE +CPP_EXTENSIONS = {'.c', '.cpp', '.h', '.hpp', '.tcc'} +IGNORE_DIRS = {'build', 'cmake-build', '.git', 'vendor', 'lib', 'external','Debug'} + +def generate_code_summary(func_name, comment, logic, code_snippet, file_path): + called_functions = FUNCTION_CALL_GRAPH.get(func_name, []) + caller_functions = CALLED_BY_GRAPH.get(func_name, []) # ← 直接获取,O(1) + included_headers = FILE_DEPENDENCIES.get(file_path, []) + + prompt = f""" +你是一名资深的航天软件工程师,请总结以下C++函数的核心功能,**必须严格包含以下6点**: + +1. 函数流程与逻辑: +{logic} + +2. 函数目的: +→ 结合航天术语总结,必须根据上述“函数流程与逻辑”推导生成,不得依赖假设! + +3. 输入参数(名称、类型、作用,用括号列出) + +4. 返回值(类型和含义) + +5. 与其他函数的关联关系: + - 被调用的函数: {', '.join(called_functions) or '无'} + - 调用此函数的函数: {', '.join(caller_functions) or '无'} + +6. 与跨文件关联的函数(头文件): {', '.join(included_headers) or '无'} + +--- +函数名: {func_name} +注释: {comment or '无'} +代码片段(截断至{MAX_CODE_LENGTH}字符): +{code_snippet[:MAX_CODE_LENGTH]} + +输出格式: "功能: [按上述6点组织的连贯段落,不要编号]" +""" + + headers = {"Authorization": f"Bearer {QWEN_API_KEY}", "Content-Type": "application/json"} + payload = { + "model": "qwen-max", + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.1 + } + + response = requests.post( + "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions", + headers=headers, + json=payload + ) + response.raise_for_status() + summary = response.json()['choices'][0]['message']['content'].strip() + return summary.replace("功能: ", "") +def generate_code_logic(func_name, comment, code_snippet, file_path): + prompt = f""" + 你是一名资深的航天软件工程师,请生成以下C++函数的markdown格式的函数流程图,并总结函数的核心流程逻辑。 + """ + # 你的提示词保持不变 + + headers = { + "Authorization": f"Bearer {QWEN_API_KEY}", + "Content-Type": "application/json" + } + payload = { + "model": "qwen-max", + "messages": [ # ← 直接顶层字段,无 "input" 包裹 + {"role": "user", "content": prompt} + ], + "temperature": 0.1 # 可选:提高摘要一致性 + } + + # 使用 OpenAI 兼容端点 + response = requests.post( + "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions", + headers=headers, + json=payload + ) + response.raise_for_status() + result = response.json() + + logic = result['choices'][0]['message']['content'].strip() # ← 注意路径变化! + return logic +def safe_decode(b: bytes) -> str: + for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']: + try: + return b.decode(encoding) + except UnicodeDecodeError: + continue + return b.decode('utf-8', errors='replace') + +# === 严格保留的代码提取辅助函数 === +def extract_comment_before(node, source_lines): + """严格保留注释提取逻辑(已验证正确)""" + start_line = node.start_point[0] + comment_lines = [] + + for i in range(start_line - 1, max(0, start_line - 10), -1): + line = safe_decode(source_lines[i]).strip() + if line.startswith('//') or line.startswith('/*'): + comment_lines.append(line) + + return " ".join(comment_lines[::-1]) or "无注释" + + +def extract_code_snippet(tree, node, source_lines): + """严格保留代码片段提取逻辑(已验证正确)""" + start_line = node.start_point[0] + end_line = node.end_point[0] + + code_lines = [] + for i in range(start_line, end_line + 1): + if i < len(source_lines): + code_lines.append(safe_decode(source_lines[i])) + + return "\n".join(code_lines) + + +# === 严格保留的向量化嵌入 === +def get_qwen_embedding(text: str) -> np.ndarray: + """ + 调用 DashScope 的 text-embedding-v4 模型生成 1024 维文本嵌入向量。 + 使用 OpenAI 兼容 SDK 方式调用。 + """ + try: + client = OpenAI( + api_key=QWEN_API_KEY, + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" + ) + response = client.embeddings.create( + model="text-embedding-v4", + input=text # OpenAI SDK 支持单个字符串或字符串列表 + ) + embedding = response.data[0].embedding + return np.array(embedding, dtype='float32') + except Exception as e: + # 假设 logger 已定义;若未定义,可临时用 print 替代 + try: + logger.error(f"获取嵌入失败: {e}") + except NameError: + print(f"获取嵌入失败: {e}") + return np.zeros(1024, dtype='float32') +def extract_function_name(func_def_node): + """ + 从 function_definition 节点中提取完整的函数名(含命名空间/类作用域) + """ + declarator = func_def_node.child_by_field_name("declarator") + if not declarator: + return "" + + # 处理 A::B::func 这种形式:可能是 field_expression 嵌套 + def get_qualified_name(node): + if node.type == "identifier": + return node.text.decode('utf-8') + elif node.type == "field_expression": + # left.field -> left 可能是 identifier 或 field_expression + left = node.child_by_field_name("argument") or node.children[0] + right = node.child_by_field_name("field") or node.children[-1] + left_name = get_qualified_name(left) if left else "" + right_name = get_qualified_name(right) if right else "" + return f"{left_name}::{right_name}" if left_name and right_name else (left_name or right_name) + elif node.type == "function_declarator": + # 函数声明器:取其 declarator 部分 + func_declarator = node.child_by_field_name("declarator") + if func_declarator: + return get_qualified_name(func_declarator) + # 兜底:直接取文本 + return node.text.decode('utf-8').split('(')[0].strip() + + return get_qualified_name(declarator) +def extract_qualified_name_from_field(field_node): + """递归解析 field_expression,如 A::B::foo""" + def _extract(n): + if n.type == "identifier": + return n.text.decode('utf-8') + elif n.type == "field_expression": + # left.field -> argument.field + left = n.child_by_field_name("argument") or (n.children[0] if n.children else None) + right = n.child_by_field_name("field") or (n.children[-1] if n.children else None) + left_str = _extract(left) if left else "" + right_str = _extract(right) if right else "" + if left_str and right_str: + return f"{left_str}::{right_str}" + return left_str or right_str + else: + # 兜底:取文本,截断参数列表 + text = n.text.decode('utf-8') + return text.split('(')[0].strip() + return _extract(field_node) + + +def extract_full_function_name(func_def_node): + """ + 从 function_definition 节点中提取干净的函数名(支持 A::foo) + """ + declarator = func_def_node.child_by_field_name("declarator") + if not declarator: + return "" + + # 情况1: declarator 是 function_declarator(最常见) + if declarator.type == "function_declarator": + inner_declarator = declarator.child_by_field_name("declarator") + if not inner_declarator: + return "" + if inner_declarator.type == "identifier": + return inner_declarator.text.decode('utf-8') + elif inner_declarator.type == "field_expression": + return extract_qualified_name_from_field(inner_declarator) + else: + # 兜底处理(如数组、指针等,但函数名通常不会这样) + raw = inner_declarator.text.decode('utf-8') + return raw.split('(')[0].strip() + + # 情况2: declarator 直接是 identifier(旧式C或简单情况) + elif declarator.type == "identifier": + return declarator.text.decode('utf-8') + + # 情况3: 其他类型(如 pointer_declarator),尝试兜底 + else: + raw = declarator.text.decode('utf-8') + return raw.split('(')[0].strip() +def traverse_type(node, target_type): + """递归遍历 AST 节点,查找指定类型""" + if node.type == target_type: + yield node + for child in node.children: + yield from traverse_type(child, target_type) +# === 主处理流程(仅优化关联关系处理)=== +def build_rag_database(): + """严格保留Tree-sitter/Qwen调用流程,仅优化关联关系处理""" + global FUNCTION_CALL_GRAPH, FILE_DEPENDENCIES, CALLED_BY_GRAPH + + # 清空旧图(避免多次调用累积) + FUNCTION_CALL_GRAPH.clear() + FILE_DEPENDENCIES.clear() + CALLED_BY_GRAPH.clear() + + dimension = 1024 + index = faiss.IndexFlatL2(dimension) + metadata_list = [] + + import re + + # 全局图(确保在阶段1前清空) + FUNCTION_CALL_GRAPH = {} + CALLED_BY_GRAPH = {} + FILE_DEPENDENCIES = {} + + def extract_callee_name(func_node): + """从 call_expression 的 function 节点中提取函数名标识符""" + if func_node.type == "identifier": + return func_node.text.decode('utf-8') + elif func_node.type == "field_expression": # 如 A::func 或 a.b + field = func_node.child_by_field_name("field") + if field and field.type == "identifier": + return field.text.decode('utf-8') + # 递归处理复杂命名空间 + text = func_node.text.decode('utf-8') + parts = re.split(r'[.:]+', text) + return parts[-1].split('(')[0].strip() if parts else "" + elif func_node.type == "member_expression": # C++ 风格: obj->method 或 obj.method + member = func_node.child_by_field_name("member") + if member and member.type == "identifier": + return member.text.decode('utf-8') + # 兜底 + text = func_node.text.decode('utf-8') + for sep in ['->', '.']: + if sep in text: + return text.split(sep)[-1].split('(')[0].strip() + return text.split('(')[0].strip() + else: + # 兜底:从文本中提取最可能的函数名 + text = func_node.text.decode('utf-8') + # 移除模板 <...> + text = re.sub(r'<[^>]*>', '', text) + # 取括号前部分 + base = text.split('(')[0].strip() + # 处理 a::b::c -> c + for sep in ['::', '->', '.']: + if sep in base: + base = base.split(sep)[-1] + # 简单合法性检查 + if base and base[0].isalpha() and all(c.isalnum() or c in '_$' for c in base): + return base + return "" + + def collect_call_expressions(node): + """递归遍历 AST 节点,收集所有 call_expression 中的 callee 名称(去重)""" + calls = set() + + def traverse(n): + if n.type == "call_expression": + func_node = n.child_by_field_name("function") + if func_node: + callee = extract_callee_name(func_node) + if callee != "": + calls.add(callee) + # 递归子节点 + for child in n.children: + traverse(child) + + traverse(node) + return list(calls) + + # ======================== + # 阶段1:构建调用图与依赖图 + # ======================== + logger.info("阶段1: 构建函数调用图和文件依赖图") + + total_funcs = 0 + + for root, _, files in os.walk(PROJECT_ROOT): + if any(ignore_dir in Path(root).parts for ignore_dir in IGNORE_DIRS): + continue + for file in files: + if Path(file).suffix.lower() in CPP_EXTENSIONS: + file_path = os.path.join(root, file) + logger.debug(f"解析文件构建关联图: {file_path}") + try: + with open(file_path, 'rb') as f: + source_bytes = f.read() + tree = parser.parse(source_bytes) + + # --- 1. 收集头文件依赖 --- + includes = [] + for node in tree.root_node.named_children: + if node.type == "preproc_include": + path_node = node.child_by_field_name("path") + if path_node: + include_text = path_node.text.decode('utf-8').strip('"<>') + includes.append(include_text) + FILE_DEPENDENCIES[file_path] = includes + + # --- 2. 遍历顶层节点,找函数定义 --- + for node in tree.root_node.named_children: + if node.type == "function_definition": + func_name = extract_full_function_name(node) + if func_name == "": + continue + + total_funcs += 1 + FUNCTION_CALL_GRAPH[func_name] = [] # 初始化 + + # 获取函数体(body) + body = node.child_by_field_name("body") + if body: + callees = collect_call_expressions(body) + FUNCTION_CALL_GRAPH[func_name] = callees + + except Exception as e: + logger.error(f"解析文件失败 {file_path}: {e}", exc_info=True) + + # --- 3. 构建反向调用图(谁调用了我)--- + CALLED_BY_GRAPH.clear() + for caller, callees in FUNCTION_CALL_GRAPH.items(): + for callee in callees: + if callee not in CALLED_BY_GRAPH: + CALLED_BY_GRAPH[callee] = [] + CALLED_BY_GRAPH[callee].append(caller) + + logger.info(f"阶段1完成。共识别 {total_funcs} 个函数,{len(FILE_DEPENDENCIES)} 个文件依赖。") + # === 阶段2: 生成摘要并构建RAG === + logger.info("阶段2: 生成函数摘要并构建RAG") + for root, _, files in os.walk(PROJECT_ROOT): + if any(ignore_dir in Path(root).parts for ignore_dir in IGNORE_DIRS): + continue + for file in files: + if Path(file).suffix.lower() in CPP_EXTENSIONS: + file_path = os.path.join(root, file) + logger.debug(f"处理文件: {file_path}") + try: + with open(file_path, 'rb') as f: + source = f.read() + tree = parser.parse(source) + source_lines = source.split(b'\n') + + for node in tree.root_node.named_children: + if node.type == "function_definition": + # === 关键修正:使用统一的函数名提取逻辑 === + func_name = extract_full_function_name(node) + if func_name == "": + logger.warning(f"无法解析函数名,跳过节点: {node.text[:100]}...") + continue + + # 提取前置注释(保持不变) + comment = extract_comment_before(node, source_lines) + + # 提取代码片段(保持不变) + func_code = extract_code_snippet(tree, node, source_lines) + print(func_code) + # 生成摘要(保持不变) + logic = generate_code_logic(func_name, comment, func_code, file_path) + summary = generate_code_summary(func_name, comment,logic,func_code, file_path) + print(summary) + # 从全局图中获取关系(确保 func_name 与阶段1一致!) + called_functions = FUNCTION_CALL_GRAPH.get(func_name, []) + caller_functions = CALLED_BY_GRAPH.get(func_name, []) + included_headers = FILE_DEPENDENCIES.get(file_path, []) + + # 构建上下文文本 + context_text = ( + f"【实体类型】函数\n" + f"【函数名】{func_name}\n" + f"【所在文件】{os.path.basename(file_path)}\n" + f"【代码注释】{comment.strip() if comment else '无'}\n" + f"【功能摘要】{summary}\n" + f"【流程逻辑】{logic}\n" + f"【调用的函数】{', '.join(called_functions) if called_functions else '无'}\n" + f"【被以下函数调用】{', '.join(caller_functions) if caller_functions else '无'}\n" + f"【包含的头文件】{', '.join(included_headers) if included_headers else '无'}" + ) + + # 生成嵌入向量 + vector = get_qwen_embedding(context_text) + + # 构建元数据 + metadata = { + "file": file_path, + "line": node.start_point[0] + 1, + "type": "function", + "name": func_name, + "summary": summary, + "logic": logic, + "calls": called_functions, + "called_by": caller_functions, + "includes": included_headers, + "comment": comment or "" + } + + # 添加到向量库 + index.add(np.array([vector], dtype='float32')) + metadata_list.append(metadata) + + except Exception as e: + logger.error(f"处理文件 {file_path} 失败: {str(e)}", exc_info=True) + # 保存结果 + faiss.write_index(index, VECTOR_DB_PATH) + with open(METADATA_PATH, 'w', encoding='utf-8') as f: + json.dump(metadata_list, f, indent=2, ensure_ascii=False) + + logger.info(f"知识库构建完成! 总实体: {len(metadata_list)}") + logger.info(f"向量数据库保存至: {VECTOR_DB_PATH}") + logger.info(f"元数据保存至: {METADATA_PATH}") + + return index, metadata_list # 如果你确实需要返回,可以放这里 + + +if __name__ == "__main__": + build_rag_database() \ No newline at end of file