188 lines
6.7 KiB
Python
188 lines
6.7 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from datetime import datetime
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any, Dict, List
|
||
|
|
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
|
||
|
|
from app.db.session import SessionLocal
|
||
|
|
from app.models.tooling import SRSExtraction, SRSRequirement, ToolJob
|
||
|
|
from app.tools.srs_reqs_qwen import get_srs_tool
|
||
|
|
|
||
|
|
|
||
|
|
def run_srs_job(job_id: int) -> None:
|
||
|
|
db = SessionLocal()
|
||
|
|
try:
|
||
|
|
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
|
||
|
|
if not job:
|
||
|
|
return
|
||
|
|
|
||
|
|
job.status = "processing"
|
||
|
|
job.started_at = datetime.utcnow()
|
||
|
|
job.error_message = None
|
||
|
|
db.commit()
|
||
|
|
|
||
|
|
payload = get_srs_tool().run(job.input_file_path)
|
||
|
|
|
||
|
|
extraction = SRSExtraction(
|
||
|
|
job_id=job.id,
|
||
|
|
document_name=payload["document_name"],
|
||
|
|
document_title=payload.get("document_title") or payload["document_name"],
|
||
|
|
generated_at=_parse_generated_at(payload.get("generated_at")),
|
||
|
|
total_requirements=len(payload.get("requirements", [])),
|
||
|
|
statistics=payload.get("statistics", {}),
|
||
|
|
raw_output=payload.get("raw_output", {}),
|
||
|
|
)
|
||
|
|
db.add(extraction)
|
||
|
|
db.flush()
|
||
|
|
|
||
|
|
for item in payload.get("requirements", []):
|
||
|
|
requirement = SRSRequirement(
|
||
|
|
extraction_id=extraction.id,
|
||
|
|
requirement_uid=item["id"],
|
||
|
|
title=item.get("title") or item["id"],
|
||
|
|
description=item.get("description") or "",
|
||
|
|
priority=item.get("priority") or "中",
|
||
|
|
acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"],
|
||
|
|
source_field=item.get("source_field") or "文档解析",
|
||
|
|
section_number=item.get("section_number"),
|
||
|
|
section_title=item.get("section_title"),
|
||
|
|
requirement_type=item.get("requirement_type"),
|
||
|
|
sort_order=int(item.get("sort_order") or 0),
|
||
|
|
)
|
||
|
|
db.add(requirement)
|
||
|
|
|
||
|
|
job.status = "completed"
|
||
|
|
job.completed_at = datetime.utcnow()
|
||
|
|
job.output_summary = {
|
||
|
|
"total_requirements": extraction.total_requirements,
|
||
|
|
"document_name": extraction.document_name,
|
||
|
|
}
|
||
|
|
db.commit()
|
||
|
|
except Exception as exc:
|
||
|
|
db.rollback()
|
||
|
|
_mark_job_failed(job_id=job_id, error_message=str(exc))
|
||
|
|
finally:
|
||
|
|
db.close()
|
||
|
|
|
||
|
|
|
||
|
|
def _mark_job_failed(job_id: int, error_message: str) -> None:
|
||
|
|
db = SessionLocal()
|
||
|
|
try:
|
||
|
|
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
|
||
|
|
if not job:
|
||
|
|
return
|
||
|
|
job.status = "failed"
|
||
|
|
job.completed_at = datetime.utcnow()
|
||
|
|
job.error_message = error_message[:2000]
|
||
|
|
db.commit()
|
||
|
|
finally:
|
||
|
|
db.close()
|
||
|
|
|
||
|
|
|
||
|
|
def _parse_generated_at(value: Any) -> datetime:
|
||
|
|
if isinstance(value, str):
|
||
|
|
try:
|
||
|
|
return datetime.fromisoformat(value)
|
||
|
|
except ValueError:
|
||
|
|
return datetime.utcnow()
|
||
|
|
return datetime.utcnow()
|
||
|
|
|
||
|
|
|
||
|
|
def ensure_upload_path(job_id: int, file_name: str) -> Path:
|
||
|
|
target_dir = Path("uploads") / "srs_jobs" / str(job_id)
|
||
|
|
target_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
return target_dir / file_name
|
||
|
|
|
||
|
|
|
||
|
|
def build_result_response(job: ToolJob, extraction: SRSExtraction) -> Dict[str, Any]:
|
||
|
|
requirements: List[Dict[str, Any]] = []
|
||
|
|
for item in extraction.requirements:
|
||
|
|
requirements.append(
|
||
|
|
{
|
||
|
|
"id": item.requirement_uid,
|
||
|
|
"title": item.title,
|
||
|
|
"description": item.description,
|
||
|
|
"priority": item.priority,
|
||
|
|
"acceptanceCriteria": item.acceptance_criteria or [],
|
||
|
|
"sourceField": item.source_field,
|
||
|
|
"sectionNumber": item.section_number,
|
||
|
|
"sectionTitle": item.section_title,
|
||
|
|
"requirementType": item.requirement_type,
|
||
|
|
"sortOrder": item.sort_order,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"jobId": job.id,
|
||
|
|
"documentName": extraction.document_name,
|
||
|
|
"generatedAt": extraction.generated_at.isoformat(),
|
||
|
|
"statistics": extraction.statistics or {},
|
||
|
|
"requirements": requirements,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def replace_requirements(db: Session, extraction: SRSExtraction, updates: List[Dict[str, Any]]) -> None:
|
||
|
|
existing = {
|
||
|
|
req.requirement_uid: req
|
||
|
|
for req in db.query(SRSRequirement)
|
||
|
|
.filter(SRSRequirement.extraction_id == extraction.id)
|
||
|
|
.all()
|
||
|
|
}
|
||
|
|
seen_ids = set()
|
||
|
|
|
||
|
|
for index, item in enumerate(updates):
|
||
|
|
uid = item["id"]
|
||
|
|
seen_ids.add(uid)
|
||
|
|
req = existing.get(uid)
|
||
|
|
if req is None:
|
||
|
|
req = SRSRequirement(
|
||
|
|
extraction_id=extraction.id,
|
||
|
|
requirement_uid=uid,
|
||
|
|
title=item.get("title") or uid,
|
||
|
|
description=item.get("description") if item.get("description") is not None else "",
|
||
|
|
priority=item.get("priority") or "中",
|
||
|
|
acceptance_criteria=item.get("acceptanceCriteria") or ["待补充验收标准"],
|
||
|
|
source_field=item.get("sourceField") or "文档解析",
|
||
|
|
section_number=item.get("sectionNumber"),
|
||
|
|
section_title=item.get("sectionTitle"),
|
||
|
|
requirement_type=item.get("requirementType"),
|
||
|
|
sort_order=int(item.get("sortOrder") or index),
|
||
|
|
)
|
||
|
|
db.add(req)
|
||
|
|
continue
|
||
|
|
|
||
|
|
req.title = item.get("title", req.title)
|
||
|
|
req.description = item.get("description", req.description)
|
||
|
|
req.priority = item.get("priority", req.priority)
|
||
|
|
req.acceptance_criteria = item.get("acceptanceCriteria", req.acceptance_criteria)
|
||
|
|
req.source_field = item.get("sourceField", req.source_field)
|
||
|
|
req.section_number = item.get("sectionNumber", req.section_number)
|
||
|
|
req.section_title = item.get("sectionTitle", req.section_title)
|
||
|
|
req.requirement_type = item.get("requirementType", req.requirement_type)
|
||
|
|
req.sort_order = int(item.get("sortOrder", index))
|
||
|
|
|
||
|
|
for uid, req in existing.items():
|
||
|
|
if uid not in seen_ids:
|
||
|
|
db.delete(req)
|
||
|
|
|
||
|
|
extraction.total_requirements = len(updates)
|
||
|
|
extraction.statistics = {
|
||
|
|
"total": len(updates),
|
||
|
|
"by_type": _count_requirement_types(updates),
|
||
|
|
}
|
||
|
|
extraction.raw_output = {
|
||
|
|
"document_name": extraction.document_name,
|
||
|
|
"generated_at": extraction.generated_at.isoformat(),
|
||
|
|
"requirements": updates,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def _count_requirement_types(items: List[Dict[str, Any]]) -> Dict[str, int]:
|
||
|
|
stats: Dict[str, int] = {}
|
||
|
|
for item in items:
|
||
|
|
req_type = item.get("requirementType") or "functional"
|
||
|
|
stats[req_type] = stats.get(req_type, 0) + 1
|
||
|
|
return stats
|