259 lines
10 KiB
Python
259 lines
10 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
import re
|
||
|
|
from typing import Any, Dict, Iterable, List, Optional
|
||
|
|
|
||
|
|
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter
|
||
|
|
from app.services.code_kb.formatter import format_evidence_context
|
||
|
|
from app.services.code_kb.schema import CodeGraphContext, CodeSearchHit
|
||
|
|
from app.services.consistency.prompt import build_judgment_prompt, build_requirement_query
|
||
|
|
from app.services.consistency.schema import ConsistencyResultItem, RequirementSnapshot, VERDICTS
|
||
|
|
from app.services.consistency.scorer import coverage_score
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
def _clip(value: str, limit: int) -> str:
|
||
|
|
text = value or ""
|
||
|
|
if len(text) <= limit:
|
||
|
|
return text
|
||
|
|
return text[:limit].rstrip() + "\n...[truncated]"
|
||
|
|
|
||
|
|
|
||
|
|
def _as_list(value: Any) -> List[str]:
|
||
|
|
if value is None:
|
||
|
|
return []
|
||
|
|
if isinstance(value, list):
|
||
|
|
return [str(item) for item in value if str(item).strip()]
|
||
|
|
if isinstance(value, tuple):
|
||
|
|
return [str(item) for item in value if str(item).strip()]
|
||
|
|
if isinstance(value, str):
|
||
|
|
text = value.strip()
|
||
|
|
if not text:
|
||
|
|
return []
|
||
|
|
try:
|
||
|
|
parsed = json.loads(text)
|
||
|
|
return _as_list(parsed)
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
return [line.strip() for line in text.splitlines() if line.strip()]
|
||
|
|
return [str(value)]
|
||
|
|
|
||
|
|
|
||
|
|
def requirement_to_snapshot(requirement: Any) -> RequirementSnapshot:
|
||
|
|
getter = requirement.get if isinstance(requirement, dict) else lambda key, default=None: getattr(requirement, key, default)
|
||
|
|
return RequirementSnapshot(
|
||
|
|
requirement_uid=getter("requirement_uid") or getter("id") or "",
|
||
|
|
title=getter("title") or "",
|
||
|
|
description=getter("description") or "",
|
||
|
|
acceptance_criteria=_as_list(getter("acceptance_criteria") or getter("acceptanceCriteria")),
|
||
|
|
requirement_type=getter("requirement_type") or getter("requirementType"),
|
||
|
|
section_title=getter("section_title") or getter("sectionTitle"),
|
||
|
|
interface_name=getter("interface_name") or getter("interfaceName"),
|
||
|
|
interface_type=getter("interface_type") or getter("interfaceType"),
|
||
|
|
data_source=getter("data_source") or getter("dataSource"),
|
||
|
|
data_destination=getter("data_destination") or getter("dataDestination"),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class ConsistencyComparator:
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
code_kb_adapter: CodeKnowledgeBaseAdapter,
|
||
|
|
llm: Any = None,
|
||
|
|
use_llm: bool = True,
|
||
|
|
) -> None:
|
||
|
|
self.code_kb_adapter = code_kb_adapter
|
||
|
|
self.llm = llm
|
||
|
|
self.use_llm = use_llm
|
||
|
|
|
||
|
|
def compare_requirements(
|
||
|
|
self,
|
||
|
|
requirements: Iterable[Any],
|
||
|
|
top_k: int = 8,
|
||
|
|
max_call_hops: int = 2,
|
||
|
|
min_similarity: float = 0.55,
|
||
|
|
) -> List[ConsistencyResultItem]:
|
||
|
|
return [
|
||
|
|
self.compare_requirement(
|
||
|
|
requirement,
|
||
|
|
top_k=top_k,
|
||
|
|
max_call_hops=max_call_hops,
|
||
|
|
min_similarity=min_similarity,
|
||
|
|
)
|
||
|
|
for requirement in requirements
|
||
|
|
]
|
||
|
|
|
||
|
|
def compare_requirement(
|
||
|
|
self,
|
||
|
|
requirement: Any,
|
||
|
|
top_k: int = 8,
|
||
|
|
max_call_hops: int = 2,
|
||
|
|
min_similarity: float = 0.55,
|
||
|
|
) -> ConsistencyResultItem:
|
||
|
|
snapshot = requirement_to_snapshot(requirement)
|
||
|
|
query = build_requirement_query(snapshot)
|
||
|
|
hits = self.code_kb_adapter.search_functions(
|
||
|
|
query=query,
|
||
|
|
top_k=top_k,
|
||
|
|
min_similarity=min_similarity,
|
||
|
|
)
|
||
|
|
contexts = [
|
||
|
|
self.code_kb_adapter.expand_call_context(hit.evidence.node_id, max_hops=max_call_hops)
|
||
|
|
for hit in hits
|
||
|
|
]
|
||
|
|
|
||
|
|
if not hits:
|
||
|
|
judgment = self._missing_judgment("未找到满足相似度阈值的函数证据。")
|
||
|
|
elif not self.use_llm:
|
||
|
|
judgment = self._heuristic_judgment(hits, contexts)
|
||
|
|
else:
|
||
|
|
judgment = self._llm_judgment(snapshot, hits, contexts)
|
||
|
|
|
||
|
|
judgment = self._normalize_judgment(judgment)
|
||
|
|
judgment["requirement_snapshot"] = snapshot.to_dict()
|
||
|
|
score = coverage_score(snapshot, hits, contexts, judgment)
|
||
|
|
matched_functions = [self._matched_function_payload(hit) for hit in hits]
|
||
|
|
call_chains = self._collect_call_chains(contexts)
|
||
|
|
|
||
|
|
return ConsistencyResultItem(
|
||
|
|
requirement_uid=snapshot.requirement_uid,
|
||
|
|
requirement_title=snapshot.title,
|
||
|
|
requirement_type=snapshot.requirement_type,
|
||
|
|
requirement_text=snapshot.description,
|
||
|
|
verdict=judgment["verdict"],
|
||
|
|
coverage_score=score,
|
||
|
|
confidence=float(judgment.get("confidence") or 0.0),
|
||
|
|
matched_functions=matched_functions,
|
||
|
|
covered_points=_as_list(judgment.get("covered_points")),
|
||
|
|
missing_points=_as_list(judgment.get("missing_points")),
|
||
|
|
conflict_points=_as_list(judgment.get("conflict_points")),
|
||
|
|
call_chain_evidence=call_chains,
|
||
|
|
suggestion=str(judgment.get("suggestion") or ""),
|
||
|
|
raw_judgment=judgment,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _llm_judgment(
|
||
|
|
self,
|
||
|
|
requirement: RequirementSnapshot,
|
||
|
|
hits: List[CodeSearchHit],
|
||
|
|
contexts: List[CodeGraphContext],
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
try:
|
||
|
|
evidence_context = format_evidence_context(hits, contexts)
|
||
|
|
prompt = build_judgment_prompt(requirement, evidence_context)
|
||
|
|
from app.services.llm.llm_factory import LLMFactory
|
||
|
|
|
||
|
|
llm = self.llm or LLMFactory.create(temperature=0, streaming=False)
|
||
|
|
response = llm.invoke(prompt) if hasattr(llm, "invoke") else llm(prompt)
|
||
|
|
text = getattr(response, "content", response)
|
||
|
|
return self.parse_json_judgment(str(text))
|
||
|
|
except Exception as exc:
|
||
|
|
logger.exception("LLM consistency judgment failed: %s", exc)
|
||
|
|
return {
|
||
|
|
"verdict": "uncertain",
|
||
|
|
"confidence": 0.2,
|
||
|
|
"covered_points": [],
|
||
|
|
"missing_points": ["模型判定失败,无法可靠确认覆盖情况。"],
|
||
|
|
"conflict_points": [],
|
||
|
|
"primary_evidence": [hit.evidence.node_id for hit in hits[:3]],
|
||
|
|
"reasoning": f"LLM judgment failed: {exc}",
|
||
|
|
"suggestion": "请检查模型配置,或人工复核匹配函数证据。",
|
||
|
|
"fallback": True,
|
||
|
|
}
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def parse_json_judgment(raw_text: str) -> Dict[str, Any]:
|
||
|
|
text = raw_text.strip()
|
||
|
|
if text.startswith("```"):
|
||
|
|
text = re.sub(r"^```(?:json)?", "", text, flags=re.IGNORECASE).strip()
|
||
|
|
text = re.sub(r"```$", "", text).strip()
|
||
|
|
try:
|
||
|
|
return json.loads(text)
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
|
||
|
|
if match:
|
||
|
|
return json.loads(match.group(0))
|
||
|
|
raise
|
||
|
|
|
||
|
|
def _heuristic_judgment(
|
||
|
|
self,
|
||
|
|
hits: List[CodeSearchHit],
|
||
|
|
contexts: List[CodeGraphContext],
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
best = hits[0].similarity if hits else 0.0
|
||
|
|
if best >= 0.78:
|
||
|
|
verdict = "partial"
|
||
|
|
confidence = min(0.68, best)
|
||
|
|
else:
|
||
|
|
verdict = "uncertain"
|
||
|
|
confidence = min(0.5, best)
|
||
|
|
return {
|
||
|
|
"verdict": verdict,
|
||
|
|
"confidence": confidence,
|
||
|
|
"covered_points": [],
|
||
|
|
"missing_points": ["未启用 LLM 判定,无法细分验收准则覆盖点。"],
|
||
|
|
"conflict_points": [],
|
||
|
|
"primary_evidence": [hit.evidence.node_id for hit in hits[:3]],
|
||
|
|
"reasoning": "仅基于向量召回和调用图生成保守判定。",
|
||
|
|
"suggestion": "启用模型判定或人工复核主要匹配函数。",
|
||
|
|
"call_context_count": len(contexts),
|
||
|
|
}
|
||
|
|
|
||
|
|
def _missing_judgment(self, reason: str) -> Dict[str, Any]:
|
||
|
|
return {
|
||
|
|
"verdict": "missing",
|
||
|
|
"confidence": 0.75,
|
||
|
|
"covered_points": [],
|
||
|
|
"missing_points": [reason],
|
||
|
|
"conflict_points": [],
|
||
|
|
"primary_evidence": [],
|
||
|
|
"reasoning": reason,
|
||
|
|
"suggestion": "补充代码实现或降低阈值后重新召回,并人工确认是否存在命名差异。",
|
||
|
|
}
|
||
|
|
|
||
|
|
def _normalize_judgment(self, judgment: Dict[str, Any]) -> Dict[str, Any]:
|
||
|
|
verdict = str(judgment.get("verdict") or "uncertain").strip().lower()
|
||
|
|
if verdict not in VERDICTS:
|
||
|
|
verdict = "uncertain"
|
||
|
|
confidence = judgment.get("confidence", 0.0)
|
||
|
|
try:
|
||
|
|
confidence = max(0.0, min(1.0, float(confidence)))
|
||
|
|
except (TypeError, ValueError):
|
||
|
|
confidence = 0.0
|
||
|
|
normalized = dict(judgment)
|
||
|
|
normalized["verdict"] = verdict
|
||
|
|
normalized["confidence"] = confidence
|
||
|
|
normalized.setdefault("covered_points", [])
|
||
|
|
normalized.setdefault("missing_points", [])
|
||
|
|
normalized.setdefault("conflict_points", [])
|
||
|
|
normalized.setdefault("primary_evidence", [])
|
||
|
|
normalized.setdefault("reasoning", "")
|
||
|
|
normalized.setdefault("suggestion", "")
|
||
|
|
return normalized
|
||
|
|
|
||
|
|
def _matched_function_payload(self, hit: CodeSearchHit) -> Dict[str, Any]:
|
||
|
|
item = hit.evidence
|
||
|
|
return {
|
||
|
|
"node_id": item.node_id,
|
||
|
|
"name": item.name,
|
||
|
|
"file": item.file,
|
||
|
|
"start_line": item.start_line,
|
||
|
|
"end_line": item.end_line,
|
||
|
|
"similarity": round(hit.similarity, 4),
|
||
|
|
"role": item.summary[:120] if item.summary else "",
|
||
|
|
"evidence_summary": item.summary,
|
||
|
|
"logic_flow": _clip(item.logic_flow, 1200),
|
||
|
|
"code_snippet": _clip(item.code_snippet, 2000),
|
||
|
|
"calls": item.calls[:20],
|
||
|
|
"called_by": item.called_by[:20],
|
||
|
|
"signature": item.signature,
|
||
|
|
}
|
||
|
|
|
||
|
|
def _collect_call_chains(self, contexts: List[CodeGraphContext]) -> List[str]:
|
||
|
|
chains: List[str] = []
|
||
|
|
for context in contexts:
|
||
|
|
chains.extend(context.call_chains)
|
||
|
|
return list(dict.fromkeys(chains))[:30]
|