from __future__ import annotations import json import os import shutil import subprocess import sys import zipfile from collections import Counter from datetime import datetime from pathlib import Path from typing import Any, Dict, Iterable, List, Optional from sqlalchemy.orm import Session from app.db.session import SessionLocal from app.models.tooling import ( CodeKnowledgeBase, ConsistencyJob, ConsistencyResult, SRSExtraction, SRSRequirement, ToolJob, ) from app.schemas.consistency import CodeKnowledgeBaseCreate, ConsistencyJobCreate from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter, read_code_kb_summary from app.services.code_kb.formatter import format_evidence_context from app.services.consistency.comparator import ConsistencyComparator from app.services.embedding.embedding_factory import EmbeddingsFactory from app.services.llm.llm_factory import LLMFactory from app.services.model_config import ModelConfigService from app.services.srs_job_service import _build_internal_title, _parse_generated_at from app.tools.srs_reqs_qwen import get_srs_tool CODE_UPLOAD_ROOT = Path("uploads") / "code_kbs" AUTO_UPLOAD_ROOT = Path("uploads") / "consistency_auto" def _workspace_root() -> Path: current = Path(__file__).resolve() for parent in current.parents: if (parent / "rag-web-ui").exists() and (parent / "RAG-TEST-TOOLS").exists(): return parent return current.parents[4] def _rag_test_tools_root() -> Path: candidates = [ _workspace_root() / "RAG-TEST-TOOLS", Path(__file__).resolve().parents[3] / "RAG-TEST-TOOLS", Path.cwd().parent / "RAG-TEST-TOOLS", ] for candidate in candidates: if candidate.exists(): return candidate.resolve() return candidates[0] def safe_upload_name(file_name: str | None, fallback: str = "upload.bin") -> str: safe_name = Path(file_name or fallback).name return safe_name or fallback def ensure_upload_dir(*parts: str) -> Path: path = Path("uploads").joinpath(*parts) path.mkdir(parents=True, exist_ok=True) return path def save_uploaded_bytes(target_dir: Path, file_name: str, content: bytes) -> Path: target_dir.mkdir(parents=True, exist_ok=True) path = target_dir / safe_upload_name(file_name) path.write_bytes(content) if path.suffix.lower() == ".zip": extract_dir = target_dir / path.stem extract_zip_safe(path, extract_dir) return extract_dir return path def extract_zip_safe(zip_path: Path, target_dir: Path) -> None: target_dir.mkdir(parents=True, exist_ok=True) target_root = target_dir.resolve() with zipfile.ZipFile(zip_path) as archive: for member in archive.infolist(): member_path = (target_dir / member.filename).resolve() try: member_path.relative_to(target_root) except ValueError as exc: raise ValueError(f"Unsafe zip entry: {member.filename}") archive.extractall(target_dir) def _build_code_kb_artifacts( project_path: str, output_dir: str, base_name: str, use_semantic: bool, model_profile: Any = None, ) -> Dict[str, str]: tools_root = _rag_test_tools_root() if not tools_root.exists(): raise FileNotFoundError(f"RAG-TEST-TOOLS not found: {tools_root}") output_path = Path(output_dir).resolve() output_path.mkdir(parents=True, exist_ok=True) command = [ sys.executable, "-m", "rag_test_tools.build_code_kb", "--project", str(Path(project_path).resolve()), "--output", str(output_path), "--base-name", base_name, ] if not use_semantic: command.append("--skip-semantic") env = os.environ.copy() if model_profile is not None: api_key = getattr(model_profile, "api_key", "") or "" api_base = getattr(model_profile, "api_base", "") or "" if api_key: env["DASHSCOPE_API_KEY"] = api_key env["DASH_SCOPE_API_KEY"] = api_key env["QWEN_API_KEY"] = api_key if api_base: env["DASH_SCOPE_API_BASE"] = api_base env["QWEN_API_URL"] = api_base if getattr(model_profile, "chat_model", None): env["QWEN_CHAT_MODEL"] = model_profile.chat_model if getattr(model_profile, "embedding_model", None): env["QWEN_EMBEDDING_MODEL"] = model_profile.embedding_model completed = subprocess.run( command, cwd=str(tools_root), env=env, capture_output=True, text=True, timeout=3600, check=False, ) if completed.returncode != 0: raise RuntimeError( "Code knowledge base build failed: " f"{completed.stderr or completed.stdout or completed.returncode}" ) try: return json.loads(completed.stdout) except json.JSONDecodeError as exc: raise RuntimeError(f"Code KB build returned invalid JSON: {completed.stdout}") from exc def _ensure_paths_exist(paths: Iterable[str]) -> None: missing = [path for path in paths if not path or not Path(path).exists()] if missing: raise FileNotFoundError(f"Code knowledge base file path does not exist: {missing}") def create_code_kb(db: Session, user_id: int, payload: CodeKnowledgeBaseCreate) -> CodeKnowledgeBase: _ensure_paths_exist([payload.vector_path, payload.metadata_path, payload.graph_path]) adapter = CodeKnowledgeBaseAdapter() adapter.load(payload.vector_path, payload.metadata_path, payload.graph_path) summary = { **read_code_kb_summary(payload.metadata_path, payload.graph_path), **adapter.summary(), } code_kb = CodeKnowledgeBase( user_id=user_id, name=payload.name, project_path=payload.project_path, vector_path=payload.vector_path, metadata_path=payload.metadata_path, graph_path=payload.graph_path, status="active", metadata_summary=summary, ) db.add(code_kb) db.commit() db.refresh(code_kb) return code_kb def create_uploaded_code_kb( db: Session, user_id: int, name: str, project_path: str, output_dir: str, ) -> CodeKnowledgeBase: base_name = f"code_kb_{datetime.utcnow().strftime('%Y%m%d%H%M%S%f')}" output_path = Path(output_dir).resolve() code_kb = CodeKnowledgeBase( user_id=user_id, name=name, project_path=project_path, vector_path=str(output_path / f"{base_name}_rag.faiss"), metadata_path=str(output_path / f"{base_name}_rag_metadata.json"), graph_path=str(output_path / f"{base_name}_code_knowledge_graph.json"), status="pending", metadata_summary={"base_name": base_name, "output_dir": str(output_path)}, ) db.add(code_kb) db.commit() db.refresh(code_kb) return code_kb def run_code_kb_build(code_kb_id: int, use_semantic: bool = True) -> None: db = SessionLocal() try: code_kb = db.query(CodeKnowledgeBase).filter(CodeKnowledgeBase.id == code_kb_id).first() if not code_kb: return model_profile = None if use_semantic: model_profile = ModelConfigService.require_active_config(db, code_kb.user_id) ModelConfigService.touch_last_used(db, model_profile) code_kb.status = "processing" db.add(code_kb) db.commit() summary = code_kb.metadata_summary or {} base_name = summary.get("base_name") or f"code_kb_{code_kb.id}" output_dir = summary.get("output_dir") or str(Path(code_kb.vector_path).parent) artifact_paths = _build_code_kb_artifacts( project_path=code_kb.project_path or "", output_dir=output_dir, base_name=base_name, use_semantic=use_semantic, model_profile=model_profile, ) code_kb.graph_path = artifact_paths["graph_path"] code_kb.vector_path = artifact_paths["vector_path"] code_kb.metadata_path = artifact_paths["metadata_path"] embedding_function = ( EmbeddingsFactory.create(model_profile=model_profile) if model_profile is not None else None ) adapter = CodeKnowledgeBaseAdapter(embedding_function=embedding_function) adapter.load(code_kb.vector_path, code_kb.metadata_path, code_kb.graph_path) code_kb.status = "active" code_kb.metadata_summary = { **read_code_kb_summary(code_kb.metadata_path, code_kb.graph_path), **adapter.summary(), "source": "upload", } db.add(code_kb) db.commit() except Exception as exc: if "code_kb" in locals() and code_kb: code_kb.status = "failed" code_kb.metadata_summary = { **(code_kb.metadata_summary or {}), "error_message": str(exc)[:2000], } db.add(code_kb) db.commit() finally: db.close() def list_code_kbs(db: Session, user_id: int) -> List[CodeKnowledgeBase]: return ( db.query(CodeKnowledgeBase) .filter(CodeKnowledgeBase.user_id == user_id) .order_by(CodeKnowledgeBase.created_at.desc()) .all() ) def get_owned_code_kb(db: Session, user_id: int, code_kb_id: int) -> Optional[CodeKnowledgeBase]: return ( db.query(CodeKnowledgeBase) .filter(CodeKnowledgeBase.id == code_kb_id, CodeKnowledgeBase.user_id == user_id) .first() ) def get_owned_srs_extraction(db: Session, user_id: int, extraction_id: int) -> Optional[SRSExtraction]: return ( db.query(SRSExtraction) .join(ToolJob, SRSExtraction.job_id == ToolJob.id) .filter(SRSExtraction.id == extraction_id, ToolJob.user_id == user_id) .first() ) def create_consistency_job( db: Session, user_id: int, payload: ConsistencyJobCreate, ) -> ConsistencyJob: extraction = get_owned_srs_extraction(db, user_id, payload.srs_extraction_id) if not extraction: raise ValueError("SRS extraction does not exist.") code_kb = get_owned_code_kb(db, user_id, payload.code_kb_id) if not code_kb: raise ValueError("Code knowledge base does not exist.") if code_kb.status != "active": raise ValueError("Code knowledge base is not active.") requirement_query = db.query(SRSRequirement).filter(SRSRequirement.extraction_id == extraction.id) if payload.requirement_uids: requirement_query = requirement_query.filter(SRSRequirement.requirement_uid.in_(payload.requirement_uids)) total = requirement_query.count() if total == 0: raise ValueError("No SRS requirements matched the selected scope.") job = ConsistencyJob( user_id=user_id, srs_extraction_id=extraction.id, code_kb_id=code_kb.id, status="pending", total_requirements=total, completed_requirements=0, output_summary={ "requirement_uids": payload.requirement_uids, "top_k": payload.top_k, "max_call_hops": payload.max_call_hops, "min_similarity": payload.min_similarity, "use_llm": payload.use_llm, }, ) db.add(job) db.commit() db.refresh(job) return job def list_consistency_jobs(db: Session, user_id: int) -> List[ConsistencyJob]: return ( db.query(ConsistencyJob) .filter(ConsistencyJob.user_id == user_id) .order_by(ConsistencyJob.created_at.desc()) .all() ) def get_owned_consistency_job(db: Session, user_id: int, job_id: int) -> Optional[ConsistencyJob]: return ( db.query(ConsistencyJob) .filter(ConsistencyJob.id == job_id, ConsistencyJob.user_id == user_id) .first() ) def ask_code_kb( code_kb: CodeKnowledgeBase, question: str, top_k: int = 6, min_similarity: float = 0.0, use_llm: bool = True, model_profile: Any = None, ) -> Dict[str, Any]: if code_kb.status != "active": raise ValueError("Code knowledge base is not active.") if model_profile is None: raise ValueError("请先在 API 密钥页面新增并启用模型配置。") adapter = CodeKnowledgeBaseAdapter( embedding_function=EmbeddingsFactory.create(model_profile=model_profile) ) adapter.load(code_kb.vector_path, code_kb.metadata_path, code_kb.graph_path) hits = adapter.search_functions(question, top_k=top_k, min_similarity=min_similarity) contexts = [adapter.expand_call_context(hit.evidence.node_id, max_hops=2) for hit in hits] evidence = [hit.to_dict() for hit in hits] if not hits: return { "answer": "未检索到相关函数证据,无法基于代码知识库回答。", "evidence": [], "raw_response": None, } evidence_context = format_evidence_context(hits, contexts) if not use_llm: return { "answer": "已检索到相关函数证据,请查看 evidence 字段中的函数摘要、文件位置和调用链。", "evidence": evidence, "raw_response": None, } prompt = ( "你是代码知识库问答助手。只能基于给定代码证据回答问题;" "如果证据不足,请明确说明不足。回答需要包含关键函数名和文件位置。\n\n" f"问题:{question}\n\n代码证据:\n{evidence_context}" ) try: llm = LLMFactory.create(temperature=0, streaming=False, model_profile=model_profile) response = llm.invoke(prompt) if hasattr(llm, "invoke") else llm(prompt) answer = str(getattr(response, "content", response)) return {"answer": answer, "evidence": evidence, "raw_response": answer} except Exception as exc: return { "answer": f"模型问答失败,已返回检索证据供人工查看。错误:{exc}", "evidence": evidence, "raw_response": None, } def result_model_to_export_dict(result: ConsistencyResult) -> Dict[str, Any]: raw = result.raw_judgment or {} return { "requirement_uid": result.requirement_uid, "requirement_title": raw.get("requirement_title", ""), "requirement_type": raw.get("requirement_type"), "requirement_text": raw.get("requirement_text", ""), "verdict": result.verdict, "coverage_score": result.coverage_score, "confidence": result.confidence, "matched_functions": result.matched_functions or [], "covered_points": result.covered_points or [], "missing_points": result.missing_points or [], "conflict_points": result.conflict_points or [], "call_chain_evidence": result.call_chain_evidence or [], "suggestion": result.suggestion or "", "raw_judgment": raw, } def _store_result(db: Session, job: ConsistencyJob, result: Any) -> None: result_dict = result.to_dict() raw_judgment = dict(result_dict.get("raw_judgment") or {}) raw_judgment.update( { "requirement_title": result_dict.get("requirement_title"), "requirement_type": result_dict.get("requirement_type"), "requirement_text": result_dict.get("requirement_text"), } ) db.add( ConsistencyResult( job_id=job.id, requirement_uid=result.requirement_uid, verdict=result.verdict, coverage_score=result.coverage_score, confidence=result.confidence, matched_functions=result.matched_functions, covered_points=result.covered_points, missing_points=result.missing_points, conflict_points=result.conflict_points, call_chain_evidence=result.call_chain_evidence, suggestion=result.suggestion, raw_judgment=raw_judgment, ) ) def _create_srs_extraction_for_job(db: Session, job: ToolJob) -> SRSExtraction: model_profile = ModelConfigService.require_active_config(db, job.user_id) ModelConfigService.touch_last_used(db, model_profile) payload = get_srs_tool().run(job.input_file_path, model_profile=model_profile) extraction = SRSExtraction( job_id=job.id, document_name=payload["document_name"], document_title=payload.get("document_title") or payload["document_name"], generated_at=_parse_generated_at(payload.get("generated_at")), total_requirements=len(payload.get("requirements", [])), statistics=payload.get("statistics", {}), raw_output=payload.get("raw_output", {}), ) db.add(extraction) db.flush() for index, item in enumerate(payload.get("requirements", [])): requirement = SRSRequirement( extraction_id=extraction.id, requirement_uid=item.get("id") or f"REQ-{index + 1:03d}", title=_build_internal_title(item.get("description"), item.get("id") or "", index), description=item.get("description") or "", priority=item.get("priority") or "中", acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"], source_field=item.get("source_field") or "文档解析", section_uid=item.get("section_uid"), section_number=item.get("section_number"), section_title=item.get("section_title"), requirement_type=item.get("requirement_type"), interface_name=item.get("interface_name"), interface_type=item.get("interface_type"), data_source=item.get("data_source"), data_destination=item.get("data_destination"), sort_order=int(item.get("sort_order") or index), ) db.add(requirement) return extraction def create_auto_consistency_tool_job( db: Session, user_id: int, requirement_file_path: str, requirement_file_name: str, code_source_dir: str, code_kb_name: str, top_k: int, max_call_hops: int, min_similarity: float, use_llm: bool, use_semantic: bool, ) -> ToolJob: job = ToolJob( user_id=user_id, tool_name="consistency.auto_compare", status="pending", input_file_name=requirement_file_name, input_file_path=requirement_file_path, output_summary={ "current_step": "pending", "code_source_dir": code_source_dir, "code_kb_name": code_kb_name, "top_k": top_k, "max_call_hops": max_call_hops, "min_similarity": min_similarity, "use_llm": use_llm, "use_semantic": use_semantic, }, ) db.add(job) db.commit() db.refresh(job) return job def get_owned_auto_job(db: Session, user_id: int, job_id: int) -> Optional[ToolJob]: return ( db.query(ToolJob) .filter( ToolJob.id == job_id, ToolJob.user_id == user_id, ToolJob.tool_name == "consistency.auto_compare", ) .first() ) def run_auto_consistency_job(tool_job_id: int) -> None: db = SessionLocal() try: tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first() if not tool_job: return options = tool_job.output_summary or {} tool_job.status = "processing" tool_job.started_at = datetime.utcnow() tool_job.output_summary = {**options, "current_step": "extracting_requirements"} db.add(tool_job) db.commit() extraction = _create_srs_extraction_for_job(db, tool_job) db.commit() options = tool_job.output_summary or options code_output_dir = str((AUTO_UPLOAD_ROOT / str(tool_job.id) / "code_kb").resolve()) code_kb = create_uploaded_code_kb( db, tool_job.user_id, options.get("code_kb_name") or f"auto-code-kb-{tool_job.id}", options["code_source_dir"], code_output_dir, ) tool_job.output_summary = { **options, "current_step": "building_code_kb", "srs_extraction_id": extraction.id, "code_kb_id": code_kb.id, } db.add(tool_job) db.commit() db.close() run_code_kb_build(code_kb.id, use_semantic=bool(options.get("use_semantic", True))) db = SessionLocal() code_kb = db.query(CodeKnowledgeBase).filter(CodeKnowledgeBase.id == code_kb.id).first() tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first() if not tool_job or not code_kb: return if code_kb.status != "active": raise RuntimeError((code_kb.metadata_summary or {}).get("error_message") or "Code KB build failed.") consistency_payload = ConsistencyJobCreate( srs_extraction_id=extraction.id, code_kb_id=code_kb.id, top_k=int(options.get("top_k", 8)), max_call_hops=int(options.get("max_call_hops", 2)), min_similarity=float(options.get("min_similarity", 0.55)), use_llm=bool(options.get("use_llm", True)), ) consistency_job = create_consistency_job(db, tool_job.user_id, consistency_payload) tool_job.output_summary = { **(tool_job.output_summary or {}), "current_step": "comparing", "consistency_job_id": consistency_job.id, } db.add(tool_job) db.commit() db.close() run_consistency_job(consistency_job.id) db = SessionLocal() consistency_job = db.query(ConsistencyJob).filter(ConsistencyJob.id == consistency_job.id).first() tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first() if not tool_job or not consistency_job: return if consistency_job.status == "failed": raise RuntimeError(consistency_job.error_message or "Consistency comparison failed.") tool_job.status = "completed" tool_job.completed_at = datetime.utcnow() tool_job.output_summary = { **(tool_job.output_summary or {}), "current_step": "completed", "consistency_job_id": consistency_job.id, } db.add(tool_job) db.commit() except Exception as exc: if "db" not in locals() or db is None: db = SessionLocal() tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first() if tool_job: tool_job.status = "failed" tool_job.error_message = str(exc)[:2000] tool_job.completed_at = datetime.utcnow() tool_job.output_summary = { **(tool_job.output_summary or {}), "current_step": "failed", } db.add(tool_job) db.commit() finally: db.close() def run_consistency_job(job_id: int) -> None: db = SessionLocal() try: job = db.query(ConsistencyJob).filter(ConsistencyJob.id == job_id).first() if not job: return job.status = "processing" job.started_at = datetime.utcnow() db.add(job) db.commit() options = job.output_summary or {} code_kb = db.query(CodeKnowledgeBase).filter(CodeKnowledgeBase.id == job.code_kb_id).first() if not code_kb: raise RuntimeError("Code knowledge base does not exist.") model_profile = ModelConfigService.require_active_config(db, job.user_id) ModelConfigService.touch_last_used(db, model_profile) adapter = CodeKnowledgeBaseAdapter( embedding_function=EmbeddingsFactory.create(model_profile=model_profile) ) adapter.load(code_kb.vector_path, code_kb.metadata_path, code_kb.graph_path) llm = None if bool(options.get("use_llm", True)): llm = LLMFactory.create(temperature=0, streaming=False, model_profile=model_profile) comparator = ConsistencyComparator( adapter, llm=llm, use_llm=bool(options.get("use_llm", True)), ) query = ( db.query(SRSRequirement) .filter(SRSRequirement.extraction_id == job.srs_extraction_id) .order_by(SRSRequirement.sort_order) ) requirement_uids = options.get("requirement_uids") if requirement_uids: query = query.filter(SRSRequirement.requirement_uid.in_(requirement_uids)) requirements = query.all() job.total_requirements = len(requirements) db.add(job) db.commit() verdict_counter: Counter[str] = Counter() for requirement in requirements: result = comparator.compare_requirement( requirement, top_k=int(options.get("top_k", 8)), max_call_hops=int(options.get("max_call_hops", 2)), min_similarity=float(options.get("min_similarity", 0.55)), ) _store_result(db, job, result) verdict_counter[result.verdict] += 1 job.completed_requirements += 1 job.output_summary = {**options, "verdict_counts": dict(verdict_counter)} db.add(job) db.commit() job.status = "completed" job.completed_at = datetime.utcnow() job.output_summary = {**options, "verdict_counts": dict(verdict_counter)} db.add(job) db.commit() except Exception as exc: if "job" in locals() and job: job.status = "failed" job.error_message = str(exc) job.completed_at = datetime.utcnow() db.add(job) db.commit() finally: db.close()