Files
2026-04-13 11:34:23 +08:00

176 lines
4.9 KiB
Python

from pathlib import Path
from typing import Any, List
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
from sqlalchemy.orm import Session
from app.core.security import get_current_user
from app.db.session import get_db
from app.models.tooling import SRSExtraction, ToolJob
from app.models.user import User
from app.schemas.tooling import (
SRSToolCreateJobResponse,
SRSToolJobStatusResponse,
SRSToolRequirementsSaveRequest,
SRSToolResultResponse,
ToolDefinitionResponse,
)
from app.services.srs_job_service import (
build_result_response,
ensure_upload_path,
replace_requirements,
run_srs_job,
)
from app.tools.registry import ToolRegistry
from app.tools.srs_reqs_qwen import get_srs_tool
router = APIRouter()
# Register SRS tool when the router is imported.
get_srs_tool()
ALLOWED_EXTENSIONS = {".pdf", ".docx"}
@router.get("", response_model=List[ToolDefinitionResponse])
async def list_tools(
current_user: User = Depends(get_current_user),
) -> Any:
_ = current_user
return ToolRegistry.list()
@router.post("/srs/jobs", response_model=SRSToolCreateJobResponse)
async def create_srs_job(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
safe_name = Path(file.filename or "").name
ext = Path(safe_name).suffix.lower()
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(status_code=400, detail="仅支持 .pdf/.docx 文件")
job = ToolJob(
user_id=current_user.id,
tool_name="srs.requirement_extractor",
status="pending",
input_file_name=safe_name,
input_file_path="",
)
db.add(job)
db.commit()
db.refresh(job)
target_path = ensure_upload_path(job.id, safe_name)
try:
content = await file.read()
target_path.write_bytes(content)
except Exception as exc:
job.status = "failed"
job.error_message = f"保存上传文件失败: {exc}"
db.add(job)
db.commit()
raise HTTPException(status_code=500, detail="上传文件保存失败")
job.input_file_path = str(target_path.resolve())
db.add(job)
db.commit()
background_tasks.add_task(run_srs_job, job.id)
return {
"job_id": job.id,
"status": job.status,
}
@router.get("/srs/jobs/{job_id}", response_model=SRSToolJobStatusResponse)
async def get_srs_job_status(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
extraction = (
db.query(SRSExtraction)
.filter(SRSExtraction.job_id == job.id)
.first()
)
return {
"job_id": job.id,
"tool_name": job.tool_name,
"status": job.status,
"error_message": job.error_message,
"extraction_id": extraction.id if extraction else None,
"started_at": job.started_at,
"completed_at": job.completed_at,
}
@router.get("/srs/jobs/{job_id}/result", response_model=SRSToolResultResponse)
async def get_srs_job_result(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
if job.status != "completed":
raise HTTPException(status_code=409, detail="任务尚未完成")
extraction = (
db.query(SRSExtraction)
.filter(SRSExtraction.job_id == job.id)
.first()
)
if not extraction:
raise HTTPException(status_code=404, detail="任务结果不存在")
return build_result_response(job, extraction)
@router.put("/srs/jobs/{job_id}/requirements", response_model=SRSToolResultResponse)
async def save_srs_requirements(
job_id: int,
payload: SRSToolRequirementsSaveRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
extraction = (
db.query(SRSExtraction)
.filter(SRSExtraction.job_id == job.id)
.first()
)
if not extraction:
raise HTTPException(status_code=404, detail="任务结果不存在")
replace_requirements(db, extraction, [item.dict() for item in payload.requirements])
db.add(extraction)
db.commit()
db.refresh(extraction)
return build_result_response(job, extraction)