from pathlib import Path from typing import Any, List from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile from sqlalchemy.orm import Session from app.core.security import get_current_user from app.db.session import get_db from app.models.tooling import SRSExtraction, ToolJob from app.models.user import User from app.schemas.tooling import ( SRSToolCreateJobResponse, SRSToolJobStatusResponse, SRSToolRequirementsSaveRequest, SRSToolResultResponse, ToolDefinitionResponse, ) from app.services.srs_job_service import ( build_result_response, ensure_upload_path, replace_requirements, run_srs_job, ) from app.tools.registry import ToolRegistry from app.tools.srs_reqs_qwen import get_srs_tool router = APIRouter() # Register SRS tool when the router is imported. get_srs_tool() ALLOWED_EXTENSIONS = {".pdf", ".docx"} @router.get("", response_model=List[ToolDefinitionResponse]) async def list_tools( current_user: User = Depends(get_current_user), ) -> Any: _ = current_user return ToolRegistry.list() @router.post("/srs/jobs", response_model=SRSToolCreateJobResponse) async def create_srs_job( background_tasks: BackgroundTasks, file: UploadFile = File(...), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> Any: safe_name = Path(file.filename or "").name ext = Path(safe_name).suffix.lower() if ext not in ALLOWED_EXTENSIONS: raise HTTPException(status_code=400, detail="仅支持 .pdf/.docx 文件") job = ToolJob( user_id=current_user.id, tool_name="srs.requirement_extractor", status="pending", input_file_name=safe_name, input_file_path="", ) db.add(job) db.commit() db.refresh(job) target_path = ensure_upload_path(job.id, safe_name) try: content = await file.read() target_path.write_bytes(content) except Exception as exc: job.status = "failed" job.error_message = f"保存上传文件失败: {exc}" db.add(job) db.commit() raise HTTPException(status_code=500, detail="上传文件保存失败") job.input_file_path = str(target_path.resolve()) db.add(job) db.commit() background_tasks.add_task(run_srs_job, job.id) return { "job_id": job.id, "status": job.status, } @router.get("/srs/jobs/{job_id}", response_model=SRSToolJobStatusResponse) async def get_srs_job_status( job_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> Any: job = ( db.query(ToolJob) .filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id) .first() ) if not job: raise HTTPException(status_code=404, detail="任务不存在") extraction = ( db.query(SRSExtraction) .filter(SRSExtraction.job_id == job.id) .first() ) return { "job_id": job.id, "tool_name": job.tool_name, "status": job.status, "error_message": job.error_message, "extraction_id": extraction.id if extraction else None, "started_at": job.started_at, "completed_at": job.completed_at, } @router.get("/srs/jobs/{job_id}/result", response_model=SRSToolResultResponse) async def get_srs_job_result( job_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> Any: job = ( db.query(ToolJob) .filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id) .first() ) if not job: raise HTTPException(status_code=404, detail="任务不存在") if job.status != "completed": raise HTTPException(status_code=409, detail="任务尚未完成") extraction = ( db.query(SRSExtraction) .filter(SRSExtraction.job_id == job.id) .first() ) if not extraction: raise HTTPException(status_code=404, detail="任务结果不存在") return build_result_response(job, extraction) @router.put("/srs/jobs/{job_id}/requirements", response_model=SRSToolResultResponse) async def save_srs_requirements( job_id: int, payload: SRSToolRequirementsSaveRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> Any: job = ( db.query(ToolJob) .filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id) .first() ) if not job: raise HTTPException(status_code=404, detail="任务不存在") extraction = ( db.query(SRSExtraction) .filter(SRSExtraction.job_id == job.id) .first() ) if not extraction: raise HTTPException(status_code=404, detail="任务结果不存在") replace_requirements(db, extraction, [item.dict() for item in payload.requirements]) db.add(extraction) db.commit() db.refresh(extraction) return build_result_response(job, extraction)