498 lines
21 KiB
Python
498 lines
21 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
卫星星务软件代码RAG知识库构建工具(严格保留Tree-sitter/Qwen调用方式)
|
||
- 仅优化关联关系处理逻辑
|
||
- 保留原始Tree-sitter解析和Qwen-V3调用流程
|
||
- 通过全局关联图确保关联关系正确性
|
||
- 适用于卫星软件项目(C/C++)
|
||
"""
|
||
|
||
import os
|
||
import re
|
||
import json
|
||
from collections import defaultdict
|
||
from pathlib import Path
|
||
import numpy as np
|
||
import faiss
|
||
import tree_sitter
|
||
import tree_sitter_c
|
||
import tree_sitter_cpp
|
||
from openai import OpenAI
|
||
from tree_sitter import Language, Parser
|
||
import requests
|
||
from tqdm import tqdm
|
||
import logging
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler("rag_builder.log"),
|
||
logging.StreamHandler()
|
||
]
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# === 配置常量 ===
|
||
QWEN_API_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
QWEN_API_KEY = ""
|
||
PROJECT_ROOT = "C:/Users/Administrator/Desktop\静态分析问题单过滤/gs-zk和扫描结果/PrjAttCtrlMng" # 替换为实际项目路径
|
||
VECTOR_DB_PATH = "satellite_rag.faiss"
|
||
METADATA_PATH = "satellite_rag_metadata.json"
|
||
MAX_CODE_LENGTH = 800 # 代码片段最大长度
|
||
|
||
|
||
# === 全局状态 ===
|
||
FUNCTION_CALL_GRAPH = defaultdict(list)
|
||
FILE_DEPENDENCIES = defaultdict(list)
|
||
CALLED_BY_GRAPH = defaultdict(list) # 反向图:callee -> [caller]
|
||
CPP_LANGUAGE = Language(tree_sitter_cpp.language())
|
||
parser = Parser()
|
||
parser.language=CPP_LANGUAGE
|
||
CPP_EXTENSIONS = {'.c', '.cpp', '.h', '.hpp', '.tcc'}
|
||
IGNORE_DIRS = {'build', 'cmake-build', '.git', 'vendor', 'lib', 'external','Debug'}
|
||
|
||
def generate_code_summary(func_name, comment, logic, code_snippet, file_path):
|
||
called_functions = FUNCTION_CALL_GRAPH.get(func_name, [])
|
||
caller_functions = CALLED_BY_GRAPH.get(func_name, []) # ← 直接获取,O(1)
|
||
included_headers = FILE_DEPENDENCIES.get(file_path, [])
|
||
|
||
prompt = f"""
|
||
你是一名资深的航天软件工程师,请总结以下C++函数的核心功能,**必须严格包含以下6点**:
|
||
|
||
1. 函数流程与逻辑:
|
||
{logic}
|
||
|
||
2. 函数目的:
|
||
→ 结合航天术语总结,必须根据上述“函数流程与逻辑”推导生成,不得依赖假设!
|
||
|
||
3. 输入参数(名称、类型、作用,用括号列出)
|
||
|
||
4. 返回值(类型和含义)
|
||
|
||
5. 与其他函数的关联关系:
|
||
- 被调用的函数: {', '.join(called_functions) or '无'}
|
||
- 调用此函数的函数: {', '.join(caller_functions) or '无'}
|
||
|
||
6. 与跨文件关联的函数(头文件): {', '.join(included_headers) or '无'}
|
||
|
||
---
|
||
函数名: {func_name}
|
||
注释: {comment or '无'}
|
||
代码片段(截断至{MAX_CODE_LENGTH}字符):
|
||
{code_snippet[:MAX_CODE_LENGTH]}
|
||
|
||
输出格式: "功能: [按上述6点组织的连贯段落,不要编号]"
|
||
"""
|
||
|
||
headers = {"Authorization": f"Bearer {QWEN_API_KEY}", "Content-Type": "application/json"}
|
||
payload = {
|
||
"model": "qwen-max",
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"temperature": 0.1
|
||
}
|
||
|
||
response = requests.post(
|
||
"https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
|
||
headers=headers,
|
||
json=payload
|
||
)
|
||
response.raise_for_status()
|
||
summary = response.json()['choices'][0]['message']['content'].strip()
|
||
return summary.replace("功能: ", "")
|
||
def generate_code_logic(func_name, comment, code_snippet, file_path):
|
||
prompt = f"""
|
||
你是一名资深的航天软件工程师,请生成以下C++函数的markdown格式的函数流程图,并总结函数的核心流程逻辑。
|
||
"""
|
||
# 你的提示词保持不变
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {QWEN_API_KEY}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
payload = {
|
||
"model": "qwen-max",
|
||
"messages": [ # ← 直接顶层字段,无 "input" 包裹
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
"temperature": 0.1 # 可选:提高摘要一致性
|
||
}
|
||
|
||
# 使用 OpenAI 兼容端点
|
||
response = requests.post(
|
||
"https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
|
||
headers=headers,
|
||
json=payload
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
logic = result['choices'][0]['message']['content'].strip() # ← 注意路径变化!
|
||
return logic
|
||
def safe_decode(b: bytes) -> str:
|
||
for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']:
|
||
try:
|
||
return b.decode(encoding)
|
||
except UnicodeDecodeError:
|
||
continue
|
||
return b.decode('utf-8', errors='replace')
|
||
|
||
# === 严格保留的代码提取辅助函数 ===
|
||
def extract_comment_before(node, source_lines):
|
||
"""严格保留注释提取逻辑(已验证正确)"""
|
||
start_line = node.start_point[0]
|
||
comment_lines = []
|
||
|
||
for i in range(start_line - 1, max(0, start_line - 10), -1):
|
||
line = safe_decode(source_lines[i]).strip()
|
||
if line.startswith('//') or line.startswith('/*'):
|
||
comment_lines.append(line)
|
||
|
||
return " ".join(comment_lines[::-1]) or "无注释"
|
||
|
||
|
||
def extract_code_snippet(tree, node, source_lines):
|
||
"""严格保留代码片段提取逻辑(已验证正确)"""
|
||
start_line = node.start_point[0]
|
||
end_line = node.end_point[0]
|
||
|
||
code_lines = []
|
||
for i in range(start_line, end_line + 1):
|
||
if i < len(source_lines):
|
||
code_lines.append(safe_decode(source_lines[i]))
|
||
|
||
return "\n".join(code_lines)
|
||
|
||
|
||
# === 严格保留的向量化嵌入 ===
|
||
def get_qwen_embedding(text: str) -> np.ndarray:
|
||
"""
|
||
调用 DashScope 的 text-embedding-v4 模型生成 1024 维文本嵌入向量。
|
||
使用 OpenAI 兼容 SDK 方式调用。
|
||
"""
|
||
try:
|
||
client = OpenAI(
|
||
api_key=QWEN_API_KEY,
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
)
|
||
response = client.embeddings.create(
|
||
model="text-embedding-v4",
|
||
input=text # OpenAI SDK 支持单个字符串或字符串列表
|
||
)
|
||
embedding = response.data[0].embedding
|
||
return np.array(embedding, dtype='float32')
|
||
except Exception as e:
|
||
# 假设 logger 已定义;若未定义,可临时用 print 替代
|
||
try:
|
||
logger.error(f"获取嵌入失败: {e}")
|
||
except NameError:
|
||
print(f"获取嵌入失败: {e}")
|
||
return np.zeros(1024, dtype='float32')
|
||
def extract_function_name(func_def_node):
|
||
"""
|
||
从 function_definition 节点中提取完整的函数名(含命名空间/类作用域)
|
||
"""
|
||
declarator = func_def_node.child_by_field_name("declarator")
|
||
if not declarator:
|
||
return "<unknown>"
|
||
|
||
# 处理 A::B::func 这种形式:可能是 field_expression 嵌套
|
||
def get_qualified_name(node):
|
||
if node.type == "identifier":
|
||
return node.text.decode('utf-8')
|
||
elif node.type == "field_expression":
|
||
# left.field -> left 可能是 identifier 或 field_expression
|
||
left = node.child_by_field_name("argument") or node.children[0]
|
||
right = node.child_by_field_name("field") or node.children[-1]
|
||
left_name = get_qualified_name(left) if left else ""
|
||
right_name = get_qualified_name(right) if right else ""
|
||
return f"{left_name}::{right_name}" if left_name and right_name else (left_name or right_name)
|
||
elif node.type == "function_declarator":
|
||
# 函数声明器:取其 declarator 部分
|
||
func_declarator = node.child_by_field_name("declarator")
|
||
if func_declarator:
|
||
return get_qualified_name(func_declarator)
|
||
# 兜底:直接取文本
|
||
return node.text.decode('utf-8').split('(')[0].strip()
|
||
|
||
return get_qualified_name(declarator)
|
||
def extract_qualified_name_from_field(field_node):
|
||
"""递归解析 field_expression,如 A::B::foo"""
|
||
def _extract(n):
|
||
if n.type == "identifier":
|
||
return n.text.decode('utf-8')
|
||
elif n.type == "field_expression":
|
||
# left.field -> argument.field
|
||
left = n.child_by_field_name("argument") or (n.children[0] if n.children else None)
|
||
right = n.child_by_field_name("field") or (n.children[-1] if n.children else None)
|
||
left_str = _extract(left) if left else ""
|
||
right_str = _extract(right) if right else ""
|
||
if left_str and right_str:
|
||
return f"{left_str}::{right_str}"
|
||
return left_str or right_str
|
||
else:
|
||
# 兜底:取文本,截断参数列表
|
||
text = n.text.decode('utf-8')
|
||
return text.split('(')[0].strip()
|
||
return _extract(field_node)
|
||
|
||
|
||
def extract_full_function_name(func_def_node):
|
||
"""
|
||
从 function_definition 节点中提取干净的函数名(支持 A::foo)
|
||
"""
|
||
declarator = func_def_node.child_by_field_name("declarator")
|
||
if not declarator:
|
||
return "<unknown>"
|
||
|
||
# 情况1: declarator 是 function_declarator(最常见)
|
||
if declarator.type == "function_declarator":
|
||
inner_declarator = declarator.child_by_field_name("declarator")
|
||
if not inner_declarator:
|
||
return "<unknown>"
|
||
if inner_declarator.type == "identifier":
|
||
return inner_declarator.text.decode('utf-8')
|
||
elif inner_declarator.type == "field_expression":
|
||
return extract_qualified_name_from_field(inner_declarator)
|
||
else:
|
||
# 兜底处理(如数组、指针等,但函数名通常不会这样)
|
||
raw = inner_declarator.text.decode('utf-8')
|
||
return raw.split('(')[0].strip()
|
||
|
||
# 情况2: declarator 直接是 identifier(旧式C或简单情况)
|
||
elif declarator.type == "identifier":
|
||
return declarator.text.decode('utf-8')
|
||
|
||
# 情况3: 其他类型(如 pointer_declarator),尝试兜底
|
||
else:
|
||
raw = declarator.text.decode('utf-8')
|
||
return raw.split('(')[0].strip()
|
||
def traverse_type(node, target_type):
|
||
"""递归遍历 AST 节点,查找指定类型"""
|
||
if node.type == target_type:
|
||
yield node
|
||
for child in node.children:
|
||
yield from traverse_type(child, target_type)
|
||
# === 主处理流程(仅优化关联关系处理)===
|
||
def build_rag_database():
|
||
"""严格保留Tree-sitter/Qwen调用流程,仅优化关联关系处理"""
|
||
global FUNCTION_CALL_GRAPH, FILE_DEPENDENCIES, CALLED_BY_GRAPH
|
||
|
||
# 清空旧图(避免多次调用累积)
|
||
FUNCTION_CALL_GRAPH.clear()
|
||
FILE_DEPENDENCIES.clear()
|
||
CALLED_BY_GRAPH.clear()
|
||
|
||
dimension = 1024
|
||
index = faiss.IndexFlatL2(dimension)
|
||
metadata_list = []
|
||
|
||
import re
|
||
|
||
# 全局图(确保在阶段1前清空)
|
||
FUNCTION_CALL_GRAPH = {}
|
||
CALLED_BY_GRAPH = {}
|
||
FILE_DEPENDENCIES = {}
|
||
|
||
def extract_callee_name(func_node):
|
||
"""从 call_expression 的 function 节点中提取函数名标识符"""
|
||
if func_node.type == "identifier":
|
||
return func_node.text.decode('utf-8')
|
||
elif func_node.type == "field_expression": # 如 A::func 或 a.b
|
||
field = func_node.child_by_field_name("field")
|
||
if field and field.type == "identifier":
|
||
return field.text.decode('utf-8')
|
||
# 递归处理复杂命名空间
|
||
text = func_node.text.decode('utf-8')
|
||
parts = re.split(r'[.:]+', text)
|
||
return parts[-1].split('(')[0].strip() if parts else "<unknown>"
|
||
elif func_node.type == "member_expression": # C++ 风格: obj->method 或 obj.method
|
||
member = func_node.child_by_field_name("member")
|
||
if member and member.type == "identifier":
|
||
return member.text.decode('utf-8')
|
||
# 兜底
|
||
text = func_node.text.decode('utf-8')
|
||
for sep in ['->', '.']:
|
||
if sep in text:
|
||
return text.split(sep)[-1].split('(')[0].strip()
|
||
return text.split('(')[0].strip()
|
||
else:
|
||
# 兜底:从文本中提取最可能的函数名
|
||
text = func_node.text.decode('utf-8')
|
||
# 移除模板 <...>
|
||
text = re.sub(r'<[^>]*>', '', text)
|
||
# 取括号前部分
|
||
base = text.split('(')[0].strip()
|
||
# 处理 a::b::c -> c
|
||
for sep in ['::', '->', '.']:
|
||
if sep in base:
|
||
base = base.split(sep)[-1]
|
||
# 简单合法性检查
|
||
if base and base[0].isalpha() and all(c.isalnum() or c in '_$' for c in base):
|
||
return base
|
||
return "<unknown>"
|
||
|
||
def collect_call_expressions(node):
|
||
"""递归遍历 AST 节点,收集所有 call_expression 中的 callee 名称(去重)"""
|
||
calls = set()
|
||
|
||
def traverse(n):
|
||
if n.type == "call_expression":
|
||
func_node = n.child_by_field_name("function")
|
||
if func_node:
|
||
callee = extract_callee_name(func_node)
|
||
if callee != "<unknown>":
|
||
calls.add(callee)
|
||
# 递归子节点
|
||
for child in n.children:
|
||
traverse(child)
|
||
|
||
traverse(node)
|
||
return list(calls)
|
||
|
||
# ========================
|
||
# 阶段1:构建调用图与依赖图
|
||
# ========================
|
||
logger.info("阶段1: 构建函数调用图和文件依赖图")
|
||
|
||
total_funcs = 0
|
||
|
||
for root, _, files in os.walk(PROJECT_ROOT):
|
||
if any(ignore_dir in Path(root).parts for ignore_dir in IGNORE_DIRS):
|
||
continue
|
||
for file in files:
|
||
if Path(file).suffix.lower() in CPP_EXTENSIONS:
|
||
file_path = os.path.join(root, file)
|
||
logger.debug(f"解析文件构建关联图: {file_path}")
|
||
try:
|
||
with open(file_path, 'rb') as f:
|
||
source_bytes = f.read()
|
||
tree = parser.parse(source_bytes)
|
||
|
||
# --- 1. 收集头文件依赖 ---
|
||
includes = []
|
||
for node in tree.root_node.named_children:
|
||
if node.type == "preproc_include":
|
||
path_node = node.child_by_field_name("path")
|
||
if path_node:
|
||
include_text = path_node.text.decode('utf-8').strip('"<>')
|
||
includes.append(include_text)
|
||
FILE_DEPENDENCIES[file_path] = includes
|
||
|
||
# --- 2. 遍历顶层节点,找函数定义 ---
|
||
for node in tree.root_node.named_children:
|
||
if node.type == "function_definition":
|
||
func_name = extract_full_function_name(node)
|
||
if func_name == "<unknown>":
|
||
continue
|
||
|
||
total_funcs += 1
|
||
FUNCTION_CALL_GRAPH[func_name] = [] # 初始化
|
||
|
||
# 获取函数体(body)
|
||
body = node.child_by_field_name("body")
|
||
if body:
|
||
callees = collect_call_expressions(body)
|
||
FUNCTION_CALL_GRAPH[func_name] = callees
|
||
|
||
except Exception as e:
|
||
logger.error(f"解析文件失败 {file_path}: {e}", exc_info=True)
|
||
|
||
# --- 3. 构建反向调用图(谁调用了我)---
|
||
CALLED_BY_GRAPH.clear()
|
||
for caller, callees in FUNCTION_CALL_GRAPH.items():
|
||
for callee in callees:
|
||
if callee not in CALLED_BY_GRAPH:
|
||
CALLED_BY_GRAPH[callee] = []
|
||
CALLED_BY_GRAPH[callee].append(caller)
|
||
|
||
logger.info(f"阶段1完成。共识别 {total_funcs} 个函数,{len(FILE_DEPENDENCIES)} 个文件依赖。")
|
||
# === 阶段2: 生成摘要并构建RAG ===
|
||
logger.info("阶段2: 生成函数摘要并构建RAG")
|
||
for root, _, files in os.walk(PROJECT_ROOT):
|
||
if any(ignore_dir in Path(root).parts for ignore_dir in IGNORE_DIRS):
|
||
continue
|
||
for file in files:
|
||
if Path(file).suffix.lower() in CPP_EXTENSIONS:
|
||
file_path = os.path.join(root, file)
|
||
logger.debug(f"处理文件: {file_path}")
|
||
try:
|
||
with open(file_path, 'rb') as f:
|
||
source = f.read()
|
||
tree = parser.parse(source)
|
||
source_lines = source.split(b'\n')
|
||
|
||
for node in tree.root_node.named_children:
|
||
if node.type == "function_definition":
|
||
# === 关键修正:使用统一的函数名提取逻辑 ===
|
||
func_name = extract_full_function_name(node)
|
||
if func_name == "<unknown>":
|
||
logger.warning(f"无法解析函数名,跳过节点: {node.text[:100]}...")
|
||
continue
|
||
|
||
# 提取前置注释(保持不变)
|
||
comment = extract_comment_before(node, source_lines)
|
||
|
||
# 提取代码片段(保持不变)
|
||
func_code = extract_code_snippet(tree, node, source_lines)
|
||
print(func_code)
|
||
# 生成摘要(保持不变)
|
||
logic = generate_code_logic(func_name, comment, func_code, file_path)
|
||
summary = generate_code_summary(func_name, comment,logic,func_code, file_path)
|
||
print(summary)
|
||
# 从全局图中获取关系(确保 func_name 与阶段1一致!)
|
||
called_functions = FUNCTION_CALL_GRAPH.get(func_name, [])
|
||
caller_functions = CALLED_BY_GRAPH.get(func_name, [])
|
||
included_headers = FILE_DEPENDENCIES.get(file_path, [])
|
||
|
||
# 构建上下文文本
|
||
context_text = (
|
||
f"【实体类型】函数\n"
|
||
f"【函数名】{func_name}\n"
|
||
f"【所在文件】{os.path.basename(file_path)}\n"
|
||
f"【代码注释】{comment.strip() if comment else '无'}\n"
|
||
f"【功能摘要】{summary}\n"
|
||
f"【流程逻辑】{logic}\n"
|
||
f"【调用的函数】{', '.join(called_functions) if called_functions else '无'}\n"
|
||
f"【被以下函数调用】{', '.join(caller_functions) if caller_functions else '无'}\n"
|
||
f"【包含的头文件】{', '.join(included_headers) if included_headers else '无'}"
|
||
)
|
||
|
||
# 生成嵌入向量
|
||
vector = get_qwen_embedding(context_text)
|
||
|
||
# 构建元数据
|
||
metadata = {
|
||
"file": file_path,
|
||
"line": node.start_point[0] + 1,
|
||
"type": "function",
|
||
"name": func_name,
|
||
"summary": summary,
|
||
"logic": logic,
|
||
"calls": called_functions,
|
||
"called_by": caller_functions,
|
||
"includes": included_headers,
|
||
"comment": comment or ""
|
||
}
|
||
|
||
# 添加到向量库
|
||
index.add(np.array([vector], dtype='float32'))
|
||
metadata_list.append(metadata)
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理文件 {file_path} 失败: {str(e)}", exc_info=True)
|
||
# 保存结果
|
||
faiss.write_index(index, VECTOR_DB_PATH)
|
||
with open(METADATA_PATH, 'w', encoding='utf-8') as f:
|
||
json.dump(metadata_list, f, indent=2, ensure_ascii=False)
|
||
|
||
logger.info(f"知识库构建完成! 总实体: {len(metadata_list)}")
|
||
logger.info(f"向量数据库保存至: {VECTOR_DB_PATH}")
|
||
logger.info(f"元数据保存至: {METADATA_PATH}")
|
||
|
||
return index, metadata_list # 如果你确实需要返回,可以放这里
|
||
|
||
|
||
if __name__ == "__main__":
|
||
build_rag_database() |