from __future__ import annotations from datetime import datetime from pathlib import Path from typing import Any, Dict, List from sqlalchemy.orm import Session from app.db.session import SessionLocal from app.models.tooling import SRSExtraction, SRSRequirement, ToolJob from app.tools.srs_reqs_qwen import get_srs_tool def run_srs_job(job_id: int) -> None: db = SessionLocal() try: job = db.query(ToolJob).filter(ToolJob.id == job_id).first() if not job: return job.status = "processing" job.started_at = datetime.utcnow() job.error_message = None db.commit() payload = get_srs_tool().run(job.input_file_path) extraction = SRSExtraction( job_id=job.id, document_name=payload["document_name"], document_title=payload.get("document_title") or payload["document_name"], generated_at=_parse_generated_at(payload.get("generated_at")), total_requirements=len(payload.get("requirements", [])), statistics=payload.get("statistics", {}), raw_output=payload.get("raw_output", {}), ) db.add(extraction) db.flush() for item in payload.get("requirements", []): requirement = SRSRequirement( extraction_id=extraction.id, requirement_uid=item["id"], title=item.get("title") or item["id"], description=item.get("description") or "", priority=item.get("priority") or "中", acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"], source_field=item.get("source_field") or "文档解析", section_number=item.get("section_number"), section_title=item.get("section_title"), requirement_type=item.get("requirement_type"), sort_order=int(item.get("sort_order") or 0), ) db.add(requirement) job.status = "completed" job.completed_at = datetime.utcnow() job.output_summary = { "total_requirements": extraction.total_requirements, "document_name": extraction.document_name, } db.commit() except Exception as exc: db.rollback() _mark_job_failed(job_id=job_id, error_message=str(exc)) finally: db.close() def _mark_job_failed(job_id: int, error_message: str) -> None: db = SessionLocal() try: job = db.query(ToolJob).filter(ToolJob.id == job_id).first() if not job: return job.status = "failed" job.completed_at = datetime.utcnow() job.error_message = error_message[:2000] db.commit() finally: db.close() def _parse_generated_at(value: Any) -> datetime: if isinstance(value, str): try: return datetime.fromisoformat(value) except ValueError: return datetime.utcnow() return datetime.utcnow() def ensure_upload_path(job_id: int, file_name: str) -> Path: target_dir = Path("uploads") / "srs_jobs" / str(job_id) target_dir.mkdir(parents=True, exist_ok=True) return target_dir / file_name def build_result_response(job: ToolJob, extraction: SRSExtraction) -> Dict[str, Any]: requirements: List[Dict[str, Any]] = [] for item in extraction.requirements: requirements.append( { "id": item.requirement_uid, "title": item.title, "description": item.description, "priority": item.priority, "acceptanceCriteria": item.acceptance_criteria or [], "sourceField": item.source_field, "sectionNumber": item.section_number, "sectionTitle": item.section_title, "requirementType": item.requirement_type, "sortOrder": item.sort_order, } ) return { "jobId": job.id, "documentName": extraction.document_name, "generatedAt": extraction.generated_at.isoformat(), "statistics": extraction.statistics or {}, "requirements": requirements, } def replace_requirements(db: Session, extraction: SRSExtraction, updates: List[Dict[str, Any]]) -> None: existing = { req.requirement_uid: req for req in db.query(SRSRequirement) .filter(SRSRequirement.extraction_id == extraction.id) .all() } seen_ids = set() for index, item in enumerate(updates): uid = item["id"] seen_ids.add(uid) req = existing.get(uid) if req is None: req = SRSRequirement( extraction_id=extraction.id, requirement_uid=uid, title=item.get("title") or uid, description=item.get("description") if item.get("description") is not None else "", priority=item.get("priority") or "中", acceptance_criteria=item.get("acceptanceCriteria") or ["待补充验收标准"], source_field=item.get("sourceField") or "文档解析", section_number=item.get("sectionNumber"), section_title=item.get("sectionTitle"), requirement_type=item.get("requirementType"), sort_order=int(item.get("sortOrder") or index), ) db.add(req) continue req.title = item.get("title", req.title) req.description = item.get("description", req.description) req.priority = item.get("priority", req.priority) req.acceptance_criteria = item.get("acceptanceCriteria", req.acceptance_criteria) req.source_field = item.get("sourceField", req.source_field) req.section_number = item.get("sectionNumber", req.section_number) req.section_title = item.get("sectionTitle", req.section_title) req.requirement_type = item.get("requirementType", req.requirement_type) req.sort_order = int(item.get("sortOrder", index)) for uid, req in existing.items(): if uid not in seen_ids: db.delete(req) extraction.total_requirements = len(updates) extraction.statistics = { "total": len(updates), "by_type": _count_requirement_types(updates), } extraction.raw_output = { "document_name": extraction.document_name, "generated_at": extraction.generated_at.isoformat(), "requirements": updates, } def _count_requirement_types(items: List[Dict[str, Any]]) -> Dict[str, int]: stats: Dict[str, int] = {} for item in items: req_type = item.get("requirementType") or "functional" stats[req_type] = stats.get(req_type, 0) + 1 return stats