Files
rag_agent/rag-web-ui/backend/app/services/consistency_job_service.py

712 lines
25 KiB
Python

from __future__ import annotations
import json
import os
import shutil
import subprocess
import sys
import zipfile
from collections import Counter
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
from sqlalchemy.orm import Session
from app.db.session import SessionLocal
from app.models.tooling import (
CodeKnowledgeBase,
ConsistencyJob,
ConsistencyResult,
SRSExtraction,
SRSRequirement,
ToolJob,
)
from app.schemas.consistency import CodeKnowledgeBaseCreate, ConsistencyJobCreate
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter, read_code_kb_summary
from app.services.code_kb.formatter import format_evidence_context
from app.services.consistency.comparator import ConsistencyComparator
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.llm.llm_factory import LLMFactory
from app.services.model_config import ModelConfigService
from app.services.srs_job_service import _build_internal_title, _parse_generated_at
from app.tools.srs_reqs_qwen import get_srs_tool
CODE_UPLOAD_ROOT = Path("uploads") / "code_kbs"
AUTO_UPLOAD_ROOT = Path("uploads") / "consistency_auto"
def _workspace_root() -> Path:
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "rag-web-ui").exists() and (parent / "RAG-TEST-TOOLS").exists():
return parent
return current.parents[4]
def _rag_test_tools_root() -> Path:
candidates = [
_workspace_root() / "RAG-TEST-TOOLS",
Path(__file__).resolve().parents[3] / "RAG-TEST-TOOLS",
Path.cwd().parent / "RAG-TEST-TOOLS",
]
for candidate in candidates:
if candidate.exists():
return candidate.resolve()
return candidates[0]
def safe_upload_name(file_name: str | None, fallback: str = "upload.bin") -> str:
safe_name = Path(file_name or fallback).name
return safe_name or fallback
def ensure_upload_dir(*parts: str) -> Path:
path = Path("uploads").joinpath(*parts)
path.mkdir(parents=True, exist_ok=True)
return path
def save_uploaded_bytes(target_dir: Path, file_name: str, content: bytes) -> Path:
target_dir.mkdir(parents=True, exist_ok=True)
path = target_dir / safe_upload_name(file_name)
path.write_bytes(content)
if path.suffix.lower() == ".zip":
extract_dir = target_dir / path.stem
extract_zip_safe(path, extract_dir)
return extract_dir
return path
def extract_zip_safe(zip_path: Path, target_dir: Path) -> None:
target_dir.mkdir(parents=True, exist_ok=True)
target_root = target_dir.resolve()
with zipfile.ZipFile(zip_path) as archive:
for member in archive.infolist():
member_path = (target_dir / member.filename).resolve()
try:
member_path.relative_to(target_root)
except ValueError as exc:
raise ValueError(f"Unsafe zip entry: {member.filename}")
archive.extractall(target_dir)
def _build_code_kb_artifacts(
project_path: str,
output_dir: str,
base_name: str,
use_semantic: bool,
model_profile: Any = None,
) -> Dict[str, str]:
tools_root = _rag_test_tools_root()
if not tools_root.exists():
raise FileNotFoundError(f"RAG-TEST-TOOLS not found: {tools_root}")
output_path = Path(output_dir).resolve()
output_path.mkdir(parents=True, exist_ok=True)
command = [
sys.executable,
"-m",
"rag_test_tools.build_code_kb",
"--project",
str(Path(project_path).resolve()),
"--output",
str(output_path),
"--base-name",
base_name,
]
if not use_semantic:
command.append("--skip-semantic")
env = os.environ.copy()
if model_profile is not None:
api_key = getattr(model_profile, "api_key", "") or ""
api_base = getattr(model_profile, "api_base", "") or ""
if api_key:
env["DASHSCOPE_API_KEY"] = api_key
env["DASH_SCOPE_API_KEY"] = api_key
env["QWEN_API_KEY"] = api_key
if api_base:
env["DASH_SCOPE_API_BASE"] = api_base
env["QWEN_API_URL"] = api_base
if getattr(model_profile, "chat_model", None):
env["QWEN_CHAT_MODEL"] = model_profile.chat_model
if getattr(model_profile, "embedding_model", None):
env["QWEN_EMBEDDING_MODEL"] = model_profile.embedding_model
completed = subprocess.run(
command,
cwd=str(tools_root),
env=env,
capture_output=True,
text=True,
timeout=3600,
check=False,
)
if completed.returncode != 0:
raise RuntimeError(
"Code knowledge base build failed: "
f"{completed.stderr or completed.stdout or completed.returncode}"
)
try:
return json.loads(completed.stdout)
except json.JSONDecodeError as exc:
raise RuntimeError(f"Code KB build returned invalid JSON: {completed.stdout}") from exc
def _ensure_paths_exist(paths: Iterable[str]) -> None:
missing = [path for path in paths if not path or not Path(path).exists()]
if missing:
raise FileNotFoundError(f"Code knowledge base file path does not exist: {missing}")
def create_code_kb(db: Session, user_id: int, payload: CodeKnowledgeBaseCreate) -> CodeKnowledgeBase:
_ensure_paths_exist([payload.vector_path, payload.metadata_path, payload.graph_path])
adapter = CodeKnowledgeBaseAdapter()
adapter.load(payload.vector_path, payload.metadata_path, payload.graph_path)
summary = {
**read_code_kb_summary(payload.metadata_path, payload.graph_path),
**adapter.summary(),
}
code_kb = CodeKnowledgeBase(
user_id=user_id,
name=payload.name,
project_path=payload.project_path,
vector_path=payload.vector_path,
metadata_path=payload.metadata_path,
graph_path=payload.graph_path,
status="active",
metadata_summary=summary,
)
db.add(code_kb)
db.commit()
db.refresh(code_kb)
return code_kb
def create_uploaded_code_kb(
db: Session,
user_id: int,
name: str,
project_path: str,
output_dir: str,
) -> CodeKnowledgeBase:
base_name = f"code_kb_{datetime.utcnow().strftime('%Y%m%d%H%M%S%f')}"
output_path = Path(output_dir).resolve()
code_kb = CodeKnowledgeBase(
user_id=user_id,
name=name,
project_path=project_path,
vector_path=str(output_path / f"{base_name}_rag.faiss"),
metadata_path=str(output_path / f"{base_name}_rag_metadata.json"),
graph_path=str(output_path / f"{base_name}_code_knowledge_graph.json"),
status="pending",
metadata_summary={"base_name": base_name, "output_dir": str(output_path)},
)
db.add(code_kb)
db.commit()
db.refresh(code_kb)
return code_kb
def run_code_kb_build(code_kb_id: int, use_semantic: bool = True) -> None:
db = SessionLocal()
try:
code_kb = db.query(CodeKnowledgeBase).filter(CodeKnowledgeBase.id == code_kb_id).first()
if not code_kb:
return
model_profile = None
if use_semantic:
model_profile = ModelConfigService.require_active_config(db, code_kb.user_id)
ModelConfigService.touch_last_used(db, model_profile)
code_kb.status = "processing"
db.add(code_kb)
db.commit()
summary = code_kb.metadata_summary or {}
base_name = summary.get("base_name") or f"code_kb_{code_kb.id}"
output_dir = summary.get("output_dir") or str(Path(code_kb.vector_path).parent)
artifact_paths = _build_code_kb_artifacts(
project_path=code_kb.project_path or "",
output_dir=output_dir,
base_name=base_name,
use_semantic=use_semantic,
model_profile=model_profile,
)
code_kb.graph_path = artifact_paths["graph_path"]
code_kb.vector_path = artifact_paths["vector_path"]
code_kb.metadata_path = artifact_paths["metadata_path"]
embedding_function = (
EmbeddingsFactory.create(model_profile=model_profile)
if model_profile is not None
else None
)
adapter = CodeKnowledgeBaseAdapter(embedding_function=embedding_function)
adapter.load(code_kb.vector_path, code_kb.metadata_path, code_kb.graph_path)
code_kb.status = "active"
code_kb.metadata_summary = {
**read_code_kb_summary(code_kb.metadata_path, code_kb.graph_path),
**adapter.summary(),
"source": "upload",
}
db.add(code_kb)
db.commit()
except Exception as exc:
if "code_kb" in locals() and code_kb:
code_kb.status = "failed"
code_kb.metadata_summary = {
**(code_kb.metadata_summary or {}),
"error_message": str(exc)[:2000],
}
db.add(code_kb)
db.commit()
finally:
db.close()
def list_code_kbs(db: Session, user_id: int) -> List[CodeKnowledgeBase]:
return (
db.query(CodeKnowledgeBase)
.filter(CodeKnowledgeBase.user_id == user_id)
.order_by(CodeKnowledgeBase.created_at.desc())
.all()
)
def get_owned_code_kb(db: Session, user_id: int, code_kb_id: int) -> Optional[CodeKnowledgeBase]:
return (
db.query(CodeKnowledgeBase)
.filter(CodeKnowledgeBase.id == code_kb_id, CodeKnowledgeBase.user_id == user_id)
.first()
)
def get_owned_srs_extraction(db: Session, user_id: int, extraction_id: int) -> Optional[SRSExtraction]:
return (
db.query(SRSExtraction)
.join(ToolJob, SRSExtraction.job_id == ToolJob.id)
.filter(SRSExtraction.id == extraction_id, ToolJob.user_id == user_id)
.first()
)
def create_consistency_job(
db: Session,
user_id: int,
payload: ConsistencyJobCreate,
) -> ConsistencyJob:
extraction = get_owned_srs_extraction(db, user_id, payload.srs_extraction_id)
if not extraction:
raise ValueError("SRS extraction does not exist.")
code_kb = get_owned_code_kb(db, user_id, payload.code_kb_id)
if not code_kb:
raise ValueError("Code knowledge base does not exist.")
if code_kb.status != "active":
raise ValueError("Code knowledge base is not active.")
requirement_query = db.query(SRSRequirement).filter(SRSRequirement.extraction_id == extraction.id)
if payload.requirement_uids:
requirement_query = requirement_query.filter(SRSRequirement.requirement_uid.in_(payload.requirement_uids))
total = requirement_query.count()
if total == 0:
raise ValueError("No SRS requirements matched the selected scope.")
job = ConsistencyJob(
user_id=user_id,
srs_extraction_id=extraction.id,
code_kb_id=code_kb.id,
status="pending",
total_requirements=total,
completed_requirements=0,
output_summary={
"requirement_uids": payload.requirement_uids,
"top_k": payload.top_k,
"max_call_hops": payload.max_call_hops,
"min_similarity": payload.min_similarity,
"use_llm": payload.use_llm,
},
)
db.add(job)
db.commit()
db.refresh(job)
return job
def list_consistency_jobs(db: Session, user_id: int) -> List[ConsistencyJob]:
return (
db.query(ConsistencyJob)
.filter(ConsistencyJob.user_id == user_id)
.order_by(ConsistencyJob.created_at.desc())
.all()
)
def get_owned_consistency_job(db: Session, user_id: int, job_id: int) -> Optional[ConsistencyJob]:
return (
db.query(ConsistencyJob)
.filter(ConsistencyJob.id == job_id, ConsistencyJob.user_id == user_id)
.first()
)
def ask_code_kb(
code_kb: CodeKnowledgeBase,
question: str,
top_k: int = 6,
min_similarity: float = 0.0,
use_llm: bool = True,
model_profile: Any = None,
) -> Dict[str, Any]:
if code_kb.status != "active":
raise ValueError("Code knowledge base is not active.")
if model_profile is None:
raise ValueError("请先在 API 密钥页面新增并启用模型配置。")
adapter = CodeKnowledgeBaseAdapter(
embedding_function=EmbeddingsFactory.create(model_profile=model_profile)
)
adapter.load(code_kb.vector_path, code_kb.metadata_path, code_kb.graph_path)
hits = adapter.search_functions(question, top_k=top_k, min_similarity=min_similarity)
contexts = [adapter.expand_call_context(hit.evidence.node_id, max_hops=2) for hit in hits]
evidence = [hit.to_dict() for hit in hits]
if not hits:
return {
"answer": "未检索到相关函数证据,无法基于代码知识库回答。",
"evidence": [],
"raw_response": None,
}
evidence_context = format_evidence_context(hits, contexts)
if not use_llm:
return {
"answer": "已检索到相关函数证据,请查看 evidence 字段中的函数摘要、文件位置和调用链。",
"evidence": evidence,
"raw_response": None,
}
prompt = (
"你是代码知识库问答助手。只能基于给定代码证据回答问题;"
"如果证据不足,请明确说明不足。回答需要包含关键函数名和文件位置。\n\n"
f"问题:{question}\n\n代码证据:\n{evidence_context}"
)
try:
llm = LLMFactory.create(temperature=0, streaming=False, model_profile=model_profile)
response = llm.invoke(prompt) if hasattr(llm, "invoke") else llm(prompt)
answer = str(getattr(response, "content", response))
return {"answer": answer, "evidence": evidence, "raw_response": answer}
except Exception as exc:
return {
"answer": f"模型问答失败,已返回检索证据供人工查看。错误:{exc}",
"evidence": evidence,
"raw_response": None,
}
def result_model_to_export_dict(result: ConsistencyResult) -> Dict[str, Any]:
raw = result.raw_judgment or {}
return {
"requirement_uid": result.requirement_uid,
"requirement_title": raw.get("requirement_title", ""),
"requirement_type": raw.get("requirement_type"),
"requirement_text": raw.get("requirement_text", ""),
"verdict": result.verdict,
"coverage_score": result.coverage_score,
"confidence": result.confidence,
"matched_functions": result.matched_functions or [],
"covered_points": result.covered_points or [],
"missing_points": result.missing_points or [],
"conflict_points": result.conflict_points or [],
"call_chain_evidence": result.call_chain_evidence or [],
"suggestion": result.suggestion or "",
"raw_judgment": raw,
}
def _store_result(db: Session, job: ConsistencyJob, result: Any) -> None:
result_dict = result.to_dict()
raw_judgment = dict(result_dict.get("raw_judgment") or {})
raw_judgment.update(
{
"requirement_title": result_dict.get("requirement_title"),
"requirement_type": result_dict.get("requirement_type"),
"requirement_text": result_dict.get("requirement_text"),
}
)
db.add(
ConsistencyResult(
job_id=job.id,
requirement_uid=result.requirement_uid,
verdict=result.verdict,
coverage_score=result.coverage_score,
confidence=result.confidence,
matched_functions=result.matched_functions,
covered_points=result.covered_points,
missing_points=result.missing_points,
conflict_points=result.conflict_points,
call_chain_evidence=result.call_chain_evidence,
suggestion=result.suggestion,
raw_judgment=raw_judgment,
)
)
def _create_srs_extraction_for_job(db: Session, job: ToolJob) -> SRSExtraction:
model_profile = ModelConfigService.require_active_config(db, job.user_id)
ModelConfigService.touch_last_used(db, model_profile)
payload = get_srs_tool().run(job.input_file_path, model_profile=model_profile)
extraction = SRSExtraction(
job_id=job.id,
document_name=payload["document_name"],
document_title=payload.get("document_title") or payload["document_name"],
generated_at=_parse_generated_at(payload.get("generated_at")),
total_requirements=len(payload.get("requirements", [])),
statistics=payload.get("statistics", {}),
raw_output=payload.get("raw_output", {}),
)
db.add(extraction)
db.flush()
for index, item in enumerate(payload.get("requirements", [])):
requirement = SRSRequirement(
extraction_id=extraction.id,
requirement_uid=item.get("id") or f"REQ-{index + 1:03d}",
title=_build_internal_title(item.get("description"), item.get("id") or "", index),
description=item.get("description") or "",
priority=item.get("priority") or "",
acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"],
source_field=item.get("source_field") or "文档解析",
section_uid=item.get("section_uid"),
section_number=item.get("section_number"),
section_title=item.get("section_title"),
requirement_type=item.get("requirement_type"),
interface_name=item.get("interface_name"),
interface_type=item.get("interface_type"),
data_source=item.get("data_source"),
data_destination=item.get("data_destination"),
sort_order=int(item.get("sort_order") or index),
)
db.add(requirement)
return extraction
def create_auto_consistency_tool_job(
db: Session,
user_id: int,
requirement_file_path: str,
requirement_file_name: str,
code_source_dir: str,
code_kb_name: str,
top_k: int,
max_call_hops: int,
min_similarity: float,
use_llm: bool,
use_semantic: bool,
) -> ToolJob:
job = ToolJob(
user_id=user_id,
tool_name="consistency.auto_compare",
status="pending",
input_file_name=requirement_file_name,
input_file_path=requirement_file_path,
output_summary={
"current_step": "pending",
"code_source_dir": code_source_dir,
"code_kb_name": code_kb_name,
"top_k": top_k,
"max_call_hops": max_call_hops,
"min_similarity": min_similarity,
"use_llm": use_llm,
"use_semantic": use_semantic,
},
)
db.add(job)
db.commit()
db.refresh(job)
return job
def get_owned_auto_job(db: Session, user_id: int, job_id: int) -> Optional[ToolJob]:
return (
db.query(ToolJob)
.filter(
ToolJob.id == job_id,
ToolJob.user_id == user_id,
ToolJob.tool_name == "consistency.auto_compare",
)
.first()
)
def run_auto_consistency_job(tool_job_id: int) -> None:
db = SessionLocal()
try:
tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first()
if not tool_job:
return
options = tool_job.output_summary or {}
tool_job.status = "processing"
tool_job.started_at = datetime.utcnow()
tool_job.output_summary = {**options, "current_step": "extracting_requirements"}
db.add(tool_job)
db.commit()
extraction = _create_srs_extraction_for_job(db, tool_job)
db.commit()
options = tool_job.output_summary or options
code_output_dir = str((AUTO_UPLOAD_ROOT / str(tool_job.id) / "code_kb").resolve())
code_kb = create_uploaded_code_kb(
db,
tool_job.user_id,
options.get("code_kb_name") or f"auto-code-kb-{tool_job.id}",
options["code_source_dir"],
code_output_dir,
)
tool_job.output_summary = {
**options,
"current_step": "building_code_kb",
"srs_extraction_id": extraction.id,
"code_kb_id": code_kb.id,
}
db.add(tool_job)
db.commit()
db.close()
run_code_kb_build(code_kb.id, use_semantic=bool(options.get("use_semantic", True)))
db = SessionLocal()
code_kb = db.query(CodeKnowledgeBase).filter(CodeKnowledgeBase.id == code_kb.id).first()
tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first()
if not tool_job or not code_kb:
return
if code_kb.status != "active":
raise RuntimeError((code_kb.metadata_summary or {}).get("error_message") or "Code KB build failed.")
consistency_payload = ConsistencyJobCreate(
srs_extraction_id=extraction.id,
code_kb_id=code_kb.id,
top_k=int(options.get("top_k", 8)),
max_call_hops=int(options.get("max_call_hops", 2)),
min_similarity=float(options.get("min_similarity", 0.55)),
use_llm=bool(options.get("use_llm", True)),
)
consistency_job = create_consistency_job(db, tool_job.user_id, consistency_payload)
tool_job.output_summary = {
**(tool_job.output_summary or {}),
"current_step": "comparing",
"consistency_job_id": consistency_job.id,
}
db.add(tool_job)
db.commit()
db.close()
run_consistency_job(consistency_job.id)
db = SessionLocal()
consistency_job = db.query(ConsistencyJob).filter(ConsistencyJob.id == consistency_job.id).first()
tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first()
if not tool_job or not consistency_job:
return
if consistency_job.status == "failed":
raise RuntimeError(consistency_job.error_message or "Consistency comparison failed.")
tool_job.status = "completed"
tool_job.completed_at = datetime.utcnow()
tool_job.output_summary = {
**(tool_job.output_summary or {}),
"current_step": "completed",
"consistency_job_id": consistency_job.id,
}
db.add(tool_job)
db.commit()
except Exception as exc:
if "db" not in locals() or db is None:
db = SessionLocal()
tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first()
if tool_job:
tool_job.status = "failed"
tool_job.error_message = str(exc)[:2000]
tool_job.completed_at = datetime.utcnow()
tool_job.output_summary = {
**(tool_job.output_summary or {}),
"current_step": "failed",
}
db.add(tool_job)
db.commit()
finally:
db.close()
def run_consistency_job(job_id: int) -> None:
db = SessionLocal()
try:
job = db.query(ConsistencyJob).filter(ConsistencyJob.id == job_id).first()
if not job:
return
job.status = "processing"
job.started_at = datetime.utcnow()
db.add(job)
db.commit()
options = job.output_summary or {}
code_kb = db.query(CodeKnowledgeBase).filter(CodeKnowledgeBase.id == job.code_kb_id).first()
if not code_kb:
raise RuntimeError("Code knowledge base does not exist.")
model_profile = ModelConfigService.require_active_config(db, job.user_id)
ModelConfigService.touch_last_used(db, model_profile)
adapter = CodeKnowledgeBaseAdapter(
embedding_function=EmbeddingsFactory.create(model_profile=model_profile)
)
adapter.load(code_kb.vector_path, code_kb.metadata_path, code_kb.graph_path)
llm = None
if bool(options.get("use_llm", True)):
llm = LLMFactory.create(temperature=0, streaming=False, model_profile=model_profile)
comparator = ConsistencyComparator(
adapter,
llm=llm,
use_llm=bool(options.get("use_llm", True)),
)
query = (
db.query(SRSRequirement)
.filter(SRSRequirement.extraction_id == job.srs_extraction_id)
.order_by(SRSRequirement.sort_order)
)
requirement_uids = options.get("requirement_uids")
if requirement_uids:
query = query.filter(SRSRequirement.requirement_uid.in_(requirement_uids))
requirements = query.all()
job.total_requirements = len(requirements)
db.add(job)
db.commit()
verdict_counter: Counter[str] = Counter()
for requirement in requirements:
result = comparator.compare_requirement(
requirement,
top_k=int(options.get("top_k", 8)),
max_call_hops=int(options.get("max_call_hops", 2)),
min_similarity=float(options.get("min_similarity", 0.55)),
)
_store_result(db, job, result)
verdict_counter[result.verdict] += 1
job.completed_requirements += 1
job.output_summary = {**options, "verdict_counts": dict(verdict_counter)}
db.add(job)
db.commit()
job.status = "completed"
job.completed_at = datetime.utcnow()
job.output_summary = {**options, "verdict_counts": dict(verdict_counter)}
db.add(job)
db.commit()
except Exception as exc:
if "job" in locals() and job:
job.status = "failed"
job.error_message = str(exc)
job.completed_at = datetime.utcnow()
db.add(job)
db.commit()
finally:
db.close()