from __future__ import annotations from copy import deepcopy from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Tuple from sqlalchemy.orm import Session from app.db.session import SessionLocal from app.models.tooling import SRSExtraction, SRSRequirement, ToolJob from app.services.model_config import ModelConfigService from app.tools.srs_reqs_qwen import get_srs_tool TYPE_TO_CHINESE = { "functional": "功能需求", "interface": "接口需求", "performance": "性能需求", "security": "安全需求", "reliability": "可靠性需求", "other": "其他需求", } def _build_internal_title(description: Any, fallback: str, index: int = 0) -> str: text = str(description or "").strip() if not text: return fallback or f"需求项 {index + 1}" for separator in ("。", ";", "\n", ";", "."): if separator in text: text = text.split(separator, 1)[0].strip() break text = text[:20].strip() return text or fallback or f"需求项 {index + 1}" def _normalize_requirement_type(value: Any) -> str: text = str(value or "").strip() if text in {"functional", "interface", "performance", "security", "reliability", "other"}: return text chinese_map = { "接口需求": "interface", "性能需求": "performance", "安全需求": "security", "可靠性需求": "reliability", "其他需求": "other", "功能需求": "functional", } return chinese_map.get(text, "functional") def run_srs_job(job_id: int) -> None: db = SessionLocal() try: job = db.query(ToolJob).filter(ToolJob.id == job_id).first() if not job: return job.status = "processing" job.started_at = datetime.utcnow() job.error_message = None db.commit() model_profile = ModelConfigService.require_active_config(db, job.user_id) ModelConfigService.touch_last_used(db, model_profile) payload = get_srs_tool().run(job.input_file_path, model_profile=model_profile) extraction = SRSExtraction( job_id=job.id, document_name=payload["document_name"], document_title=payload.get("document_title") or payload["document_name"], generated_at=_parse_generated_at(payload.get("generated_at")), total_requirements=len(payload.get("requirements", [])), statistics=payload.get("statistics", {}), raw_output=payload.get("raw_output", {}), ) db.add(extraction) db.flush() for item in payload.get("requirements", []): requirement = SRSRequirement( extraction_id=extraction.id, requirement_uid=item["id"], title=_build_internal_title(item.get("description"), item["id"]), description=item.get("description") or "", priority=item.get("priority") or "中", acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"], source_field=item.get("source_field") or "文档解析", section_uid=item.get("section_uid"), section_number=item.get("section_number"), section_title=item.get("section_title"), requirement_type=item.get("requirement_type"), interface_name=item.get("interface_name"), interface_type=item.get("interface_type"), data_source=item.get("data_source"), data_destination=item.get("data_destination"), sort_order=int(item.get("sort_order") or 0), ) db.add(requirement) job.status = "completed" job.completed_at = datetime.utcnow() job.output_summary = { "total_requirements": extraction.total_requirements, "document_name": extraction.document_name, } db.commit() except Exception as exc: db.rollback() _mark_job_failed(job_id=job_id, error_message=str(exc)) finally: db.close() def _mark_job_failed(job_id: int, error_message: str) -> None: db = SessionLocal() try: job = db.query(ToolJob).filter(ToolJob.id == job_id).first() if not job: return job.status = "failed" job.completed_at = datetime.utcnow() job.error_message = error_message[:2000] db.commit() finally: db.close() def _parse_generated_at(value: Any) -> datetime: if isinstance(value, str): try: return datetime.fromisoformat(value) except ValueError: return datetime.utcnow() return datetime.utcnow() def ensure_upload_path(job_id: int, file_name: str) -> Path: target_dir = Path("uploads") / "srs_jobs" / str(job_id) target_dir.mkdir(parents=True, exist_ok=True) return target_dir / file_name def build_result_response(job: ToolJob, extraction: SRSExtraction) -> Dict[str, Any]: requirements = [_requirement_model_to_payload(item) for item in extraction.requirements] raw_output = _merge_updates_into_raw_output(extraction.raw_output, requirements, extraction.document_name) return { "jobId": job.id, "documentName": extraction.document_name, "generatedAt": extraction.generated_at.isoformat(), "statistics": extraction.statistics or {}, "requirements": requirements, "rawOutput": raw_output, } def replace_requirements(db: Session, extraction: SRSExtraction, updates: List[Dict[str, Any]]) -> None: normalized_updates = [_normalize_update_payload(item, index) for index, item in enumerate(updates)] existing = { req.requirement_uid: req for req in db.query(SRSRequirement) .filter(SRSRequirement.extraction_id == extraction.id) .all() } seen_ids = set() for index, item in enumerate(normalized_updates): uid = item["id"] seen_ids.add(uid) req = existing.get(uid) if req is None: description = item.get("description") if item.get("description") is not None else "" req = SRSRequirement( extraction_id=extraction.id, requirement_uid=uid, title=_build_internal_title(description, uid, int(item.get("sortOrder") or index)), description=description, priority=item.get("priority") or "中", acceptance_criteria=item.get("acceptanceCriteria") or ["待补充验收标准"], source_field=item.get("sourceField") or "文档解析", section_uid=item.get("sectionUid"), section_number=item.get("sectionNumber"), section_title=item.get("sectionTitle"), requirement_type=item.get("requirementType"), interface_name=item.get("interfaceName"), interface_type=item.get("interfaceType"), data_source=item.get("dataSource"), data_destination=item.get("dataDestination"), sort_order=int(item.get("sortOrder") or 0), ) db.add(req) continue req.description = item.get("description", req.description) req.title = _build_internal_title(req.description, req.requirement_uid, int(item.get("sortOrder") or index)) req.priority = item.get("priority", req.priority) req.acceptance_criteria = item.get("acceptanceCriteria", req.acceptance_criteria) req.source_field = item.get("sourceField", req.source_field) req.section_uid = item.get("sectionUid", req.section_uid) req.section_number = item.get("sectionNumber", req.section_number) req.section_title = item.get("sectionTitle", req.section_title) req.requirement_type = item.get("requirementType", req.requirement_type) req.interface_name = item.get("interfaceName", req.interface_name) req.interface_type = item.get("interfaceType", req.interface_type) req.data_source = item.get("dataSource", req.data_source) req.data_destination = item.get("dataDestination", req.data_destination) req.sort_order = int(item.get("sortOrder", req.sort_order)) for uid, req in existing.items(): if uid not in seen_ids: db.delete(req) extraction.total_requirements = len(normalized_updates) extraction.statistics = { "total": len(normalized_updates), "by_type": _count_requirement_types(normalized_updates), } extraction.raw_output = _merge_updates_into_raw_output( extraction.raw_output, normalized_updates, extraction.document_name, ) def list_srs_history(db: Session, user_id: int) -> List[Dict[str, Any]]: records: List[Tuple[ToolJob, SRSExtraction]] = ( db.query(ToolJob, SRSExtraction) .join(SRSExtraction, SRSExtraction.job_id == ToolJob.id) .filter(ToolJob.user_id == user_id) .order_by(ToolJob.created_at.desc()) .all() ) items: List[Dict[str, Any]] = [] for job, extraction in records: items.append( { "jobId": job.id, "documentName": extraction.document_name, "generatedAt": extraction.generated_at.isoformat(), "totalRequirements": extraction.total_requirements, "status": job.status, "createdAt": job.created_at.isoformat(), "updatedAt": job.updated_at.isoformat(), } ) return items def delete_srs_job(db: Session, job: ToolJob) -> None: db.delete(job) db.commit() def build_srs_upload_path(job_id: int) -> Path: return Path("uploads") / "srs_jobs" / str(job_id) def _count_requirement_types(items: List[Dict[str, Any]]) -> Dict[str, int]: stats: Dict[str, int] = {} for item in items: req_type = item.get("requirementType") or "functional" stats[req_type] = stats.get(req_type, 0) + 1 return stats def _requirement_model_to_payload(item: SRSRequirement) -> Dict[str, Any]: return { "id": item.requirement_uid, "description": item.description, "priority": item.priority, "acceptanceCriteria": item.acceptance_criteria or [], "sourceField": item.source_field, "sectionUid": item.section_uid, "sectionNumber": item.section_number, "sectionTitle": item.section_title, "requirementType": item.requirement_type, "interfaceName": item.interface_name, "interfaceType": item.interface_type, "dataSource": item.data_source, "dataDestination": item.data_destination, "sortOrder": item.sort_order, } def _normalize_update_payload(item: Dict[str, Any], index: int) -> Dict[str, Any]: requirement_type = _normalize_requirement_type(item.get("requirementType")) normalized = { "id": str(item.get("id") or f"REQ-{index + 1:03d}"), "description": str(item.get("description") or "").strip(), "priority": item.get("priority") or "中", "acceptanceCriteria": item.get("acceptanceCriteria") or ["待补充验收标准"], "sourceField": item.get("sourceField") or "文档解析", "sectionUid": item.get("sectionUid"), "sectionNumber": item.get("sectionNumber"), "sectionTitle": item.get("sectionTitle"), "requirementType": requirement_type, "interfaceName": item.get("interfaceName") or "", "interfaceType": item.get("interfaceType") or "", "dataSource": item.get("dataSource") or "", "dataDestination": item.get("dataDestination") or "", "sortOrder": int(item.get("sortOrder") or index), } if requirement_type != "interface": normalized["interfaceName"] = "" normalized["interfaceType"] = "" normalized["dataSource"] = "" normalized["dataDestination"] = "" return normalized def _merge_updates_into_raw_output( raw_output: Dict[str, Any] | None, updates: List[Dict[str, Any]], document_name: str, ) -> Dict[str, Any]: if not isinstance(raw_output, dict) or "需求内容" not in raw_output: return _build_raw_output_from_flat(updates, document_name) result = deepcopy(raw_output) content = result.get("需求内容") if not isinstance(content, dict): return _build_raw_output_from_flat(updates, document_name) updates_by_section = _group_updates_by_section(updates) _rewrite_content_requirements(content, updates_by_section) _append_unmatched_sections(content, updates_by_section) _refresh_metadata(result, updates) return result def _group_updates_by_section(updates: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: grouped: Dict[str, List[Dict[str, Any]]] = {} for item in updates: key = _section_key(item.get("sectionUid"), item.get("sectionNumber"), item.get("sectionTitle")) grouped.setdefault(key, []).append(item) for values in grouped.values(): values.sort(key=lambda value: int(value.get("sortOrder") or 0)) return grouped def _rewrite_content_requirements(content: Dict[str, Any], updates_by_section: Dict[str, List[Dict[str, Any]]]) -> None: for section in content.values(): if not isinstance(section, dict): continue section_info = section.get("章节信息") or {} section_key = _section_key( section_info.get("章节UID"), section_info.get("章节编号"), section_info.get("章节标题"), ) if section_key in updates_by_section: section["需求列表"] = [_to_raw_requirement_item(item) for item in updates_by_section.pop(section_key)] elif "需求列表" in section: section["需求列表"] = [] children = section.get("子章节") if isinstance(children, dict): _rewrite_content_requirements(children, updates_by_section) def _append_unmatched_sections(content: Dict[str, Any], updates_by_section: Dict[str, List[Dict[str, Any]]]) -> None: if not updates_by_section: return orphan_key = "未归类章节" orphan_section = content.get(orphan_key) if not isinstance(orphan_section, dict): orphan_section = { "章节信息": { "章节编号": "", "章节标题": orphan_key, "章节级别": 1, }, "需求列表": [], } content[orphan_key] = orphan_section all_reqs: List[Dict[str, Any]] = orphan_section.get("需求列表") or [] for values in updates_by_section.values(): for item in values: all_reqs.append(_to_raw_requirement_item(item)) orphan_section["需求列表"] = all_reqs updates_by_section.clear() def _refresh_metadata(raw_output: Dict[str, Any], updates: List[Dict[str, Any]]) -> None: metadata = raw_output.get("文档元数据") if not isinstance(metadata, dict): metadata = {} raw_output["文档元数据"] = metadata metadata["总需求数"] = len(updates) metadata["生成时间"] = datetime.now().isoformat() type_stats: Dict[str, int] = {} for item in updates: req_type = item.get("requirementType") or "functional" cn_type = TYPE_TO_CHINESE.get(req_type, "其他需求") type_stats[cn_type] = type_stats.get(cn_type, 0) + 1 metadata["需求类型统计"] = type_stats def _to_raw_requirement_item(item: Dict[str, Any]) -> Dict[str, Any]: req_type = item.get("requirementType") or "functional" raw_item = { "需求类型": TYPE_TO_CHINESE.get(req_type, "其他需求"), "需求编号": item.get("id") or "", "需求描述": item.get("description") or "", "优先级": item.get("priority") or "中", } if req_type == "interface": raw_item["接口名称"] = item.get("interfaceName") or "" raw_item["接口类型"] = item.get("interfaceType") or "" raw_item["数据来源"] = item.get("dataSource") or "" raw_item["数据目的地"] = item.get("dataDestination") or "" return raw_item def _build_raw_output_from_flat(updates: List[Dict[str, Any]], document_name: str) -> Dict[str, Any]: grouped = _group_updates_by_section(updates) content: Dict[str, Any] = {} for key, values in grouped.items(): number, title = _parse_section_key(key) section_title = title or "未归类章节" display_key = f"{number} {section_title}".strip() content[display_key] = { "章节信息": { "章节编号": number, "章节标题": section_title, "章节级别": max(len(number.split(".")), 1) if number else 1, }, "需求列表": [_to_raw_requirement_item(item) for item in values], } raw_output = { "文档元数据": { "标题": document_name, "生成时间": datetime.now().isoformat(), "总需求数": len(updates), "需求类型统计": {}, }, "需求内容": content, } _refresh_metadata(raw_output, updates) return raw_output def _section_key(section_uid: Any, section_number: Any, section_title: Any) -> str: uid = str(section_uid or "").strip() if uid: return f"uid::{uid}" number = str(section_number or "").strip() title = str(section_title or "").strip() return f"number::{number}::title::{title}" def _parse_section_key(value: str) -> Tuple[str, str]: if value.startswith("uid::"): return "", "未归类章节" number = "" title = "" parts = value.split("::") if len(parts) >= 4: number = parts[1] title = parts[3] return number, title