上传文件至「/」
This commit is contained in:
498
Knowledge_Build.py
Normal file
498
Knowledge_Build.py
Normal file
@@ -0,0 +1,498 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
卫星星务软件代码RAG知识库构建工具(严格保留Tree-sitter/Qwen调用方式)
|
||||
- 仅优化关联关系处理逻辑
|
||||
- 保留原始Tree-sitter解析和Qwen-V3调用流程
|
||||
- 通过全局关联图确保关联关系正确性
|
||||
- 适用于卫星软件项目(C/C++)
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import faiss
|
||||
import tree_sitter
|
||||
import tree_sitter_c
|
||||
import tree_sitter_cpp
|
||||
from openai import OpenAI
|
||||
from tree_sitter import Language, Parser
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler("rag_builder.log"),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# === 配置常量 ===
|
||||
QWEN_API_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
QWEN_API_KEY = ""
|
||||
PROJECT_ROOT = "C:/Users/Administrator/Desktop\静态分析问题单过滤/gs-zk和扫描结果/PrjAttCtrlMng" # 替换为实际项目路径
|
||||
VECTOR_DB_PATH = "satellite_rag.faiss"
|
||||
METADATA_PATH = "satellite_rag_metadata.json"
|
||||
MAX_CODE_LENGTH = 800 # 代码片段最大长度
|
||||
|
||||
|
||||
# === 全局状态 ===
|
||||
FUNCTION_CALL_GRAPH = defaultdict(list)
|
||||
FILE_DEPENDENCIES = defaultdict(list)
|
||||
CALLED_BY_GRAPH = defaultdict(list) # 反向图:callee -> [caller]
|
||||
CPP_LANGUAGE = Language(tree_sitter_cpp.language())
|
||||
parser = Parser()
|
||||
parser.language=CPP_LANGUAGE
|
||||
CPP_EXTENSIONS = {'.c', '.cpp', '.h', '.hpp', '.tcc'}
|
||||
IGNORE_DIRS = {'build', 'cmake-build', '.git', 'vendor', 'lib', 'external','Debug'}
|
||||
|
||||
def generate_code_summary(func_name, comment, logic, code_snippet, file_path):
|
||||
called_functions = FUNCTION_CALL_GRAPH.get(func_name, [])
|
||||
caller_functions = CALLED_BY_GRAPH.get(func_name, []) # ← 直接获取,O(1)
|
||||
included_headers = FILE_DEPENDENCIES.get(file_path, [])
|
||||
|
||||
prompt = f"""
|
||||
你是一名资深的航天软件工程师,请总结以下C++函数的核心功能,**必须严格包含以下6点**:
|
||||
|
||||
1. 函数流程与逻辑:
|
||||
{logic}
|
||||
|
||||
2. 函数目的:
|
||||
→ 结合航天术语总结,必须根据上述“函数流程与逻辑”推导生成,不得依赖假设!
|
||||
|
||||
3. 输入参数(名称、类型、作用,用括号列出)
|
||||
|
||||
4. 返回值(类型和含义)
|
||||
|
||||
5. 与其他函数的关联关系:
|
||||
- 被调用的函数: {', '.join(called_functions) or '无'}
|
||||
- 调用此函数的函数: {', '.join(caller_functions) or '无'}
|
||||
|
||||
6. 与跨文件关联的函数(头文件): {', '.join(included_headers) or '无'}
|
||||
|
||||
---
|
||||
函数名: {func_name}
|
||||
注释: {comment or '无'}
|
||||
代码片段(截断至{MAX_CODE_LENGTH}字符):
|
||||
{code_snippet[:MAX_CODE_LENGTH]}
|
||||
|
||||
输出格式: "功能: [按上述6点组织的连贯段落,不要编号]"
|
||||
"""
|
||||
|
||||
headers = {"Authorization": f"Bearer {QWEN_API_KEY}", "Content-Type": "application/json"}
|
||||
payload = {
|
||||
"model": "qwen-max",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.1
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
"https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
summary = response.json()['choices'][0]['message']['content'].strip()
|
||||
return summary.replace("功能: ", "")
|
||||
def generate_code_logic(func_name, comment, code_snippet, file_path):
|
||||
prompt = f"""
|
||||
你是一名资深的航天软件工程师,请生成以下C++函数的markdown格式的函数流程图,并总结函数的核心流程逻辑。
|
||||
"""
|
||||
# 你的提示词保持不变
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {QWEN_API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
payload = {
|
||||
"model": "qwen-max",
|
||||
"messages": [ # ← 直接顶层字段,无 "input" 包裹
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"temperature": 0.1 # 可选:提高摘要一致性
|
||||
}
|
||||
|
||||
# 使用 OpenAI 兼容端点
|
||||
response = requests.post(
|
||||
"https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
logic = result['choices'][0]['message']['content'].strip() # ← 注意路径变化!
|
||||
return logic
|
||||
def safe_decode(b: bytes) -> str:
|
||||
for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']:
|
||||
try:
|
||||
return b.decode(encoding)
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
return b.decode('utf-8', errors='replace')
|
||||
|
||||
# === 严格保留的代码提取辅助函数 ===
|
||||
def extract_comment_before(node, source_lines):
|
||||
"""严格保留注释提取逻辑(已验证正确)"""
|
||||
start_line = node.start_point[0]
|
||||
comment_lines = []
|
||||
|
||||
for i in range(start_line - 1, max(0, start_line - 10), -1):
|
||||
line = safe_decode(source_lines[i]).strip()
|
||||
if line.startswith('//') or line.startswith('/*'):
|
||||
comment_lines.append(line)
|
||||
|
||||
return " ".join(comment_lines[::-1]) or "无注释"
|
||||
|
||||
|
||||
def extract_code_snippet(tree, node, source_lines):
|
||||
"""严格保留代码片段提取逻辑(已验证正确)"""
|
||||
start_line = node.start_point[0]
|
||||
end_line = node.end_point[0]
|
||||
|
||||
code_lines = []
|
||||
for i in range(start_line, end_line + 1):
|
||||
if i < len(source_lines):
|
||||
code_lines.append(safe_decode(source_lines[i]))
|
||||
|
||||
return "\n".join(code_lines)
|
||||
|
||||
|
||||
# === 严格保留的向量化嵌入 ===
|
||||
def get_qwen_embedding(text: str) -> np.ndarray:
|
||||
"""
|
||||
调用 DashScope 的 text-embedding-v4 模型生成 1024 维文本嵌入向量。
|
||||
使用 OpenAI 兼容 SDK 方式调用。
|
||||
"""
|
||||
try:
|
||||
client = OpenAI(
|
||||
api_key=QWEN_API_KEY,
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
)
|
||||
response = client.embeddings.create(
|
||||
model="text-embedding-v4",
|
||||
input=text # OpenAI SDK 支持单个字符串或字符串列表
|
||||
)
|
||||
embedding = response.data[0].embedding
|
||||
return np.array(embedding, dtype='float32')
|
||||
except Exception as e:
|
||||
# 假设 logger 已定义;若未定义,可临时用 print 替代
|
||||
try:
|
||||
logger.error(f"获取嵌入失败: {e}")
|
||||
except NameError:
|
||||
print(f"获取嵌入失败: {e}")
|
||||
return np.zeros(1024, dtype='float32')
|
||||
def extract_function_name(func_def_node):
|
||||
"""
|
||||
从 function_definition 节点中提取完整的函数名(含命名空间/类作用域)
|
||||
"""
|
||||
declarator = func_def_node.child_by_field_name("declarator")
|
||||
if not declarator:
|
||||
return "<unknown>"
|
||||
|
||||
# 处理 A::B::func 这种形式:可能是 field_expression 嵌套
|
||||
def get_qualified_name(node):
|
||||
if node.type == "identifier":
|
||||
return node.text.decode('utf-8')
|
||||
elif node.type == "field_expression":
|
||||
# left.field -> left 可能是 identifier 或 field_expression
|
||||
left = node.child_by_field_name("argument") or node.children[0]
|
||||
right = node.child_by_field_name("field") or node.children[-1]
|
||||
left_name = get_qualified_name(left) if left else ""
|
||||
right_name = get_qualified_name(right) if right else ""
|
||||
return f"{left_name}::{right_name}" if left_name and right_name else (left_name or right_name)
|
||||
elif node.type == "function_declarator":
|
||||
# 函数声明器:取其 declarator 部分
|
||||
func_declarator = node.child_by_field_name("declarator")
|
||||
if func_declarator:
|
||||
return get_qualified_name(func_declarator)
|
||||
# 兜底:直接取文本
|
||||
return node.text.decode('utf-8').split('(')[0].strip()
|
||||
|
||||
return get_qualified_name(declarator)
|
||||
def extract_qualified_name_from_field(field_node):
|
||||
"""递归解析 field_expression,如 A::B::foo"""
|
||||
def _extract(n):
|
||||
if n.type == "identifier":
|
||||
return n.text.decode('utf-8')
|
||||
elif n.type == "field_expression":
|
||||
# left.field -> argument.field
|
||||
left = n.child_by_field_name("argument") or (n.children[0] if n.children else None)
|
||||
right = n.child_by_field_name("field") or (n.children[-1] if n.children else None)
|
||||
left_str = _extract(left) if left else ""
|
||||
right_str = _extract(right) if right else ""
|
||||
if left_str and right_str:
|
||||
return f"{left_str}::{right_str}"
|
||||
return left_str or right_str
|
||||
else:
|
||||
# 兜底:取文本,截断参数列表
|
||||
text = n.text.decode('utf-8')
|
||||
return text.split('(')[0].strip()
|
||||
return _extract(field_node)
|
||||
|
||||
|
||||
def extract_full_function_name(func_def_node):
|
||||
"""
|
||||
从 function_definition 节点中提取干净的函数名(支持 A::foo)
|
||||
"""
|
||||
declarator = func_def_node.child_by_field_name("declarator")
|
||||
if not declarator:
|
||||
return "<unknown>"
|
||||
|
||||
# 情况1: declarator 是 function_declarator(最常见)
|
||||
if declarator.type == "function_declarator":
|
||||
inner_declarator = declarator.child_by_field_name("declarator")
|
||||
if not inner_declarator:
|
||||
return "<unknown>"
|
||||
if inner_declarator.type == "identifier":
|
||||
return inner_declarator.text.decode('utf-8')
|
||||
elif inner_declarator.type == "field_expression":
|
||||
return extract_qualified_name_from_field(inner_declarator)
|
||||
else:
|
||||
# 兜底处理(如数组、指针等,但函数名通常不会这样)
|
||||
raw = inner_declarator.text.decode('utf-8')
|
||||
return raw.split('(')[0].strip()
|
||||
|
||||
# 情况2: declarator 直接是 identifier(旧式C或简单情况)
|
||||
elif declarator.type == "identifier":
|
||||
return declarator.text.decode('utf-8')
|
||||
|
||||
# 情况3: 其他类型(如 pointer_declarator),尝试兜底
|
||||
else:
|
||||
raw = declarator.text.decode('utf-8')
|
||||
return raw.split('(')[0].strip()
|
||||
def traverse_type(node, target_type):
|
||||
"""递归遍历 AST 节点,查找指定类型"""
|
||||
if node.type == target_type:
|
||||
yield node
|
||||
for child in node.children:
|
||||
yield from traverse_type(child, target_type)
|
||||
# === 主处理流程(仅优化关联关系处理)===
|
||||
def build_rag_database():
|
||||
"""严格保留Tree-sitter/Qwen调用流程,仅优化关联关系处理"""
|
||||
global FUNCTION_CALL_GRAPH, FILE_DEPENDENCIES, CALLED_BY_GRAPH
|
||||
|
||||
# 清空旧图(避免多次调用累积)
|
||||
FUNCTION_CALL_GRAPH.clear()
|
||||
FILE_DEPENDENCIES.clear()
|
||||
CALLED_BY_GRAPH.clear()
|
||||
|
||||
dimension = 1024
|
||||
index = faiss.IndexFlatL2(dimension)
|
||||
metadata_list = []
|
||||
|
||||
import re
|
||||
|
||||
# 全局图(确保在阶段1前清空)
|
||||
FUNCTION_CALL_GRAPH = {}
|
||||
CALLED_BY_GRAPH = {}
|
||||
FILE_DEPENDENCIES = {}
|
||||
|
||||
def extract_callee_name(func_node):
|
||||
"""从 call_expression 的 function 节点中提取函数名标识符"""
|
||||
if func_node.type == "identifier":
|
||||
return func_node.text.decode('utf-8')
|
||||
elif func_node.type == "field_expression": # 如 A::func 或 a.b
|
||||
field = func_node.child_by_field_name("field")
|
||||
if field and field.type == "identifier":
|
||||
return field.text.decode('utf-8')
|
||||
# 递归处理复杂命名空间
|
||||
text = func_node.text.decode('utf-8')
|
||||
parts = re.split(r'[.:]+', text)
|
||||
return parts[-1].split('(')[0].strip() if parts else "<unknown>"
|
||||
elif func_node.type == "member_expression": # C++ 风格: obj->method 或 obj.method
|
||||
member = func_node.child_by_field_name("member")
|
||||
if member and member.type == "identifier":
|
||||
return member.text.decode('utf-8')
|
||||
# 兜底
|
||||
text = func_node.text.decode('utf-8')
|
||||
for sep in ['->', '.']:
|
||||
if sep in text:
|
||||
return text.split(sep)[-1].split('(')[0].strip()
|
||||
return text.split('(')[0].strip()
|
||||
else:
|
||||
# 兜底:从文本中提取最可能的函数名
|
||||
text = func_node.text.decode('utf-8')
|
||||
# 移除模板 <...>
|
||||
text = re.sub(r'<[^>]*>', '', text)
|
||||
# 取括号前部分
|
||||
base = text.split('(')[0].strip()
|
||||
# 处理 a::b::c -> c
|
||||
for sep in ['::', '->', '.']:
|
||||
if sep in base:
|
||||
base = base.split(sep)[-1]
|
||||
# 简单合法性检查
|
||||
if base and base[0].isalpha() and all(c.isalnum() or c in '_$' for c in base):
|
||||
return base
|
||||
return "<unknown>"
|
||||
|
||||
def collect_call_expressions(node):
|
||||
"""递归遍历 AST 节点,收集所有 call_expression 中的 callee 名称(去重)"""
|
||||
calls = set()
|
||||
|
||||
def traverse(n):
|
||||
if n.type == "call_expression":
|
||||
func_node = n.child_by_field_name("function")
|
||||
if func_node:
|
||||
callee = extract_callee_name(func_node)
|
||||
if callee != "<unknown>":
|
||||
calls.add(callee)
|
||||
# 递归子节点
|
||||
for child in n.children:
|
||||
traverse(child)
|
||||
|
||||
traverse(node)
|
||||
return list(calls)
|
||||
|
||||
# ========================
|
||||
# 阶段1:构建调用图与依赖图
|
||||
# ========================
|
||||
logger.info("阶段1: 构建函数调用图和文件依赖图")
|
||||
|
||||
total_funcs = 0
|
||||
|
||||
for root, _, files in os.walk(PROJECT_ROOT):
|
||||
if any(ignore_dir in Path(root).parts for ignore_dir in IGNORE_DIRS):
|
||||
continue
|
||||
for file in files:
|
||||
if Path(file).suffix.lower() in CPP_EXTENSIONS:
|
||||
file_path = os.path.join(root, file)
|
||||
logger.debug(f"解析文件构建关联图: {file_path}")
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
source_bytes = f.read()
|
||||
tree = parser.parse(source_bytes)
|
||||
|
||||
# --- 1. 收集头文件依赖 ---
|
||||
includes = []
|
||||
for node in tree.root_node.named_children:
|
||||
if node.type == "preproc_include":
|
||||
path_node = node.child_by_field_name("path")
|
||||
if path_node:
|
||||
include_text = path_node.text.decode('utf-8').strip('"<>')
|
||||
includes.append(include_text)
|
||||
FILE_DEPENDENCIES[file_path] = includes
|
||||
|
||||
# --- 2. 遍历顶层节点,找函数定义 ---
|
||||
for node in tree.root_node.named_children:
|
||||
if node.type == "function_definition":
|
||||
func_name = extract_full_function_name(node)
|
||||
if func_name == "<unknown>":
|
||||
continue
|
||||
|
||||
total_funcs += 1
|
||||
FUNCTION_CALL_GRAPH[func_name] = [] # 初始化
|
||||
|
||||
# 获取函数体(body)
|
||||
body = node.child_by_field_name("body")
|
||||
if body:
|
||||
callees = collect_call_expressions(body)
|
||||
FUNCTION_CALL_GRAPH[func_name] = callees
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析文件失败 {file_path}: {e}", exc_info=True)
|
||||
|
||||
# --- 3. 构建反向调用图(谁调用了我)---
|
||||
CALLED_BY_GRAPH.clear()
|
||||
for caller, callees in FUNCTION_CALL_GRAPH.items():
|
||||
for callee in callees:
|
||||
if callee not in CALLED_BY_GRAPH:
|
||||
CALLED_BY_GRAPH[callee] = []
|
||||
CALLED_BY_GRAPH[callee].append(caller)
|
||||
|
||||
logger.info(f"阶段1完成。共识别 {total_funcs} 个函数,{len(FILE_DEPENDENCIES)} 个文件依赖。")
|
||||
# === 阶段2: 生成摘要并构建RAG ===
|
||||
logger.info("阶段2: 生成函数摘要并构建RAG")
|
||||
for root, _, files in os.walk(PROJECT_ROOT):
|
||||
if any(ignore_dir in Path(root).parts for ignore_dir in IGNORE_DIRS):
|
||||
continue
|
||||
for file in files:
|
||||
if Path(file).suffix.lower() in CPP_EXTENSIONS:
|
||||
file_path = os.path.join(root, file)
|
||||
logger.debug(f"处理文件: {file_path}")
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
source = f.read()
|
||||
tree = parser.parse(source)
|
||||
source_lines = source.split(b'\n')
|
||||
|
||||
for node in tree.root_node.named_children:
|
||||
if node.type == "function_definition":
|
||||
# === 关键修正:使用统一的函数名提取逻辑 ===
|
||||
func_name = extract_full_function_name(node)
|
||||
if func_name == "<unknown>":
|
||||
logger.warning(f"无法解析函数名,跳过节点: {node.text[:100]}...")
|
||||
continue
|
||||
|
||||
# 提取前置注释(保持不变)
|
||||
comment = extract_comment_before(node, source_lines)
|
||||
|
||||
# 提取代码片段(保持不变)
|
||||
func_code = extract_code_snippet(tree, node, source_lines)
|
||||
print(func_code)
|
||||
# 生成摘要(保持不变)
|
||||
logic = generate_code_logic(func_name, comment, func_code, file_path)
|
||||
summary = generate_code_summary(func_name, comment,logic,func_code, file_path)
|
||||
print(summary)
|
||||
# 从全局图中获取关系(确保 func_name 与阶段1一致!)
|
||||
called_functions = FUNCTION_CALL_GRAPH.get(func_name, [])
|
||||
caller_functions = CALLED_BY_GRAPH.get(func_name, [])
|
||||
included_headers = FILE_DEPENDENCIES.get(file_path, [])
|
||||
|
||||
# 构建上下文文本
|
||||
context_text = (
|
||||
f"【实体类型】函数\n"
|
||||
f"【函数名】{func_name}\n"
|
||||
f"【所在文件】{os.path.basename(file_path)}\n"
|
||||
f"【代码注释】{comment.strip() if comment else '无'}\n"
|
||||
f"【功能摘要】{summary}\n"
|
||||
f"【流程逻辑】{logic}\n"
|
||||
f"【调用的函数】{', '.join(called_functions) if called_functions else '无'}\n"
|
||||
f"【被以下函数调用】{', '.join(caller_functions) if caller_functions else '无'}\n"
|
||||
f"【包含的头文件】{', '.join(included_headers) if included_headers else '无'}"
|
||||
)
|
||||
|
||||
# 生成嵌入向量
|
||||
vector = get_qwen_embedding(context_text)
|
||||
|
||||
# 构建元数据
|
||||
metadata = {
|
||||
"file": file_path,
|
||||
"line": node.start_point[0] + 1,
|
||||
"type": "function",
|
||||
"name": func_name,
|
||||
"summary": summary,
|
||||
"logic": logic,
|
||||
"calls": called_functions,
|
||||
"called_by": caller_functions,
|
||||
"includes": included_headers,
|
||||
"comment": comment or ""
|
||||
}
|
||||
|
||||
# 添加到向量库
|
||||
index.add(np.array([vector], dtype='float32'))
|
||||
metadata_list.append(metadata)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理文件 {file_path} 失败: {str(e)}", exc_info=True)
|
||||
# 保存结果
|
||||
faiss.write_index(index, VECTOR_DB_PATH)
|
||||
with open(METADATA_PATH, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata_list, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"知识库构建完成! 总实体: {len(metadata_list)}")
|
||||
logger.info(f"向量数据库保存至: {VECTOR_DB_PATH}")
|
||||
logger.info(f"元数据保存至: {METADATA_PATH}")
|
||||
|
||||
return index, metadata_list # 如果你确实需要返回,可以放这里
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
build_rag_database()
|
||||
Reference in New Issue
Block a user