Files
rag_agent/RAG-TEST-TOOLS/feature_retriever.py

611 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 在 FeatureRetriever 类内部新增以下常量和方法
import json
import os
import re
from typing import Dict, List, Optional
import logging
import faiss
import numpy as np
from config import TOP_K, METADATA_PATH, VECTOR_DB_PATH, KNOWLEDGE_GRAPH_PATH, MIN_SIMILARITY_THRESHOLD
from llm_processor import call_qwen_max, get_qwen_embedding
logger = logging.getLogger(__name__)
class KnowledgeGraphAnalyzer:
"""知识图谱分析器(简化版)"""
def __init__(self, graph_data: Dict):
"""初始化图谱分析器"""
self.nodes = {node["id"]: node for node in graph_data.get("nodes", [])}
self.edges = graph_data.get("edges", [])
# 构建邻接表
self.calls_graph = self._build_calls_graph()
self.reverse_calls_graph = self._build_reverse_calls_graph()
# 构建名称到ID的映射
self.name_to_id = self._build_name_mapping()
def _build_calls_graph(self) -> Dict[str, List[str]]:
"""构建函数调用图"""
calls_graph = {}
for edge in self.edges:
if edge["type"] == "CALLS":
source_id = edge["source_id"]
target_id = edge["target_id"]
if source_id not in calls_graph:
calls_graph[source_id] = []
calls_graph[source_id].append(target_id)
return calls_graph
def _build_reverse_calls_graph(self) -> Dict[str, List[str]]:
"""构建反向调用图"""
reverse_graph = {}
for edge in self.edges:
if edge["type"] == "CALLS":
source_id = edge["source_id"]
target_id = edge["target_id"]
if target_id not in reverse_graph:
reverse_graph[target_id] = []
reverse_graph[target_id].append(source_id)
return reverse_graph
def _build_name_mapping(self) -> Dict[str, str]:
"""构建函数名到节点ID的映射"""
name_to_id = {}
for node_id, node in self.nodes.items():
if node["type"] == "Function":
name_to_id[node["name"]] = node_id
return name_to_id
def find_function_by_name(self, function_name: str) -> Optional[Dict]:
"""根据函数名查找节点"""
node_id = self.name_to_id.get(function_name)
if node_id:
return self.nodes.get(node_id)
return None
def analyze_function_context(self, function_id: str) -> Dict:
"""分析函数的上下文(调用者和被调用者)"""
if function_id not in self.nodes:
return {"error": f"未找到函数节点: {function_id}"}
result = {
"function_info": self.nodes[function_id],
"callers": [],
"callees": [],
"call_depth": 0
}
# 分析调用者
if function_id in self.reverse_calls_graph:
for caller_id in self.reverse_calls_graph[function_id]:
if caller_id in self.nodes:
result["callers"].append(self.nodes[caller_id])
# 分析被调用者
if function_id in self.calls_graph:
for callee_id in self.calls_graph[function_id]:
if callee_id in self.nodes:
result["callees"].append(self.nodes[callee_id])
# 计算调用深度
result["call_depth"] = self._calculate_max_depth(function_id)
return result
def _calculate_max_depth(self, func_id: str, visited: set = None) -> int:
"""计算函数的最大调用深度"""
if visited is None:
visited = set()
if func_id in visited or func_id not in self.calls_graph:
return 0
visited.add(func_id)
max_depth = 0
for callee_id in self.calls_graph.get(func_id, []):
depth = 1 + self._calculate_max_depth(callee_id, visited.copy())
max_depth = max(max_depth, depth)
return max_depth
class FeatureRetriever:
"""功能需求检索器主类(简化版)"""
def __init__(self, config: Optional[Dict] = None):
"""初始化检索器"""
self.llm_judgment_prompt_template = self._get_judgment_prompt_template()
def _get_judgment_prompt_template(self) -> str:
"""返回用于LLM推理的提示词模板"""
return r"""
你是一个资深的软件架构师和代码评审专家。你需要根据用户提出的功能需求,结合给定的代码知识库分析结果,判断该需求在项目中是否已被实现。
# 任务
请严格遵循以下步骤和约束条件进行分析,并给出最终裁决。
# 输入信息
1. **用户需求**{user_query}
2. **候选函数分析档案**
{analysis_dossiers}
# 你必须遵循的推理约束(知识图谱约束)
请基于以下四个约束条件,对每个候选函数进行分析,并综合判断需求实现状态:
**约束1功能完整性约束**
- 检查候选函数的调用链callers & callees是否能完整实现用户需求所隐含的所有子步骤或业务流程。
- 关键问题:是否存在逻辑断点?调用链是否覆盖了需求中“输入-处理-输出”的全过程?
**约束2架构/设计模式契合度约束**
- 识别候选函数所体现的设计模式(如工厂、策略、观察者、责任链、模板方法等)。
- 判断该模式是否与用户需求所描述的系统行为模式、扩展性要求或解耦意图相匹配。
**约束3模块协调性约束**
- 分析候选函数与其相关函数(调用者/被调用者)是否共同构成了一个高内聚、低耦合的功能模块。
- 关键问题:相关函数是否在语义上属于同一业务领域?它们之间的数据流是否清晰合理?
**约束4语义覆盖深度约束**
- 超越简单的文本相似度,深入分析函数在代码上下文中的**实际语义**。
- 关键问题:函数的名称、摘要、调用关系所体现的核心功能,是否精准地**覆盖**了用户需求的**核心意图**?是否存在语义偏差或功能泛化/缩窄?
你的输出**必须且只能**是一个完全符合 RFC 8259 标准的 JSON 对象。该 JSON 必须能被 Python 的 `json.loads()` 函数**直接解析**,无需任何预处理。
为了确保这一点,你必须严格遵守以下规则:
1. **禁用反引号**:在整个 JSON 对象的任何字符串值中,**绝对不允许**使用反引号 (\`)。所有需要引用标识符(如函数名、文件名)的地方,应直接将其作为字符串的一部分,**无需用任何特殊字符包围**。例如,应写成 `StarCheckGyro`,而非 \`StarCheckGyro\`。
2. **正确转义双引号**:如果在字符串内部**必须**使用双引号("),则必须使用反斜杠进行转义,写作 `\"`。例如,描述“主函数”的字符串应写作 `"主函数"`,若需在内部提及带引号的文本,应写作 `"他说:\\"你好\\""`。
3. **正确的布尔值和空值**:布尔值必须使用小写的 `true` 或 `false`。空值必须使用小写的 `null`。禁止使用 Python 的 `True`, `False`, `None` 或大写的 `Null`。
4. **无注释**JSON 对象中不允许包含任何形式的注释(`//`, `/* */`)。
5. **无尾随逗号**:在对象或数组的最后一个元素后,不允许有逗号。
违反以上任何一条规则,都将导致你的输出无法被解析,任务失败。
# 输出格式
你必须以以下JSON格式输出且仅输出此JSON对象
{{
"verdict": "完整实现 | 部分实现 | 未实现",
"confidence": 0.0到1.0之间的一个浮点数,
"primary_reasoning": "一段文字综合阐述依据约束1-4得出的主要判断理由。",
"constraint_analysis": {{
"functional_completeness": {{
"status": "满足 | 部分满足 | 不满足",
"reason": "基于约束1的分析说明。"
}},
"architectural_fit": {{
"status": "匹配 | 部分匹配 | 不匹配",
"reason": "基于约束2的分析说明。"
}},
"module_coordination": {{
"status": "协调 | 部分协调 | 不协调",
"reason": "基于约束3的分析说明。"
}},
"semantic_coverage": {{
"status": "覆盖 | 部分覆盖 | 未覆盖",
"reason": "基于约束4的分析说明。"
}}
}},
"call_chain_analysis": {{
"description": "对实现需求的核心调用链的详细文字描述。包括起点函数、关键处理函数、终点函数,以及数据流或控制流的简要说明。",
"chain": ["函数名1", "函数名2", "...", "函数名N"] # 按调用顺序列出关键函数名
}},
"design_pattern_analysis": {{
"identified_patterns": ["模式1", "模式2", ...], # 识别出的设计模式列表
"explanation": "对识别出的设计模式进行解释,说明其在代码中如何体现,以及为何适用于当前需求。"
}},
"most_relevant_functions": [
{{
"name": "函数名",
"role": "此函数在满足需求中扮演的角色(如:主入口、关键计算、数据提供者、校验器、协调者等)",
"supporting_evidence": "来自分析档案的简要证据"
}}
],
"gap_analysis": "如果 verdict 不是 '完整实现',请详细说明缺失的功能组件、断裂的调用链或语义偏差。否则为空字符串。"
}}
现在,请开始你的分析。
"""
def _generate_analysis_dossier(self, func_info: Dict) -> str:
"""为单个候选函数生成详细的文本分析档案"""
dossier_lines = []
# 1. 基础信息
dossier_lines.append(f"## 函数: {func_info.get('name', 'N/A')}")
dossier_lines.append(f"- **文件**: {os.path.basename(func_info.get('file', 'N/A'))}")
dossier_lines.append(f"- **语义相似度**: {func_info.get('similarity', 0):.3f}")
dossier_lines.append(f"- **功能摘要**: {func_info.get('summary', '')}")
# 2. 调用链深度分析 (利用知识图谱)
func_name = func_info.get('name')
call_chain_info = " (知识图谱未加载,无法分析)"
if self.kg_analyzer and func_name:
node = self.kg_analyzer.find_function_by_name(func_name)
if node:
ctx = self.kg_analyzer.analyze_function_context(node['id'])
dossier_lines.append(f"- **调用深度**: {ctx.get('call_depth', 0)}")
dossier_lines.append(f"- **被以下函数调用 (Callers, {len(ctx.get('callers', []))}个)**:")
for caller in ctx.get('callers', [])[:3]: # 最多显示3个
dossier_lines.append(f" - {caller.get('name', 'N/A')}")
dossier_lines.append(f"- **调用以下函数 (Callees, {len(ctx.get('callees', []))}个)**:")
for callee in ctx.get('callees', [])[:5]: # 最多显示5个
dossier_lines.append(f" - {callee.get('name', 'N/A')}")
call_chain_info = f" 深度{ctx.get('call_depth', 0)} {len(ctx.get('callers', []))}个调用者, {len(ctx.get('callees', []))}个被调用者。"
# 3. 架构模式分析 (增强版)
pattern_info = self._analyze_design_pattern(func_info)
dossier_lines.append(f"- **识别的设计模式/架构特征**: {', '.join(pattern_info) if pattern_info else '未识别到明显模式'}")
dossier_lines.append("") # 空行分隔
return "\n".join(dossier_lines)
def _analyze_design_pattern(self, func_info: Dict) -> List[str]:
"""增强的设计模式分析,返回识别出的模式列表"""
patterns = []
func_name = func_info.get('name', '')
if not self.kg_analyzer:
return patterns
# 关键修正:使用您原有的 find_function_by_name 方法
node = self.kg_analyzer.find_function_by_name(func_name)
if not node:
return patterns
# 关键修正:使用您原有的 analyze_function_context 方法获取调用关系
ctx = self.kg_analyzer.analyze_function_context(node['id'])
callers = ctx.get('callers', [])
callees = ctx.get('callees', [])
callers_cnt = len(callers)
callees_cnt = len(callees)
# 基于调用关系推断模式
if callers_cnt == 0 and callees_cnt > 2:
patterns.append("“工厂方法”或“构建器” (可能是对象创建入口)")
elif callers_cnt > 2 and callees_cnt == 0:
patterns.append("“观察者”的通知方法 或 “策略”的接口实现")
elif callers_cnt == 1 and callees_cnt > 1:
patterns.append("“模板方法”中的步骤定义 或 “外观”模式接口")
elif callers_cnt > 1 and callees_cnt > 1:
patterns.append("“协调者”或“中介者” (复杂逻辑协调)")
if ctx.get('call_depth', 0) >= 3:
patterns.append("“责任链”或“管道” (多级处理)")
# 基于函数名关键词的启发式规则 (可选)
name_lower = func_name.lower()
if 'factory' in name_lower or 'create' in name_lower or 'build' in name_lower:
patterns.append("函数名暗示“创建型”模式")
if 'handler' in name_lower or 'processor' in name_lower or 'service' in name_lower:
patterns.append("函数名暗示“行为型”模式")
if 'adapter' in name_lower or 'wrapper' in name_lower:
patterns.append("函数名暗示“结构型”模式")
return list(set(patterns)) # 去重
def analyze_with_multiple_constraints(self, query: str) -> Dict:
"""
基于多推理约束分析需求实现状态 (新版)
流程1. 语义检索 -> 2. 为每个结果生成分析档案 -> 3. LLM基于约束推理
"""
if not self.faiss_index:
logger.error("知识库未加载")
return {"error": "知识库未完全加载"}
# 阶段1: 语义检索 (不变)
relevant_functions = self.search_by_semantic(query, top_k=TOP_K)
if not relevant_functions:
return {
"implemented": False,
"reason": "在代码库中未找到语义相似度足够高的相关函数",
"confidence": 0.1,
"analysis_method": "semantic_search_only",
"details": "语义检索未命中。"
}
# 阶段2: 为每个相关函数生成深度分析档案
analysis_dossiers = []
for func_info in relevant_functions:
dossier_text = self._generate_analysis_dossier(func_info)
analysis_dossiers.append(dossier_text)
combined_dossiers = "\n---\n".join(analysis_dossiers)
# 阶段3: 调用LLM基于约束进行综合推理判断
llm_judgment_result = self._make_llm_judgment(query, combined_dossiers)
# 阶段4: 格式化最终输出,兼容新旧接口
final_result = self._format_final_result(llm_judgment_result, relevant_functions)
return final_result
def _make_llm_judgment(self, query: str, analysis_dossiers: str) -> Dict:
"""核心调用LLM强制其在知识图谱约束下进行推理评判"""
try:
# 构建提示词
prompt = self.llm_judgment_prompt_template.format(
user_query=query,
analysis_dossiers=analysis_dossiers
)
# 调用LLM
llm_response_text = call_qwen_max(prompt, temperature=0.1)
# --- 新增健壮的JSON解析与修复 ---
import json
judgment_data = self._robust_json_parse(llm_response_text)
if judgment_data is not None:
return judgment_data
else:
# 如果修复后解析仍然失败,则回退
logger.error(f"LLM响应JSON修复后解析仍失败。原始响应:\n{llm_response_text}")
return self._get_fallback_judgment()
except Exception as e:
logger.error(f"LLM推理过程出错: {e}")
return self._get_fallback_judgment()
def _robust_json_parse(self, raw_text: str) -> Optional[Dict]:
"""
尝试从原始文本中稳健地解析JSON。
步骤1. 提取JSON块。 2. 清理常见问题。 3. 尝试解析。 4. 如失败,尝试修复常见问题后重试。
"""
import json
import re
# 步骤1: 提取最可能包含JSON的部分
# 去除可能的Markdown代码块标记
text = raw_text.strip()
if text.startswith('```'):
# 移除开头的 ```json 或 ```
lines = text.split('\n', 1)
if len(lines) > 1:
text = lines[1].strip()
if text.endswith('```'):
text = text[:-3].strip()
# 查找第一个 '{' 和最后一个 '}' 的位置
start = text.find('{')
end = text.rfind('}')
if start == -1 or end == -1 or start >= end:
logger.warning(f"在LLM响应中未找到有效的JSON结构。")
return None
json_str = text[start:end + 1] # 提取从 { 到 } 的字符串
# 步骤2: 首次尝试直接解析
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
logger.warning(f"首次JSON解析失败 (位置 {e.pos}),尝试清理和修复。错误: {e.msg}")
# 记录出错位置附近的文本,便于调试
snippet_start = max(0, e.pos - 20)
snippet_end = min(len(json_str), e.pos + 20)
logger.debug(f"错误位置附近文本: ...{json_str[snippet_start:snippet_end]}...")
# 步骤3: 进行一系列文本清理和修复尝试
repaired_json = json_str
pattern = r'(?<!\\)"(\w+)"'
def replace_inner_quotes(match):
# 将匹配到的 "word" 替换为 \"word\"
return r'\"' + match.group(1) + r'\"'
repaired_json = re.sub(pattern, replace_inner_quotes, repaired_json)
# 尝试2: 修复不匹配的引号 (例如,字符串以单引号开始但以双引号结束)
# 这里我们只是简单地移除字符串开头/结尾可能存在的杂散引号或换行符
repaired_json = repaired_json.strip()
# 尝试3: 修复不规范的布尔值或null (例如 True, False, Null)
repaired_json = re.sub(r':\s*True\b', ': true', repaired_json)
repaired_json = re.sub(r':\s*False\b', ': false', repaired_json)
repaired_json = re.sub(r':\s*Null\b', ': null', repaired_json, flags=re.IGNORECASE)
repaired_json = re.sub(r':\s*None\b', ': null', repaired_json)
# 步骤4: 尝试解析修复后的JSON
try:
return json.loads(repaired_json)
except json.JSONDecodeError as e2:
logger.error(f"修复后JSON解析仍然失败 (位置 {e2.pos})。修复后文本片段:\n...{repaired_json[max(0, e2.pos-50):e2.pos+50]}...")
return None
def _get_fallback_judgment(self) -> Dict:
"""LLM调用失败时的回退方案"""
return {
"verdict": "未知",
"confidence": 0.0,
"primary_reasoning": "LLM推理过程发生错误使用回退逻辑。",
"constraint_analysis": {
"functional_completeness": {"status": "未知", "reason": "错误"},
"architectural_fit": {"status": "未知", "reason": "错误"},
"module_coordination": {"status": "未知", "reason": "错误"},
"semantic_coverage": {"status": "未知", "reason": "错误"}
},
"most_relevant_functions": [],
"gap_analysis": "系统内部错误。"
}
def _format_final_result(self, llm_judgment: Dict, relevant_functions: List[Dict]) -> Dict:
"""将LLM的评判结果格式化为兼容原接口的最终输出"""
# 映射 verdict 到 implemented
verdict_map = {
"完整实现": True,
"部分实现": True, # 部分实现也视为某种程度的实现
"未实现": False,
"未知": False
}
implemented = verdict_map.get(llm_judgment.get("verdict", "未知"), False)
# 构建详细的约束分析结果,兼容旧版格式的同时提供更丰富的信息
constraint_scores = {}
constraint_status = llm_judgment.get("constraint_analysis", {})
# 为每个约束生成一个模拟分数(基于状态),用于兼容性
for cons_name, cons_info in constraint_status.items():
status_score_map = {"满足": 1.0, "匹配": 1.0, "协调": 1.0, "覆盖": 1.0,
"部分满足": 0.5, "部分匹配": 0.5, "部分协调": 0.5, "部分覆盖": 0.5,
"不满足": 0.0, "不匹配": 0.0, "不协调": 0.0, "未覆盖": 0.0,
"未知": 0.0}
score = status_score_map.get(cons_info.get("status", "未知"), 0.0)
constraint_scores[cons_name] = {
"score": score,
"passed": score >= 0.5, # 0.5作为通过的阈值
"details": cons_info.get("reason", "")
}
# 选取相似度最高的函数作为 most_relevant_function
most_relevant = None
if relevant_functions:
best_func = max(relevant_functions, key=lambda x: x["similarity"])
most_relevant = {
"name": best_func["name"],
"file": os.path.basename(best_func.get("file", "未知文件")),
"similarity": best_func["similarity"],
"summary": best_func.get("summary", "无摘要")[:100]
}
# 构建最终结果
final_result = {
"implemented": implemented,
"reason": llm_judgment.get("primary_reasoning", "无理由提供"),
"confidence": llm_judgment.get("confidence", 0.0),
"total_score": llm_judgment.get("confidence", 0.0), # 用confidence作为total_score
"passed_constraints": sum(1 for cs in constraint_scores.values() if cs.get("passed", False)),
"total_constraints": len(constraint_scores),
"constraint_scores": constraint_scores,
"most_relevant_function": most_relevant,
"relevant_functions_count": len(relevant_functions),
# 新增字段提供更丰富的LLM推理结果
"llm_judgment": {
"verdict": llm_judgment.get("verdict", "未知"),
"gap_analysis": llm_judgment.get("gap_analysis", ""),
"most_relevant_functions_roles": llm_judgment.get("most_relevant_functions", [])
},
"analysis_method": "llm_constraint_reasoning" # 标识使用了新方法
}
return final_result
def __init__(self, config: Optional[Dict] = None):
"""初始化检索器"""
self.config = config or {}
self.faiss_index = None
self.contexts = None
self.metadatas = None
self.knowledge_graph = None
self.kg_analyzer = None
# 从配置加载路径
self.vector_db_path = config.get("vector_db_path", VECTOR_DB_PATH)
self.metadata_path = config.get("metadata_path", METADATA_PATH)
self.knowledge_graph_path = config.get("knowledge_graph_path", KNOWLEDGE_GRAPH_PATH)
# 新增:初始化提示词模板
self.llm_judgment_prompt_template = self._get_judgment_prompt_template()
def load_knowledge_base(self) -> bool:
"""加载知识库文件"""
try:
logger.info("正在加载知识库...")
# 1. 加载FAISS索引
if not os.path.exists(self.vector_db_path):
logger.error(f"向量数据库文件不存在: {self.vector_db_path}")
return False
self.faiss_index = faiss.read_index(self.vector_db_path)
# 2. 加载元数据
if not os.path.exists(self.metadata_path):
logger.error(f"元数据文件不存在: {self.metadata_path}")
return False
with open(self.metadata_path, "r", encoding="utf-8") as f:
self.metadatas = json.load(f)
# 3. 加载知识图谱
if os.path.exists(self.knowledge_graph_path):
with open(self.knowledge_graph_path, "r", encoding="utf-8") as f:
self.knowledge_graph = json.load(f)
# 初始化知识图谱分析器
self.kg_analyzer = KnowledgeGraphAnalyzer(self.knowledge_graph)
else:
logger.warning(f"知识图谱文件不存在: {self.knowledge_graph_path}")
self.knowledge_graph = None
self.kg_analyzer = None
# 4. 重建上下文文本
self.contexts = []
for meta in self.metadatas:
func_name = meta.get("name", "未知函数")
file_path = meta.get("file", "未知文件")
summary = meta.get("summary", "无摘要")
context_str = (
f"【函数名】{func_name}\n"
f"【所在文件】{os.path.basename(file_path)}\n"
f"【功能摘要】{summary}"
)
self.contexts.append(context_str)
logger.info(f"知识库加载成功: {len(self.metadatas)} 个函数")
return True
except Exception as e:
logger.error(f"加载知识库失败: {e}")
return False
def search_by_semantic(self, query: str, top_k: int = TOP_K) -> List[Dict]:
"""语义搜索:基于向量相似度查找相关函数"""
if not self.faiss_index or not self.contexts:
logger.error("知识库未加载")
return []
try:
# 生成查询向量
query_vec = get_qwen_embedding(query)
query_vec = np.expand_dims(query_vec, axis=0).astype(np.float32)
# 搜索最近邻
distances, indices = self.faiss_index.search(query_vec, top_k)
# 处理结果
relevant_functions = []
for i, idx in enumerate(indices[0]):
similarity = 1 - 0.5 * float(distances[0][i])
if similarity >= MIN_SIMILARITY_THRESHOLD:
func_info = {
"index": int(idx),
"similarity": similarity,
"context": self.contexts[idx]
}
# 提取元数据
func_info.update(self._get_function_metadata(idx))
relevant_functions.append(func_info)
return relevant_functions
except Exception as e:
logger.error(f"语义搜索失败: {e}")
return []
def _get_function_metadata(self, index: int) -> Dict:
"""获取函数的元数据"""
if index < 0 or index >= len(self.metadatas):
return {}
meta = self.metadatas[index]
func_name = meta.get("name", "未知函数")
file_path = meta.get("file", "未知文件")
return {
"name": func_name,
"file": file_path,
"summary": meta.get("summary", "无摘要"),
"calls": meta.get("calls", []),
"called_by": meta.get("called_by", [])
}
def _extract_function_name(self, context: str) -> str:
"""从上下文文本中提取函数名"""
match = re.search(r"【函数名】(.*)", context)
if match:
return match.group(1).strip()
return "未知函数"