Files
rag_agent/rag-web-ui/backend/app/api/api_v1/tools.py

416 lines
12 KiB
Python

from pathlib import Path
from typing import Any, List
import shutil
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
from sqlalchemy.orm import Session
from app.core.security import get_current_user
from app.db.session import get_db
from app.models.knowledge import KnowledgeBase
from app.models.tooling import SRSExtraction, TestingGeneration, ToolJob
from app.models.user import User
from app.schemas.tooling import (
SRSToolCreateJobResponse,
SRSToolHistoryItem,
SRSToolJobStatusResponse,
SRSToolRequirementsSaveRequest,
SRSToolResultResponse,
TestingGenerationCreateRequest,
TestingGenerationCreateResponse,
TestingGenerationHistoryItem,
TestingGenerationJobStatusResponse,
TestingGenerationResultResponse,
TestingGenerationSaveRequest,
ToolDefinitionResponse,
)
from app.services.srs_job_service import (
build_srs_upload_path,
build_result_response,
delete_srs_job,
ensure_upload_path,
list_srs_history,
replace_requirements,
run_srs_job,
)
from app.services.testing_generation_service import (
build_testing_generation_response,
create_testing_generation,
delete_testing_generation,
list_testing_history,
)
from app.services.testing_generation_job_service import run_testing_generation_job
from app.tools.registry import ToolRegistry
from app.tools.srs_reqs_qwen import get_srs_tool
router = APIRouter()
# Register SRS tool when the router is imported.
get_srs_tool()
ALLOWED_EXTENSIONS = {".pdf", ".docx"}
@router.get("", response_model=List[ToolDefinitionResponse])
async def list_tools(
current_user: User = Depends(get_current_user),
) -> Any:
_ = current_user
return ToolRegistry.list()
@router.post("/srs/jobs", response_model=SRSToolCreateJobResponse)
async def create_srs_job(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
safe_name = Path(file.filename or "").name
ext = Path(safe_name).suffix.lower()
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(status_code=400, detail="仅支持 .pdf/.docx 文件")
job = ToolJob(
user_id=current_user.id,
tool_name="srs.requirement_extractor",
status="pending",
input_file_name=safe_name,
input_file_path="",
)
db.add(job)
db.commit()
db.refresh(job)
target_path = ensure_upload_path(job.id, safe_name)
try:
content = await file.read()
target_path.write_bytes(content)
except Exception as exc:
job.status = "failed"
job.error_message = f"保存上传文件失败: {exc}"
db.add(job)
db.commit()
raise HTTPException(status_code=500, detail="上传文件保存失败")
job.input_file_path = str(target_path.resolve())
db.add(job)
db.commit()
background_tasks.add_task(run_srs_job, job.id)
return {
"job_id": job.id,
"status": job.status,
}
@router.get("/srs/jobs/{job_id}", response_model=SRSToolJobStatusResponse)
async def get_srs_job_status(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
extraction = (
db.query(SRSExtraction)
.filter(SRSExtraction.job_id == job.id)
.first()
)
return {
"job_id": job.id,
"tool_name": job.tool_name,
"status": job.status,
"error_message": job.error_message,
"extraction_id": extraction.id if extraction else None,
"started_at": job.started_at,
"completed_at": job.completed_at,
}
@router.get("/srs/jobs/{job_id}/result", response_model=SRSToolResultResponse)
async def get_srs_job_result(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
if job.status != "completed":
raise HTTPException(status_code=409, detail="任务尚未完成")
extraction = (
db.query(SRSExtraction)
.filter(SRSExtraction.job_id == job.id)
.first()
)
if not extraction:
raise HTTPException(status_code=404, detail="任务结果不存在")
return build_result_response(job, extraction)
@router.put("/srs/jobs/{job_id}/requirements", response_model=SRSToolResultResponse)
async def save_srs_requirements(
job_id: int,
payload: SRSToolRequirementsSaveRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
extraction = (
db.query(SRSExtraction)
.filter(SRSExtraction.job_id == job.id)
.first()
)
if not extraction:
raise HTTPException(status_code=404, detail="任务结果不存在")
replace_requirements(db, extraction, [item.dict() for item in payload.requirements])
db.add(extraction)
db.commit()
db.refresh(extraction)
return build_result_response(job, extraction)
@router.get("/srs/history", response_model=List[SRSToolHistoryItem])
async def get_srs_history(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
return list_srs_history(db, current_user.id)
@router.delete("/srs/jobs/{job_id}")
async def delete_srs_job_api(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
upload_path = build_srs_upload_path(job_id)
delete_srs_job(db, job)
if upload_path.exists():
shutil.rmtree(upload_path, ignore_errors=True)
return {"message": "删除成功"}
@router.post("/testing/generations", response_model=TestingGenerationResultResponse)
async def save_testing_generation(
payload: TestingGenerationSaveRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
if payload.source_job_id is not None:
source_job = (
db.query(ToolJob)
.filter(
ToolJob.id == payload.source_job_id,
ToolJob.user_id == current_user.id,
)
.first()
)
if not source_job:
raise HTTPException(status_code=404, detail="来源文件不存在")
if payload.knowledge_base_id is not None:
knowledge_base = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id == payload.knowledge_base_id,
KnowledgeBase.user_id == current_user.id,
)
.first()
)
if not knowledge_base:
raise HTTPException(status_code=404, detail="知识库不存在")
return create_testing_generation(db, current_user.id, payload)
@router.post("/testing/jobs", response_model=TestingGenerationCreateResponse)
async def create_testing_generation_job(
background_tasks: BackgroundTasks,
payload: TestingGenerationCreateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
if payload.source_job_id is not None:
source_job = (
db.query(ToolJob)
.filter(
ToolJob.id == payload.source_job_id,
ToolJob.user_id == current_user.id,
)
.first()
)
if not source_job:
raise HTTPException(status_code=404, detail="来源文件不存在")
if payload.knowledge_base_id is not None:
knowledge_base = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id == payload.knowledge_base_id,
KnowledgeBase.user_id == current_user.id,
)
.first()
)
if not knowledge_base:
raise HTTPException(status_code=404, detail="知识库不存在")
job = ToolJob(
user_id=current_user.id,
tool_name="testing.case_generator",
status="pending",
input_file_name=payload.source_document_name,
input_file_path="",
output_summary={
"source_document_name": payload.source_document_name,
"current_step": 0,
"total_steps": len(payload.requirements),
},
)
db.add(job)
db.commit()
db.refresh(job)
background_tasks.add_task(
run_testing_generation_job,
job.id,
payload.dict(),
)
return {
"job_id": job.id,
"status": job.status,
}
@router.get("/testing/jobs/{job_id}", response_model=TestingGenerationJobStatusResponse)
async def get_testing_generation_job_status(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(
ToolJob.id == job_id,
ToolJob.user_id == current_user.id,
)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
summary = job.output_summary or {}
return {
"job_id": job.id,
"tool_name": job.tool_name,
"status": job.status,
"error_message": job.error_message,
"started_at": job.started_at,
"completed_at": job.completed_at,
"source_document_name": summary.get("source_document_name"),
"current_step": summary.get("current_step"),
"total_steps": summary.get("total_steps"),
"current_requirement_id": summary.get("current_requirement_id"),
}
@router.get("/testing/history", response_model=List[TestingGenerationHistoryItem])
async def get_testing_history(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
return list_testing_history(db, current_user.id)
@router.get("/testing/jobs/{job_id}/result", response_model=TestingGenerationResultResponse)
async def get_testing_generation_result(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(
ToolJob.id == job_id,
ToolJob.user_id == current_user.id,
)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
generation = (
db.query(TestingGeneration)
.filter(TestingGeneration.job_id == job.id)
.first()
)
if not generation:
raise HTTPException(status_code=404, detail="任务结果不存在")
return build_testing_generation_response(job, generation)
@router.delete("/testing/jobs/{job_id}")
async def delete_testing_generation_api(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(
ToolJob.id == job_id,
ToolJob.user_id == current_user.id,
)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
generation = (
db.query(TestingGeneration)
.filter(TestingGeneration.job_id == job.id)
.first()
)
if not generation:
raise HTTPException(status_code=404, detail="任务结果不存在")
delete_testing_generation(db, job)
return {"message": "删除成功"}