Files

188 lines
6.7 KiB
Python
Raw Permalink Normal View History

2026-04-13 11:34:23 +08:00
from __future__ import annotations
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List
from sqlalchemy.orm import Session
from app.db.session import SessionLocal
from app.models.tooling import SRSExtraction, SRSRequirement, ToolJob
from app.tools.srs_reqs_qwen import get_srs_tool
def run_srs_job(job_id: int) -> None:
db = SessionLocal()
try:
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
if not job:
return
job.status = "processing"
job.started_at = datetime.utcnow()
job.error_message = None
db.commit()
payload = get_srs_tool().run(job.input_file_path)
extraction = SRSExtraction(
job_id=job.id,
document_name=payload["document_name"],
document_title=payload.get("document_title") or payload["document_name"],
generated_at=_parse_generated_at(payload.get("generated_at")),
total_requirements=len(payload.get("requirements", [])),
statistics=payload.get("statistics", {}),
raw_output=payload.get("raw_output", {}),
)
db.add(extraction)
db.flush()
for item in payload.get("requirements", []):
requirement = SRSRequirement(
extraction_id=extraction.id,
requirement_uid=item["id"],
title=item.get("title") or item["id"],
description=item.get("description") or "",
priority=item.get("priority") or "",
acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"],
source_field=item.get("source_field") or "文档解析",
section_number=item.get("section_number"),
section_title=item.get("section_title"),
requirement_type=item.get("requirement_type"),
sort_order=int(item.get("sort_order") or 0),
)
db.add(requirement)
job.status = "completed"
job.completed_at = datetime.utcnow()
job.output_summary = {
"total_requirements": extraction.total_requirements,
"document_name": extraction.document_name,
}
db.commit()
except Exception as exc:
db.rollback()
_mark_job_failed(job_id=job_id, error_message=str(exc))
finally:
db.close()
def _mark_job_failed(job_id: int, error_message: str) -> None:
db = SessionLocal()
try:
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
if not job:
return
job.status = "failed"
job.completed_at = datetime.utcnow()
job.error_message = error_message[:2000]
db.commit()
finally:
db.close()
def _parse_generated_at(value: Any) -> datetime:
if isinstance(value, str):
try:
return datetime.fromisoformat(value)
except ValueError:
return datetime.utcnow()
return datetime.utcnow()
def ensure_upload_path(job_id: int, file_name: str) -> Path:
target_dir = Path("uploads") / "srs_jobs" / str(job_id)
target_dir.mkdir(parents=True, exist_ok=True)
return target_dir / file_name
def build_result_response(job: ToolJob, extraction: SRSExtraction) -> Dict[str, Any]:
requirements: List[Dict[str, Any]] = []
for item in extraction.requirements:
requirements.append(
{
"id": item.requirement_uid,
"title": item.title,
"description": item.description,
"priority": item.priority,
"acceptanceCriteria": item.acceptance_criteria or [],
"sourceField": item.source_field,
"sectionNumber": item.section_number,
"sectionTitle": item.section_title,
"requirementType": item.requirement_type,
"sortOrder": item.sort_order,
}
)
return {
"jobId": job.id,
"documentName": extraction.document_name,
"generatedAt": extraction.generated_at.isoformat(),
"statistics": extraction.statistics or {},
"requirements": requirements,
}
def replace_requirements(db: Session, extraction: SRSExtraction, updates: List[Dict[str, Any]]) -> None:
existing = {
req.requirement_uid: req
for req in db.query(SRSRequirement)
.filter(SRSRequirement.extraction_id == extraction.id)
.all()
}
seen_ids = set()
for index, item in enumerate(updates):
uid = item["id"]
seen_ids.add(uid)
req = existing.get(uid)
if req is None:
req = SRSRequirement(
extraction_id=extraction.id,
requirement_uid=uid,
title=item.get("title") or uid,
description=item.get("description") if item.get("description") is not None else "",
priority=item.get("priority") or "",
acceptance_criteria=item.get("acceptanceCriteria") or ["待补充验收标准"],
source_field=item.get("sourceField") or "文档解析",
section_number=item.get("sectionNumber"),
section_title=item.get("sectionTitle"),
requirement_type=item.get("requirementType"),
sort_order=int(item.get("sortOrder") or index),
)
db.add(req)
continue
req.title = item.get("title", req.title)
req.description = item.get("description", req.description)
req.priority = item.get("priority", req.priority)
req.acceptance_criteria = item.get("acceptanceCriteria", req.acceptance_criteria)
req.source_field = item.get("sourceField", req.source_field)
req.section_number = item.get("sectionNumber", req.section_number)
req.section_title = item.get("sectionTitle", req.section_title)
req.requirement_type = item.get("requirementType", req.requirement_type)
req.sort_order = int(item.get("sortOrder", index))
for uid, req in existing.items():
if uid not in seen_ids:
db.delete(req)
extraction.total_requirements = len(updates)
extraction.statistics = {
"total": len(updates),
"by_type": _count_requirement_types(updates),
}
extraction.raw_output = {
"document_name": extraction.document_name,
"generated_at": extraction.generated_at.isoformat(),
"requirements": updates,
}
def _count_requirement_types(items: List[Dict[str, Any]]) -> Dict[str, int]:
stats: Dict[str, int] = {}
for item in items:
req_type = item.get("requirementType") or "functional"
stats[req_type] = stats.get(req_type, 0) + 1
return stats