176 lines
4.9 KiB
Python
176 lines
4.9 KiB
Python
|
|
from pathlib import Path
|
||
|
|
from typing import Any, List
|
||
|
|
|
||
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
|
||
|
|
from app.core.security import get_current_user
|
||
|
|
from app.db.session import get_db
|
||
|
|
from app.models.tooling import SRSExtraction, ToolJob
|
||
|
|
from app.models.user import User
|
||
|
|
from app.schemas.tooling import (
|
||
|
|
SRSToolCreateJobResponse,
|
||
|
|
SRSToolJobStatusResponse,
|
||
|
|
SRSToolRequirementsSaveRequest,
|
||
|
|
SRSToolResultResponse,
|
||
|
|
ToolDefinitionResponse,
|
||
|
|
)
|
||
|
|
from app.services.srs_job_service import (
|
||
|
|
build_result_response,
|
||
|
|
ensure_upload_path,
|
||
|
|
replace_requirements,
|
||
|
|
run_srs_job,
|
||
|
|
)
|
||
|
|
from app.tools.registry import ToolRegistry
|
||
|
|
from app.tools.srs_reqs_qwen import get_srs_tool
|
||
|
|
|
||
|
|
router = APIRouter()
|
||
|
|
|
||
|
|
# Register SRS tool when the router is imported.
|
||
|
|
get_srs_tool()
|
||
|
|
|
||
|
|
ALLOWED_EXTENSIONS = {".pdf", ".docx"}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("", response_model=List[ToolDefinitionResponse])
|
||
|
|
async def list_tools(
|
||
|
|
current_user: User = Depends(get_current_user),
|
||
|
|
) -> Any:
|
||
|
|
_ = current_user
|
||
|
|
return ToolRegistry.list()
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/srs/jobs", response_model=SRSToolCreateJobResponse)
|
||
|
|
async def create_srs_job(
|
||
|
|
background_tasks: BackgroundTasks,
|
||
|
|
file: UploadFile = File(...),
|
||
|
|
db: Session = Depends(get_db),
|
||
|
|
current_user: User = Depends(get_current_user),
|
||
|
|
) -> Any:
|
||
|
|
safe_name = Path(file.filename or "").name
|
||
|
|
ext = Path(safe_name).suffix.lower()
|
||
|
|
if ext not in ALLOWED_EXTENSIONS:
|
||
|
|
raise HTTPException(status_code=400, detail="仅支持 .pdf/.docx 文件")
|
||
|
|
|
||
|
|
job = ToolJob(
|
||
|
|
user_id=current_user.id,
|
||
|
|
tool_name="srs.requirement_extractor",
|
||
|
|
status="pending",
|
||
|
|
input_file_name=safe_name,
|
||
|
|
input_file_path="",
|
||
|
|
)
|
||
|
|
db.add(job)
|
||
|
|
db.commit()
|
||
|
|
db.refresh(job)
|
||
|
|
|
||
|
|
target_path = ensure_upload_path(job.id, safe_name)
|
||
|
|
try:
|
||
|
|
content = await file.read()
|
||
|
|
target_path.write_bytes(content)
|
||
|
|
except Exception as exc:
|
||
|
|
job.status = "failed"
|
||
|
|
job.error_message = f"保存上传文件失败: {exc}"
|
||
|
|
db.add(job)
|
||
|
|
db.commit()
|
||
|
|
raise HTTPException(status_code=500, detail="上传文件保存失败")
|
||
|
|
|
||
|
|
job.input_file_path = str(target_path.resolve())
|
||
|
|
db.add(job)
|
||
|
|
db.commit()
|
||
|
|
|
||
|
|
background_tasks.add_task(run_srs_job, job.id)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"job_id": job.id,
|
||
|
|
"status": job.status,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/srs/jobs/{job_id}", response_model=SRSToolJobStatusResponse)
|
||
|
|
async def get_srs_job_status(
|
||
|
|
job_id: int,
|
||
|
|
db: Session = Depends(get_db),
|
||
|
|
current_user: User = Depends(get_current_user),
|
||
|
|
) -> Any:
|
||
|
|
job = (
|
||
|
|
db.query(ToolJob)
|
||
|
|
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
|
||
|
|
.first()
|
||
|
|
)
|
||
|
|
if not job:
|
||
|
|
raise HTTPException(status_code=404, detail="任务不存在")
|
||
|
|
|
||
|
|
extraction = (
|
||
|
|
db.query(SRSExtraction)
|
||
|
|
.filter(SRSExtraction.job_id == job.id)
|
||
|
|
.first()
|
||
|
|
)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"job_id": job.id,
|
||
|
|
"tool_name": job.tool_name,
|
||
|
|
"status": job.status,
|
||
|
|
"error_message": job.error_message,
|
||
|
|
"extraction_id": extraction.id if extraction else None,
|
||
|
|
"started_at": job.started_at,
|
||
|
|
"completed_at": job.completed_at,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/srs/jobs/{job_id}/result", response_model=SRSToolResultResponse)
|
||
|
|
async def get_srs_job_result(
|
||
|
|
job_id: int,
|
||
|
|
db: Session = Depends(get_db),
|
||
|
|
current_user: User = Depends(get_current_user),
|
||
|
|
) -> Any:
|
||
|
|
job = (
|
||
|
|
db.query(ToolJob)
|
||
|
|
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
|
||
|
|
.first()
|
||
|
|
)
|
||
|
|
if not job:
|
||
|
|
raise HTTPException(status_code=404, detail="任务不存在")
|
||
|
|
if job.status != "completed":
|
||
|
|
raise HTTPException(status_code=409, detail="任务尚未完成")
|
||
|
|
|
||
|
|
extraction = (
|
||
|
|
db.query(SRSExtraction)
|
||
|
|
.filter(SRSExtraction.job_id == job.id)
|
||
|
|
.first()
|
||
|
|
)
|
||
|
|
if not extraction:
|
||
|
|
raise HTTPException(status_code=404, detail="任务结果不存在")
|
||
|
|
|
||
|
|
return build_result_response(job, extraction)
|
||
|
|
|
||
|
|
|
||
|
|
@router.put("/srs/jobs/{job_id}/requirements", response_model=SRSToolResultResponse)
|
||
|
|
async def save_srs_requirements(
|
||
|
|
job_id: int,
|
||
|
|
payload: SRSToolRequirementsSaveRequest,
|
||
|
|
db: Session = Depends(get_db),
|
||
|
|
current_user: User = Depends(get_current_user),
|
||
|
|
) -> Any:
|
||
|
|
job = (
|
||
|
|
db.query(ToolJob)
|
||
|
|
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
|
||
|
|
.first()
|
||
|
|
)
|
||
|
|
if not job:
|
||
|
|
raise HTTPException(status_code=404, detail="任务不存在")
|
||
|
|
|
||
|
|
extraction = (
|
||
|
|
db.query(SRSExtraction)
|
||
|
|
.filter(SRSExtraction.job_id == job.id)
|
||
|
|
.first()
|
||
|
|
)
|
||
|
|
if not extraction:
|
||
|
|
raise HTTPException(status_code=404, detail="任务结果不存在")
|
||
|
|
|
||
|
|
replace_requirements(db, extraction, [item.dict() for item in payload.requirements])
|
||
|
|
db.add(extraction)
|
||
|
|
db.commit()
|
||
|
|
db.refresh(extraction)
|
||
|
|
|
||
|
|
return build_result_response(job, extraction)
|