475 lines
18 KiB
Python
475 lines
18 KiB
Python
from __future__ import annotations
|
||
|
||
from copy import deepcopy
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Tuple
|
||
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.db.session import SessionLocal
|
||
from app.models.tooling import SRSExtraction, SRSRequirement, ToolJob
|
||
from app.services.model_config import ModelConfigService
|
||
from app.tools.srs_reqs_qwen import get_srs_tool
|
||
|
||
TYPE_TO_CHINESE = {
|
||
"functional": "功能需求",
|
||
"interface": "接口需求",
|
||
"performance": "性能需求",
|
||
"security": "安全需求",
|
||
"reliability": "可靠性需求",
|
||
"other": "其他需求",
|
||
}
|
||
|
||
|
||
def _build_internal_title(description: Any, fallback: str, index: int = 0) -> str:
|
||
text = str(description or "").strip()
|
||
if not text:
|
||
return fallback or f"需求项 {index + 1}"
|
||
|
||
for separator in ("。", ";", "\n", ";", "."):
|
||
if separator in text:
|
||
text = text.split(separator, 1)[0].strip()
|
||
break
|
||
|
||
text = text[:20].strip()
|
||
return text or fallback or f"需求项 {index + 1}"
|
||
|
||
|
||
def _normalize_requirement_type(value: Any) -> str:
|
||
text = str(value or "").strip()
|
||
if text in {"functional", "interface", "performance", "security", "reliability", "other"}:
|
||
return text
|
||
|
||
chinese_map = {
|
||
"接口需求": "interface",
|
||
"性能需求": "performance",
|
||
"安全需求": "security",
|
||
"可靠性需求": "reliability",
|
||
"其他需求": "other",
|
||
"功能需求": "functional",
|
||
}
|
||
return chinese_map.get(text, "functional")
|
||
|
||
|
||
def run_srs_job(job_id: int) -> None:
|
||
db = SessionLocal()
|
||
try:
|
||
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
|
||
if not job:
|
||
return
|
||
|
||
job.status = "processing"
|
||
job.started_at = datetime.utcnow()
|
||
job.error_message = None
|
||
db.commit()
|
||
|
||
model_profile = ModelConfigService.require_active_config(db, job.user_id)
|
||
ModelConfigService.touch_last_used(db, model_profile)
|
||
payload = get_srs_tool().run(job.input_file_path, model_profile=model_profile)
|
||
|
||
extraction = SRSExtraction(
|
||
job_id=job.id,
|
||
document_name=payload["document_name"],
|
||
document_title=payload.get("document_title") or payload["document_name"],
|
||
generated_at=_parse_generated_at(payload.get("generated_at")),
|
||
total_requirements=len(payload.get("requirements", [])),
|
||
statistics=payload.get("statistics", {}),
|
||
raw_output=payload.get("raw_output", {}),
|
||
)
|
||
db.add(extraction)
|
||
db.flush()
|
||
|
||
for item in payload.get("requirements", []):
|
||
requirement = SRSRequirement(
|
||
extraction_id=extraction.id,
|
||
requirement_uid=item["id"],
|
||
title=_build_internal_title(item.get("description"), item["id"]),
|
||
description=item.get("description") or "",
|
||
priority=item.get("priority") or "中",
|
||
acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"],
|
||
source_field=item.get("source_field") or "文档解析",
|
||
section_uid=item.get("section_uid"),
|
||
section_number=item.get("section_number"),
|
||
section_title=item.get("section_title"),
|
||
requirement_type=item.get("requirement_type"),
|
||
interface_name=item.get("interface_name"),
|
||
interface_type=item.get("interface_type"),
|
||
data_source=item.get("data_source"),
|
||
data_destination=item.get("data_destination"),
|
||
sort_order=int(item.get("sort_order") or 0),
|
||
)
|
||
db.add(requirement)
|
||
|
||
job.status = "completed"
|
||
job.completed_at = datetime.utcnow()
|
||
job.output_summary = {
|
||
"total_requirements": extraction.total_requirements,
|
||
"document_name": extraction.document_name,
|
||
}
|
||
db.commit()
|
||
except Exception as exc:
|
||
db.rollback()
|
||
_mark_job_failed(job_id=job_id, error_message=str(exc))
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
def _mark_job_failed(job_id: int, error_message: str) -> None:
|
||
db = SessionLocal()
|
||
try:
|
||
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
|
||
if not job:
|
||
return
|
||
job.status = "failed"
|
||
job.completed_at = datetime.utcnow()
|
||
job.error_message = error_message[:2000]
|
||
db.commit()
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
def _parse_generated_at(value: Any) -> datetime:
|
||
if isinstance(value, str):
|
||
try:
|
||
return datetime.fromisoformat(value)
|
||
except ValueError:
|
||
return datetime.utcnow()
|
||
return datetime.utcnow()
|
||
|
||
|
||
def ensure_upload_path(job_id: int, file_name: str) -> Path:
|
||
target_dir = Path("uploads") / "srs_jobs" / str(job_id)
|
||
target_dir.mkdir(parents=True, exist_ok=True)
|
||
return target_dir / file_name
|
||
|
||
|
||
def build_result_response(job: ToolJob, extraction: SRSExtraction) -> Dict[str, Any]:
|
||
requirements = [_requirement_model_to_payload(item) for item in extraction.requirements]
|
||
raw_output = _merge_updates_into_raw_output(extraction.raw_output, requirements, extraction.document_name)
|
||
|
||
return {
|
||
"jobId": job.id,
|
||
"documentName": extraction.document_name,
|
||
"generatedAt": extraction.generated_at.isoformat(),
|
||
"statistics": extraction.statistics or {},
|
||
"requirements": requirements,
|
||
"rawOutput": raw_output,
|
||
}
|
||
|
||
|
||
def replace_requirements(db: Session, extraction: SRSExtraction, updates: List[Dict[str, Any]]) -> None:
|
||
normalized_updates = [_normalize_update_payload(item, index) for index, item in enumerate(updates)]
|
||
existing = {
|
||
req.requirement_uid: req
|
||
for req in db.query(SRSRequirement)
|
||
.filter(SRSRequirement.extraction_id == extraction.id)
|
||
.all()
|
||
}
|
||
seen_ids = set()
|
||
|
||
for index, item in enumerate(normalized_updates):
|
||
uid = item["id"]
|
||
seen_ids.add(uid)
|
||
req = existing.get(uid)
|
||
if req is None:
|
||
description = item.get("description") if item.get("description") is not None else ""
|
||
req = SRSRequirement(
|
||
extraction_id=extraction.id,
|
||
requirement_uid=uid,
|
||
title=_build_internal_title(description, uid, int(item.get("sortOrder") or index)),
|
||
description=description,
|
||
priority=item.get("priority") or "中",
|
||
acceptance_criteria=item.get("acceptanceCriteria") or ["待补充验收标准"],
|
||
source_field=item.get("sourceField") or "文档解析",
|
||
section_uid=item.get("sectionUid"),
|
||
section_number=item.get("sectionNumber"),
|
||
section_title=item.get("sectionTitle"),
|
||
requirement_type=item.get("requirementType"),
|
||
interface_name=item.get("interfaceName"),
|
||
interface_type=item.get("interfaceType"),
|
||
data_source=item.get("dataSource"),
|
||
data_destination=item.get("dataDestination"),
|
||
sort_order=int(item.get("sortOrder") or 0),
|
||
)
|
||
db.add(req)
|
||
continue
|
||
|
||
req.description = item.get("description", req.description)
|
||
req.title = _build_internal_title(req.description, req.requirement_uid, int(item.get("sortOrder") or index))
|
||
req.priority = item.get("priority", req.priority)
|
||
req.acceptance_criteria = item.get("acceptanceCriteria", req.acceptance_criteria)
|
||
req.source_field = item.get("sourceField", req.source_field)
|
||
req.section_uid = item.get("sectionUid", req.section_uid)
|
||
req.section_number = item.get("sectionNumber", req.section_number)
|
||
req.section_title = item.get("sectionTitle", req.section_title)
|
||
req.requirement_type = item.get("requirementType", req.requirement_type)
|
||
req.interface_name = item.get("interfaceName", req.interface_name)
|
||
req.interface_type = item.get("interfaceType", req.interface_type)
|
||
req.data_source = item.get("dataSource", req.data_source)
|
||
req.data_destination = item.get("dataDestination", req.data_destination)
|
||
req.sort_order = int(item.get("sortOrder", req.sort_order))
|
||
|
||
for uid, req in existing.items():
|
||
if uid not in seen_ids:
|
||
db.delete(req)
|
||
|
||
extraction.total_requirements = len(normalized_updates)
|
||
extraction.statistics = {
|
||
"total": len(normalized_updates),
|
||
"by_type": _count_requirement_types(normalized_updates),
|
||
}
|
||
extraction.raw_output = _merge_updates_into_raw_output(
|
||
extraction.raw_output,
|
||
normalized_updates,
|
||
extraction.document_name,
|
||
)
|
||
|
||
|
||
def list_srs_history(db: Session, user_id: int) -> List[Dict[str, Any]]:
|
||
records: List[Tuple[ToolJob, SRSExtraction]] = (
|
||
db.query(ToolJob, SRSExtraction)
|
||
.join(SRSExtraction, SRSExtraction.job_id == ToolJob.id)
|
||
.filter(ToolJob.user_id == user_id)
|
||
.order_by(ToolJob.created_at.desc())
|
||
.all()
|
||
)
|
||
items: List[Dict[str, Any]] = []
|
||
for job, extraction in records:
|
||
items.append(
|
||
{
|
||
"jobId": job.id,
|
||
"documentName": extraction.document_name,
|
||
"generatedAt": extraction.generated_at.isoformat(),
|
||
"totalRequirements": extraction.total_requirements,
|
||
"status": job.status,
|
||
"createdAt": job.created_at.isoformat(),
|
||
"updatedAt": job.updated_at.isoformat(),
|
||
}
|
||
)
|
||
return items
|
||
|
||
|
||
def delete_srs_job(db: Session, job: ToolJob) -> None:
|
||
db.delete(job)
|
||
db.commit()
|
||
|
||
|
||
def build_srs_upload_path(job_id: int) -> Path:
|
||
return Path("uploads") / "srs_jobs" / str(job_id)
|
||
|
||
|
||
def _count_requirement_types(items: List[Dict[str, Any]]) -> Dict[str, int]:
|
||
stats: Dict[str, int] = {}
|
||
for item in items:
|
||
req_type = item.get("requirementType") or "functional"
|
||
stats[req_type] = stats.get(req_type, 0) + 1
|
||
return stats
|
||
|
||
|
||
def _requirement_model_to_payload(item: SRSRequirement) -> Dict[str, Any]:
|
||
return {
|
||
"id": item.requirement_uid,
|
||
"description": item.description,
|
||
"priority": item.priority,
|
||
"acceptanceCriteria": item.acceptance_criteria or [],
|
||
"sourceField": item.source_field,
|
||
"sectionUid": item.section_uid,
|
||
"sectionNumber": item.section_number,
|
||
"sectionTitle": item.section_title,
|
||
"requirementType": item.requirement_type,
|
||
"interfaceName": item.interface_name,
|
||
"interfaceType": item.interface_type,
|
||
"dataSource": item.data_source,
|
||
"dataDestination": item.data_destination,
|
||
"sortOrder": item.sort_order,
|
||
}
|
||
|
||
|
||
def _normalize_update_payload(item: Dict[str, Any], index: int) -> Dict[str, Any]:
|
||
requirement_type = _normalize_requirement_type(item.get("requirementType"))
|
||
normalized = {
|
||
"id": str(item.get("id") or f"REQ-{index + 1:03d}"),
|
||
"description": str(item.get("description") or "").strip(),
|
||
"priority": item.get("priority") or "中",
|
||
"acceptanceCriteria": item.get("acceptanceCriteria") or ["待补充验收标准"],
|
||
"sourceField": item.get("sourceField") or "文档解析",
|
||
"sectionUid": item.get("sectionUid"),
|
||
"sectionNumber": item.get("sectionNumber"),
|
||
"sectionTitle": item.get("sectionTitle"),
|
||
"requirementType": requirement_type,
|
||
"interfaceName": item.get("interfaceName") or "",
|
||
"interfaceType": item.get("interfaceType") or "",
|
||
"dataSource": item.get("dataSource") or "",
|
||
"dataDestination": item.get("dataDestination") or "",
|
||
"sortOrder": int(item.get("sortOrder") or index),
|
||
}
|
||
|
||
if requirement_type != "interface":
|
||
normalized["interfaceName"] = ""
|
||
normalized["interfaceType"] = ""
|
||
normalized["dataSource"] = ""
|
||
normalized["dataDestination"] = ""
|
||
|
||
return normalized
|
||
|
||
|
||
def _merge_updates_into_raw_output(
|
||
raw_output: Dict[str, Any] | None,
|
||
updates: List[Dict[str, Any]],
|
||
document_name: str,
|
||
) -> Dict[str, Any]:
|
||
if not isinstance(raw_output, dict) or "需求内容" not in raw_output:
|
||
return _build_raw_output_from_flat(updates, document_name)
|
||
|
||
result = deepcopy(raw_output)
|
||
content = result.get("需求内容")
|
||
if not isinstance(content, dict):
|
||
return _build_raw_output_from_flat(updates, document_name)
|
||
|
||
updates_by_section = _group_updates_by_section(updates)
|
||
_rewrite_content_requirements(content, updates_by_section)
|
||
_append_unmatched_sections(content, updates_by_section)
|
||
_refresh_metadata(result, updates)
|
||
return result
|
||
|
||
|
||
def _group_updates_by_section(updates: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
|
||
grouped: Dict[str, List[Dict[str, Any]]] = {}
|
||
for item in updates:
|
||
key = _section_key(item.get("sectionUid"), item.get("sectionNumber"), item.get("sectionTitle"))
|
||
grouped.setdefault(key, []).append(item)
|
||
|
||
for values in grouped.values():
|
||
values.sort(key=lambda value: int(value.get("sortOrder") or 0))
|
||
return grouped
|
||
|
||
|
||
def _rewrite_content_requirements(content: Dict[str, Any], updates_by_section: Dict[str, List[Dict[str, Any]]]) -> None:
|
||
for section in content.values():
|
||
if not isinstance(section, dict):
|
||
continue
|
||
|
||
section_info = section.get("章节信息") or {}
|
||
section_key = _section_key(
|
||
section_info.get("章节UID"),
|
||
section_info.get("章节编号"),
|
||
section_info.get("章节标题"),
|
||
)
|
||
if section_key in updates_by_section:
|
||
section["需求列表"] = [_to_raw_requirement_item(item) for item in updates_by_section.pop(section_key)]
|
||
elif "需求列表" in section:
|
||
section["需求列表"] = []
|
||
|
||
children = section.get("子章节")
|
||
if isinstance(children, dict):
|
||
_rewrite_content_requirements(children, updates_by_section)
|
||
|
||
|
||
def _append_unmatched_sections(content: Dict[str, Any], updates_by_section: Dict[str, List[Dict[str, Any]]]) -> None:
|
||
if not updates_by_section:
|
||
return
|
||
|
||
orphan_key = "未归类章节"
|
||
orphan_section = content.get(orphan_key)
|
||
if not isinstance(orphan_section, dict):
|
||
orphan_section = {
|
||
"章节信息": {
|
||
"章节编号": "",
|
||
"章节标题": orphan_key,
|
||
"章节级别": 1,
|
||
},
|
||
"需求列表": [],
|
||
}
|
||
content[orphan_key] = orphan_section
|
||
|
||
all_reqs: List[Dict[str, Any]] = orphan_section.get("需求列表") or []
|
||
for values in updates_by_section.values():
|
||
for item in values:
|
||
all_reqs.append(_to_raw_requirement_item(item))
|
||
orphan_section["需求列表"] = all_reqs
|
||
updates_by_section.clear()
|
||
|
||
|
||
def _refresh_metadata(raw_output: Dict[str, Any], updates: List[Dict[str, Any]]) -> None:
|
||
metadata = raw_output.get("文档元数据")
|
||
if not isinstance(metadata, dict):
|
||
metadata = {}
|
||
raw_output["文档元数据"] = metadata
|
||
|
||
metadata["总需求数"] = len(updates)
|
||
metadata["生成时间"] = datetime.now().isoformat()
|
||
|
||
type_stats: Dict[str, int] = {}
|
||
for item in updates:
|
||
req_type = item.get("requirementType") or "functional"
|
||
cn_type = TYPE_TO_CHINESE.get(req_type, "其他需求")
|
||
type_stats[cn_type] = type_stats.get(cn_type, 0) + 1
|
||
metadata["需求类型统计"] = type_stats
|
||
|
||
|
||
def _to_raw_requirement_item(item: Dict[str, Any]) -> Dict[str, Any]:
|
||
req_type = item.get("requirementType") or "functional"
|
||
raw_item = {
|
||
"需求类型": TYPE_TO_CHINESE.get(req_type, "其他需求"),
|
||
"需求编号": item.get("id") or "",
|
||
"需求描述": item.get("description") or "",
|
||
"优先级": item.get("priority") or "中",
|
||
}
|
||
if req_type == "interface":
|
||
raw_item["接口名称"] = item.get("interfaceName") or ""
|
||
raw_item["接口类型"] = item.get("interfaceType") or ""
|
||
raw_item["数据来源"] = item.get("dataSource") or ""
|
||
raw_item["数据目的地"] = item.get("dataDestination") or ""
|
||
return raw_item
|
||
|
||
|
||
def _build_raw_output_from_flat(updates: List[Dict[str, Any]], document_name: str) -> Dict[str, Any]:
|
||
grouped = _group_updates_by_section(updates)
|
||
content: Dict[str, Any] = {}
|
||
for key, values in grouped.items():
|
||
number, title = _parse_section_key(key)
|
||
section_title = title or "未归类章节"
|
||
display_key = f"{number} {section_title}".strip()
|
||
content[display_key] = {
|
||
"章节信息": {
|
||
"章节编号": number,
|
||
"章节标题": section_title,
|
||
"章节级别": max(len(number.split(".")), 1) if number else 1,
|
||
},
|
||
"需求列表": [_to_raw_requirement_item(item) for item in values],
|
||
}
|
||
|
||
raw_output = {
|
||
"文档元数据": {
|
||
"标题": document_name,
|
||
"生成时间": datetime.now().isoformat(),
|
||
"总需求数": len(updates),
|
||
"需求类型统计": {},
|
||
},
|
||
"需求内容": content,
|
||
}
|
||
_refresh_metadata(raw_output, updates)
|
||
return raw_output
|
||
|
||
|
||
def _section_key(section_uid: Any, section_number: Any, section_title: Any) -> str:
|
||
uid = str(section_uid or "").strip()
|
||
if uid:
|
||
return f"uid::{uid}"
|
||
number = str(section_number or "").strip()
|
||
title = str(section_title or "").strip()
|
||
return f"number::{number}::title::{title}"
|
||
|
||
|
||
def _parse_section_key(value: str) -> Tuple[str, str]:
|
||
if value.startswith("uid::"):
|
||
return "", "未归类章节"
|
||
number = ""
|
||
title = ""
|
||
parts = value.split("::")
|
||
if len(parts) >= 4:
|
||
number = parts[1]
|
||
title = parts[3]
|
||
return number, title
|