from pathlib import Path from typing import Any, List import shutil from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile from sqlalchemy.orm import Session from app.core.security import get_current_user from app.db.session import get_db from app.models.knowledge import KnowledgeBase from app.models.tooling import SRSExtraction, TestingGeneration, ToolJob from app.models.user import User from app.schemas.tooling import ( SRSToolCreateJobResponse, SRSToolHistoryItem, SRSToolJobStatusResponse, SRSToolRequirementsSaveRequest, SRSToolResultResponse, TestingGenerationCreateRequest, TestingGenerationCreateResponse, TestingGenerationHistoryItem, TestingGenerationJobStatusResponse, TestingGenerationResultResponse, TestingGenerationSaveRequest, ToolDefinitionResponse, ) from app.services.srs_job_service import ( build_srs_upload_path, build_result_response, delete_srs_job, ensure_upload_path, list_srs_history, replace_requirements, run_srs_job, ) from app.services.testing_generation_service import ( build_testing_generation_response, create_testing_generation, delete_testing_generation, list_testing_history, ) from app.services.testing_generation_job_service import run_testing_generation_job from app.tools.registry import ToolRegistry from app.tools.srs_reqs_qwen import get_srs_tool router = APIRouter() # Register SRS tool when the router is imported. get_srs_tool() ALLOWED_EXTENSIONS = {".pdf", ".docx"} @router.get("", response_model=List[ToolDefinitionResponse]) async def list_tools( current_user: User = Depends(get_current_user), ) -> Any: _ = current_user return ToolRegistry.list() @router.post("/srs/jobs", response_model=SRSToolCreateJobResponse) async def create_srs_job( background_tasks: BackgroundTasks, file: UploadFile = File(...), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> Any: safe_name = Path(file.filename or "").name ext = Path(safe_name).suffix.lower() if ext not in ALLOWED_EXTENSIONS: raise HTTPException(status_code=400, detail="仅支持 .pdf/.docx 文件") job = ToolJob( user_id=current_user.id, tool_name="srs.requirement_extractor", status="pending", input_file_name=safe_name, input_file_path="", ) db.add(job) db.commit() db.refresh(job) target_path = ensure_upload_path(job.id, safe_name) try: content = await file.read() target_path.write_bytes(content) except Exception as exc: job.status = "failed" job.error_message = f"保存上传文件失败: {exc}" db.add(job) db.commit() raise HTTPException(status_code=500, detail="上传文件保存失败") job.input_file_path = str(target_path.resolve()) db.add(job) db.commit() background_tasks.add_task(run_srs_job, job.id) return { "job_id": job.id, "status": job.status, } @router.get("/srs/jobs/{job_id}", response_model=SRSToolJobStatusResponse) async def get_srs_job_status( job_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> Any: job = ( db.query(ToolJob) .filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id) .first() ) if not job: raise HTTPException(status_code=404, detail="任务不存在") extraction = ( db.query(SRSExtraction) .filter(SRSExtraction.job_id == job.id) .first() ) return { "job_id": job.id, "tool_name": job.tool_name, "status": job.status, "error_message": job.error_message, "extraction_id": extraction.id if extraction else None, "started_at": job.started_at, "completed_at": job.completed_at, } @router.get("/srs/jobs/{job_id}/result", response_model=SRSToolResultResponse) async def get_srs_job_result( job_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> Any: job = ( db.query(ToolJob) .filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id) .first() ) if not job: raise HTTPException(status_code=404, detail="任务不存在") if job.status != "completed": raise HTTPException(status_code=409, detail="任务尚未完成") extraction = ( db.query(SRSExtraction) .filter(SRSExtraction.job_id == job.id) .first() ) if not extraction: raise HTTPException(status_code=404, detail="任务结果不存在") return build_result_response(job, extraction) @router.put("/srs/jobs/{job_id}/requirements", response_model=SRSToolResultResponse) async def save_srs_requirements( job_id: int, payload: SRSToolRequirementsSaveRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> Any: job = ( db.query(ToolJob) .filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id) .first() ) if not job: raise HTTPException(status_code=404, detail="任务不存在") extraction = ( db.query(SRSExtraction) .filter(SRSExtraction.job_id == job.id) .first() ) if not extraction: raise HTTPException(status_code=404, detail="任务结果不存在") replace_requirements(db, extraction, [item.dict() for item in payload.requirements]) db.add(extraction) db.commit() db.refresh(extraction) return build_result_response(job, extraction) @router.get("/srs/history", response_model=List[SRSToolHistoryItem]) async def get_srs_history( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> Any: return list_srs_history(db, current_user.id) @router.delete("/srs/jobs/{job_id}") async def delete_srs_job_api( job_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> Any: job = ( db.query(ToolJob) .filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id) .first() ) if not job: raise HTTPException(status_code=404, detail="任务不存在") upload_path = build_srs_upload_path(job_id) delete_srs_job(db, job) if upload_path.exists(): shutil.rmtree(upload_path, ignore_errors=True) return {"message": "删除成功"} @router.post("/testing/generations", response_model=TestingGenerationResultResponse) async def save_testing_generation( payload: TestingGenerationSaveRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> Any: if payload.source_job_id is not None: source_job = ( db.query(ToolJob) .filter( ToolJob.id == payload.source_job_id, ToolJob.user_id == current_user.id, ) .first() ) if not source_job: raise HTTPException(status_code=404, detail="来源文件不存在") if payload.knowledge_base_id is not None: knowledge_base = ( db.query(KnowledgeBase) .filter( KnowledgeBase.id == payload.knowledge_base_id, KnowledgeBase.user_id == current_user.id, ) .first() ) if not knowledge_base: raise HTTPException(status_code=404, detail="知识库不存在") return create_testing_generation(db, current_user.id, payload) @router.post("/testing/jobs", response_model=TestingGenerationCreateResponse) async def create_testing_generation_job( background_tasks: BackgroundTasks, payload: TestingGenerationCreateRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> Any: if payload.source_job_id is not None: source_job = ( db.query(ToolJob) .filter( ToolJob.id == payload.source_job_id, ToolJob.user_id == current_user.id, ) .first() ) if not source_job: raise HTTPException(status_code=404, detail="来源文件不存在") if payload.knowledge_base_id is not None: knowledge_base = ( db.query(KnowledgeBase) .filter( KnowledgeBase.id == payload.knowledge_base_id, KnowledgeBase.user_id == current_user.id, ) .first() ) if not knowledge_base: raise HTTPException(status_code=404, detail="知识库不存在") job = ToolJob( user_id=current_user.id, tool_name="testing.case_generator", status="pending", input_file_name=payload.source_document_name, input_file_path="", output_summary={ "source_document_name": payload.source_document_name, "current_step": 0, "total_steps": len(payload.requirements), }, ) db.add(job) db.commit() db.refresh(job) background_tasks.add_task( run_testing_generation_job, job.id, payload.dict(), ) return { "job_id": job.id, "status": job.status, } @router.get("/testing/jobs/{job_id}", response_model=TestingGenerationJobStatusResponse) async def get_testing_generation_job_status( job_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> Any: job = ( db.query(ToolJob) .filter( ToolJob.id == job_id, ToolJob.user_id == current_user.id, ) .first() ) if not job: raise HTTPException(status_code=404, detail="任务不存在") summary = job.output_summary or {} return { "job_id": job.id, "tool_name": job.tool_name, "status": job.status, "error_message": job.error_message, "started_at": job.started_at, "completed_at": job.completed_at, "source_document_name": summary.get("source_document_name"), "current_step": summary.get("current_step"), "total_steps": summary.get("total_steps"), "current_requirement_id": summary.get("current_requirement_id"), } @router.get("/testing/history", response_model=List[TestingGenerationHistoryItem]) async def get_testing_history( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> Any: return list_testing_history(db, current_user.id) @router.get("/testing/jobs/{job_id}/result", response_model=TestingGenerationResultResponse) async def get_testing_generation_result( job_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> Any: job = ( db.query(ToolJob) .filter( ToolJob.id == job_id, ToolJob.user_id == current_user.id, ) .first() ) if not job: raise HTTPException(status_code=404, detail="任务不存在") generation = ( db.query(TestingGeneration) .filter(TestingGeneration.job_id == job.id) .first() ) if not generation: raise HTTPException(status_code=404, detail="任务结果不存在") return build_testing_generation_response(job, generation) @router.delete("/testing/jobs/{job_id}") async def delete_testing_generation_api( job_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> Any: job = ( db.query(ToolJob) .filter( ToolJob.id == job_id, ToolJob.user_id == current_user.id, ) .first() ) if not job: raise HTTPException(status_code=404, detail="任务不存在") generation = ( db.query(TestingGeneration) .filter(TestingGeneration.job_id == job.id) .first() ) if not generation: raise HTTPException(status_code=404, detail="任务结果不存在") delete_testing_generation(db, job) return {"message": "删除成功"}