Files
CodeKnowledgeBuild/Knowledge_Build.py

498 lines
21 KiB
Python
Raw Normal View History

2026-02-04 14:44:45 +08:00
#!/usr/bin/env python3
"""
卫星星务软件代码RAG知识库构建工具严格保留Tree-sitter/Qwen调用方式
- 仅优化关联关系处理逻辑
- 保留原始Tree-sitter解析和Qwen-V3调用流程
- 通过全局关联图确保关联关系正确性
- 适用于卫星软件项目C/C++
"""
import os
import re
import json
from collections import defaultdict
from pathlib import Path
import numpy as np
import faiss
import tree_sitter
import tree_sitter_c
import tree_sitter_cpp
from openai import OpenAI
from tree_sitter import Language, Parser
import requests
from tqdm import tqdm
import logging
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("rag_builder.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# === 配置常量 ===
QWEN_API_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
QWEN_API_KEY = ""
PROJECT_ROOT = "C:/Users/Administrator/Desktop\静态分析问题单过滤/gs-zk和扫描结果/PrjAttCtrlMng" # 替换为实际项目路径
VECTOR_DB_PATH = "satellite_rag.faiss"
METADATA_PATH = "satellite_rag_metadata.json"
MAX_CODE_LENGTH = 800 # 代码片段最大长度
# === 全局状态 ===
FUNCTION_CALL_GRAPH = defaultdict(list)
FILE_DEPENDENCIES = defaultdict(list)
CALLED_BY_GRAPH = defaultdict(list) # 反向图callee -> [caller]
CPP_LANGUAGE = Language(tree_sitter_cpp.language())
parser = Parser()
parser.language=CPP_LANGUAGE
CPP_EXTENSIONS = {'.c', '.cpp', '.h', '.hpp', '.tcc'}
IGNORE_DIRS = {'build', 'cmake-build', '.git', 'vendor', 'lib', 'external','Debug'}
def generate_code_summary(func_name, comment, logic, code_snippet, file_path):
called_functions = FUNCTION_CALL_GRAPH.get(func_name, [])
caller_functions = CALLED_BY_GRAPH.get(func_name, []) # ← 直接获取O(1)
included_headers = FILE_DEPENDENCIES.get(file_path, [])
prompt = f"""
你是一名资深的航天软件工程师请总结以下C++函数的核心功能**必须严格包含以下6点**
1. 函数流程与逻辑:
{logic}
2. 函数目的:
结合航天术语总结必须根据上述函数流程与逻辑推导生成不得依赖假设
3. 输入参数名称类型作用用括号列出
4. 返回值类型和含义
5. 与其他函数的关联关系
- 被调用的函数: {', '.join(called_functions) or ''}
- 调用此函数的函数: {', '.join(caller_functions) or ''}
6. 与跨文件关联的函数头文件: {', '.join(included_headers) or ''}
---
函数名: {func_name}
注释: {comment or ''}
代码片段截断至{MAX_CODE_LENGTH}字符:
{code_snippet[:MAX_CODE_LENGTH]}
输出格式: "功能: [按上述6点组织的连贯段落不要编号]"
"""
headers = {"Authorization": f"Bearer {QWEN_API_KEY}", "Content-Type": "application/json"}
payload = {
"model": "qwen-max",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.1
}
response = requests.post(
"https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
headers=headers,
json=payload
)
response.raise_for_status()
summary = response.json()['choices'][0]['message']['content'].strip()
return summary.replace("功能: ", "")
def generate_code_logic(func_name, comment, code_snippet, file_path):
prompt = f"""
你是一名资深的航天软件工程师请生成以下C++函数的markdown格式的函数流程图并总结函数的核心流程逻辑
"""
# 你的提示词保持不变
headers = {
"Authorization": f"Bearer {QWEN_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": "qwen-max",
"messages": [ # ← 直接顶层字段,无 "input" 包裹
{"role": "user", "content": prompt}
],
"temperature": 0.1 # 可选:提高摘要一致性
}
# 使用 OpenAI 兼容端点
response = requests.post(
"https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
headers=headers,
json=payload
)
response.raise_for_status()
result = response.json()
logic = result['choices'][0]['message']['content'].strip() # ← 注意路径变化!
return logic
def safe_decode(b: bytes) -> str:
for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']:
try:
return b.decode(encoding)
except UnicodeDecodeError:
continue
return b.decode('utf-8', errors='replace')
# === 严格保留的代码提取辅助函数 ===
def extract_comment_before(node, source_lines):
"""严格保留注释提取逻辑(已验证正确)"""
start_line = node.start_point[0]
comment_lines = []
for i in range(start_line - 1, max(0, start_line - 10), -1):
line = safe_decode(source_lines[i]).strip()
if line.startswith('//') or line.startswith('/*'):
comment_lines.append(line)
return " ".join(comment_lines[::-1]) or "无注释"
def extract_code_snippet(tree, node, source_lines):
"""严格保留代码片段提取逻辑(已验证正确)"""
start_line = node.start_point[0]
end_line = node.end_point[0]
code_lines = []
for i in range(start_line, end_line + 1):
if i < len(source_lines):
code_lines.append(safe_decode(source_lines[i]))
return "\n".join(code_lines)
# === 严格保留的向量化嵌入 ===
def get_qwen_embedding(text: str) -> np.ndarray:
"""
调用 DashScope text-embedding-v4 模型生成 1024 维文本嵌入向量
使用 OpenAI 兼容 SDK 方式调用
"""
try:
client = OpenAI(
api_key=QWEN_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
)
response = client.embeddings.create(
model="text-embedding-v4",
input=text # OpenAI SDK 支持单个字符串或字符串列表
)
embedding = response.data[0].embedding
return np.array(embedding, dtype='float32')
except Exception as e:
# 假设 logger 已定义;若未定义,可临时用 print 替代
try:
logger.error(f"获取嵌入失败: {e}")
except NameError:
print(f"获取嵌入失败: {e}")
return np.zeros(1024, dtype='float32')
def extract_function_name(func_def_node):
"""
function_definition 节点中提取完整的函数名含命名空间/类作用域
"""
declarator = func_def_node.child_by_field_name("declarator")
if not declarator:
return "<unknown>"
# 处理 A::B::func 这种形式:可能是 field_expression 嵌套
def get_qualified_name(node):
if node.type == "identifier":
return node.text.decode('utf-8')
elif node.type == "field_expression":
# left.field -> left 可能是 identifier 或 field_expression
left = node.child_by_field_name("argument") or node.children[0]
right = node.child_by_field_name("field") or node.children[-1]
left_name = get_qualified_name(left) if left else ""
right_name = get_qualified_name(right) if right else ""
return f"{left_name}::{right_name}" if left_name and right_name else (left_name or right_name)
elif node.type == "function_declarator":
# 函数声明器:取其 declarator 部分
func_declarator = node.child_by_field_name("declarator")
if func_declarator:
return get_qualified_name(func_declarator)
# 兜底:直接取文本
return node.text.decode('utf-8').split('(')[0].strip()
return get_qualified_name(declarator)
def extract_qualified_name_from_field(field_node):
"""递归解析 field_expression如 A::B::foo"""
def _extract(n):
if n.type == "identifier":
return n.text.decode('utf-8')
elif n.type == "field_expression":
# left.field -> argument.field
left = n.child_by_field_name("argument") or (n.children[0] if n.children else None)
right = n.child_by_field_name("field") or (n.children[-1] if n.children else None)
left_str = _extract(left) if left else ""
right_str = _extract(right) if right else ""
if left_str and right_str:
return f"{left_str}::{right_str}"
return left_str or right_str
else:
# 兜底:取文本,截断参数列表
text = n.text.decode('utf-8')
return text.split('(')[0].strip()
return _extract(field_node)
def extract_full_function_name(func_def_node):
"""
function_definition 节点中提取干净的函数名支持 A::foo
"""
declarator = func_def_node.child_by_field_name("declarator")
if not declarator:
return "<unknown>"
# 情况1: declarator 是 function_declarator最常见
if declarator.type == "function_declarator":
inner_declarator = declarator.child_by_field_name("declarator")
if not inner_declarator:
return "<unknown>"
if inner_declarator.type == "identifier":
return inner_declarator.text.decode('utf-8')
elif inner_declarator.type == "field_expression":
return extract_qualified_name_from_field(inner_declarator)
else:
# 兜底处理(如数组、指针等,但函数名通常不会这样)
raw = inner_declarator.text.decode('utf-8')
return raw.split('(')[0].strip()
# 情况2: declarator 直接是 identifier旧式C或简单情况
elif declarator.type == "identifier":
return declarator.text.decode('utf-8')
# 情况3: 其他类型(如 pointer_declarator尝试兜底
else:
raw = declarator.text.decode('utf-8')
return raw.split('(')[0].strip()
def traverse_type(node, target_type):
"""递归遍历 AST 节点,查找指定类型"""
if node.type == target_type:
yield node
for child in node.children:
yield from traverse_type(child, target_type)
# === 主处理流程(仅优化关联关系处理)===
def build_rag_database():
"""严格保留Tree-sitter/Qwen调用流程仅优化关联关系处理"""
global FUNCTION_CALL_GRAPH, FILE_DEPENDENCIES, CALLED_BY_GRAPH
# 清空旧图(避免多次调用累积)
FUNCTION_CALL_GRAPH.clear()
FILE_DEPENDENCIES.clear()
CALLED_BY_GRAPH.clear()
dimension = 1024
index = faiss.IndexFlatL2(dimension)
metadata_list = []
import re
# 全局图确保在阶段1前清空
FUNCTION_CALL_GRAPH = {}
CALLED_BY_GRAPH = {}
FILE_DEPENDENCIES = {}
def extract_callee_name(func_node):
"""从 call_expression 的 function 节点中提取函数名标识符"""
if func_node.type == "identifier":
return func_node.text.decode('utf-8')
elif func_node.type == "field_expression": # 如 A::func 或 a.b
field = func_node.child_by_field_name("field")
if field and field.type == "identifier":
return field.text.decode('utf-8')
# 递归处理复杂命名空间
text = func_node.text.decode('utf-8')
parts = re.split(r'[.:]+', text)
return parts[-1].split('(')[0].strip() if parts else "<unknown>"
elif func_node.type == "member_expression": # C++ 风格: obj->method 或 obj.method
member = func_node.child_by_field_name("member")
if member and member.type == "identifier":
return member.text.decode('utf-8')
# 兜底
text = func_node.text.decode('utf-8')
for sep in ['->', '.']:
if sep in text:
return text.split(sep)[-1].split('(')[0].strip()
return text.split('(')[0].strip()
else:
# 兜底:从文本中提取最可能的函数名
text = func_node.text.decode('utf-8')
# 移除模板 <...>
text = re.sub(r'<[^>]*>', '', text)
# 取括号前部分
base = text.split('(')[0].strip()
# 处理 a::b::c -> c
for sep in ['::', '->', '.']:
if sep in base:
base = base.split(sep)[-1]
# 简单合法性检查
if base and base[0].isalpha() and all(c.isalnum() or c in '_$' for c in base):
return base
return "<unknown>"
def collect_call_expressions(node):
"""递归遍历 AST 节点,收集所有 call_expression 中的 callee 名称(去重)"""
calls = set()
def traverse(n):
if n.type == "call_expression":
func_node = n.child_by_field_name("function")
if func_node:
callee = extract_callee_name(func_node)
if callee != "<unknown>":
calls.add(callee)
# 递归子节点
for child in n.children:
traverse(child)
traverse(node)
return list(calls)
# ========================
# 阶段1构建调用图与依赖图
# ========================
logger.info("阶段1: 构建函数调用图和文件依赖图")
total_funcs = 0
for root, _, files in os.walk(PROJECT_ROOT):
if any(ignore_dir in Path(root).parts for ignore_dir in IGNORE_DIRS):
continue
for file in files:
if Path(file).suffix.lower() in CPP_EXTENSIONS:
file_path = os.path.join(root, file)
logger.debug(f"解析文件构建关联图: {file_path}")
try:
with open(file_path, 'rb') as f:
source_bytes = f.read()
tree = parser.parse(source_bytes)
# --- 1. 收集头文件依赖 ---
includes = []
for node in tree.root_node.named_children:
if node.type == "preproc_include":
path_node = node.child_by_field_name("path")
if path_node:
include_text = path_node.text.decode('utf-8').strip('"<>')
includes.append(include_text)
FILE_DEPENDENCIES[file_path] = includes
# --- 2. 遍历顶层节点,找函数定义 ---
for node in tree.root_node.named_children:
if node.type == "function_definition":
func_name = extract_full_function_name(node)
if func_name == "<unknown>":
continue
total_funcs += 1
FUNCTION_CALL_GRAPH[func_name] = [] # 初始化
# 获取函数体body
body = node.child_by_field_name("body")
if body:
callees = collect_call_expressions(body)
FUNCTION_CALL_GRAPH[func_name] = callees
except Exception as e:
logger.error(f"解析文件失败 {file_path}: {e}", exc_info=True)
# --- 3. 构建反向调用图(谁调用了我)---
CALLED_BY_GRAPH.clear()
for caller, callees in FUNCTION_CALL_GRAPH.items():
for callee in callees:
if callee not in CALLED_BY_GRAPH:
CALLED_BY_GRAPH[callee] = []
CALLED_BY_GRAPH[callee].append(caller)
logger.info(f"阶段1完成。共识别 {total_funcs} 个函数,{len(FILE_DEPENDENCIES)} 个文件依赖。")
# === 阶段2: 生成摘要并构建RAG ===
logger.info("阶段2: 生成函数摘要并构建RAG")
for root, _, files in os.walk(PROJECT_ROOT):
if any(ignore_dir in Path(root).parts for ignore_dir in IGNORE_DIRS):
continue
for file in files:
if Path(file).suffix.lower() in CPP_EXTENSIONS:
file_path = os.path.join(root, file)
logger.debug(f"处理文件: {file_path}")
try:
with open(file_path, 'rb') as f:
source = f.read()
tree = parser.parse(source)
source_lines = source.split(b'\n')
for node in tree.root_node.named_children:
if node.type == "function_definition":
# === 关键修正:使用统一的函数名提取逻辑 ===
func_name = extract_full_function_name(node)
if func_name == "<unknown>":
logger.warning(f"无法解析函数名,跳过节点: {node.text[:100]}...")
continue
# 提取前置注释(保持不变)
comment = extract_comment_before(node, source_lines)
# 提取代码片段(保持不变)
func_code = extract_code_snippet(tree, node, source_lines)
print(func_code)
# 生成摘要(保持不变)
logic = generate_code_logic(func_name, comment, func_code, file_path)
summary = generate_code_summary(func_name, comment,logic,func_code, file_path)
print(summary)
# 从全局图中获取关系(确保 func_name 与阶段1一致
called_functions = FUNCTION_CALL_GRAPH.get(func_name, [])
caller_functions = CALLED_BY_GRAPH.get(func_name, [])
included_headers = FILE_DEPENDENCIES.get(file_path, [])
# 构建上下文文本
context_text = (
f"【实体类型】函数\n"
f"【函数名】{func_name}\n"
f"【所在文件】{os.path.basename(file_path)}\n"
f"【代码注释】{comment.strip() if comment else ''}\n"
f"【功能摘要】{summary}\n"
f"【流程逻辑】{logic}\n"
f"【调用的函数】{', '.join(called_functions) if called_functions else ''}\n"
f"【被以下函数调用】{', '.join(caller_functions) if caller_functions else ''}\n"
f"【包含的头文件】{', '.join(included_headers) if included_headers else ''}"
)
# 生成嵌入向量
vector = get_qwen_embedding(context_text)
# 构建元数据
metadata = {
"file": file_path,
"line": node.start_point[0] + 1,
"type": "function",
"name": func_name,
"summary": summary,
"logic": logic,
"calls": called_functions,
"called_by": caller_functions,
"includes": included_headers,
"comment": comment or ""
}
# 添加到向量库
index.add(np.array([vector], dtype='float32'))
metadata_list.append(metadata)
except Exception as e:
logger.error(f"处理文件 {file_path} 失败: {str(e)}", exc_info=True)
# 保存结果
faiss.write_index(index, VECTOR_DB_PATH)
with open(METADATA_PATH, 'w', encoding='utf-8') as f:
json.dump(metadata_list, f, indent=2, ensure_ascii=False)
logger.info(f"知识库构建完成! 总实体: {len(metadata_list)}")
logger.info(f"向量数据库保存至: {VECTOR_DB_PATH}")
logger.info(f"元数据保存至: {METADATA_PATH}")
return index, metadata_list # 如果你确实需要返回,可以放这里
if __name__ == "__main__":
build_rag_database()