Files
rag_agent/rag-web-ui/backend/app/services/srs_job_service.py

475 lines
18 KiB
Python
Raw Normal View History

2026-04-13 11:34:23 +08:00
from __future__ import annotations
from copy import deepcopy
2026-04-13 11:34:23 +08:00
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Tuple
2026-04-13 11:34:23 +08:00
from sqlalchemy.orm import Session
from app.db.session import SessionLocal
from app.models.tooling import SRSExtraction, SRSRequirement, ToolJob
from app.services.model_config import ModelConfigService
2026-04-13 11:34:23 +08:00
from app.tools.srs_reqs_qwen import get_srs_tool
TYPE_TO_CHINESE = {
"functional": "功能需求",
"interface": "接口需求",
"performance": "性能需求",
"security": "安全需求",
"reliability": "可靠性需求",
"other": "其他需求",
}
def _build_internal_title(description: Any, fallback: str, index: int = 0) -> str:
text = str(description or "").strip()
if not text:
return fallback or f"需求项 {index + 1}"
for separator in ("", "", "\n", ";", "."):
if separator in text:
text = text.split(separator, 1)[0].strip()
break
text = text[:20].strip()
return text or fallback or f"需求项 {index + 1}"
def _normalize_requirement_type(value: Any) -> str:
text = str(value or "").strip()
if text in {"functional", "interface", "performance", "security", "reliability", "other"}:
return text
chinese_map = {
"接口需求": "interface",
"性能需求": "performance",
"安全需求": "security",
"可靠性需求": "reliability",
"其他需求": "other",
"功能需求": "functional",
}
return chinese_map.get(text, "functional")
2026-04-13 11:34:23 +08:00
def run_srs_job(job_id: int) -> None:
db = SessionLocal()
try:
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
if not job:
return
job.status = "processing"
job.started_at = datetime.utcnow()
job.error_message = None
db.commit()
model_profile = ModelConfigService.require_active_config(db, job.user_id)
ModelConfigService.touch_last_used(db, model_profile)
payload = get_srs_tool().run(job.input_file_path, model_profile=model_profile)
2026-04-13 11:34:23 +08:00
extraction = SRSExtraction(
job_id=job.id,
document_name=payload["document_name"],
document_title=payload.get("document_title") or payload["document_name"],
generated_at=_parse_generated_at(payload.get("generated_at")),
total_requirements=len(payload.get("requirements", [])),
statistics=payload.get("statistics", {}),
raw_output=payload.get("raw_output", {}),
)
db.add(extraction)
db.flush()
for item in payload.get("requirements", []):
requirement = SRSRequirement(
extraction_id=extraction.id,
requirement_uid=item["id"],
title=_build_internal_title(item.get("description"), item["id"]),
2026-04-13 11:34:23 +08:00
description=item.get("description") or "",
priority=item.get("priority") or "",
acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"],
source_field=item.get("source_field") or "文档解析",
section_uid=item.get("section_uid"),
2026-04-13 11:34:23 +08:00
section_number=item.get("section_number"),
section_title=item.get("section_title"),
requirement_type=item.get("requirement_type"),
interface_name=item.get("interface_name"),
interface_type=item.get("interface_type"),
data_source=item.get("data_source"),
data_destination=item.get("data_destination"),
2026-04-13 11:34:23 +08:00
sort_order=int(item.get("sort_order") or 0),
)
db.add(requirement)
job.status = "completed"
job.completed_at = datetime.utcnow()
job.output_summary = {
"total_requirements": extraction.total_requirements,
"document_name": extraction.document_name,
}
db.commit()
except Exception as exc:
db.rollback()
_mark_job_failed(job_id=job_id, error_message=str(exc))
finally:
db.close()
def _mark_job_failed(job_id: int, error_message: str) -> None:
db = SessionLocal()
try:
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
if not job:
return
job.status = "failed"
job.completed_at = datetime.utcnow()
job.error_message = error_message[:2000]
db.commit()
finally:
db.close()
def _parse_generated_at(value: Any) -> datetime:
if isinstance(value, str):
try:
return datetime.fromisoformat(value)
except ValueError:
return datetime.utcnow()
return datetime.utcnow()
def ensure_upload_path(job_id: int, file_name: str) -> Path:
target_dir = Path("uploads") / "srs_jobs" / str(job_id)
target_dir.mkdir(parents=True, exist_ok=True)
return target_dir / file_name
def build_result_response(job: ToolJob, extraction: SRSExtraction) -> Dict[str, Any]:
requirements = [_requirement_model_to_payload(item) for item in extraction.requirements]
raw_output = _merge_updates_into_raw_output(extraction.raw_output, requirements, extraction.document_name)
2026-04-13 11:34:23 +08:00
return {
"jobId": job.id,
"documentName": extraction.document_name,
"generatedAt": extraction.generated_at.isoformat(),
"statistics": extraction.statistics or {},
"requirements": requirements,
"rawOutput": raw_output,
2026-04-13 11:34:23 +08:00
}
def replace_requirements(db: Session, extraction: SRSExtraction, updates: List[Dict[str, Any]]) -> None:
normalized_updates = [_normalize_update_payload(item, index) for index, item in enumerate(updates)]
2026-04-13 11:34:23 +08:00
existing = {
req.requirement_uid: req
for req in db.query(SRSRequirement)
.filter(SRSRequirement.extraction_id == extraction.id)
.all()
}
seen_ids = set()
for index, item in enumerate(normalized_updates):
2026-04-13 11:34:23 +08:00
uid = item["id"]
seen_ids.add(uid)
req = existing.get(uid)
if req is None:
description = item.get("description") if item.get("description") is not None else ""
2026-04-13 11:34:23 +08:00
req = SRSRequirement(
extraction_id=extraction.id,
requirement_uid=uid,
title=_build_internal_title(description, uid, int(item.get("sortOrder") or index)),
description=description,
2026-04-13 11:34:23 +08:00
priority=item.get("priority") or "",
acceptance_criteria=item.get("acceptanceCriteria") or ["待补充验收标准"],
source_field=item.get("sourceField") or "文档解析",
section_uid=item.get("sectionUid"),
2026-04-13 11:34:23 +08:00
section_number=item.get("sectionNumber"),
section_title=item.get("sectionTitle"),
requirement_type=item.get("requirementType"),
interface_name=item.get("interfaceName"),
interface_type=item.get("interfaceType"),
data_source=item.get("dataSource"),
data_destination=item.get("dataDestination"),
sort_order=int(item.get("sortOrder") or 0),
2026-04-13 11:34:23 +08:00
)
db.add(req)
continue
req.description = item.get("description", req.description)
req.title = _build_internal_title(req.description, req.requirement_uid, int(item.get("sortOrder") or index))
2026-04-13 11:34:23 +08:00
req.priority = item.get("priority", req.priority)
req.acceptance_criteria = item.get("acceptanceCriteria", req.acceptance_criteria)
req.source_field = item.get("sourceField", req.source_field)
req.section_uid = item.get("sectionUid", req.section_uid)
2026-04-13 11:34:23 +08:00
req.section_number = item.get("sectionNumber", req.section_number)
req.section_title = item.get("sectionTitle", req.section_title)
req.requirement_type = item.get("requirementType", req.requirement_type)
req.interface_name = item.get("interfaceName", req.interface_name)
req.interface_type = item.get("interfaceType", req.interface_type)
req.data_source = item.get("dataSource", req.data_source)
req.data_destination = item.get("dataDestination", req.data_destination)
req.sort_order = int(item.get("sortOrder", req.sort_order))
2026-04-13 11:34:23 +08:00
for uid, req in existing.items():
if uid not in seen_ids:
db.delete(req)
extraction.total_requirements = len(normalized_updates)
2026-04-13 11:34:23 +08:00
extraction.statistics = {
"total": len(normalized_updates),
"by_type": _count_requirement_types(normalized_updates),
2026-04-13 11:34:23 +08:00
}
extraction.raw_output = _merge_updates_into_raw_output(
extraction.raw_output,
normalized_updates,
extraction.document_name,
)
def list_srs_history(db: Session, user_id: int) -> List[Dict[str, Any]]:
records: List[Tuple[ToolJob, SRSExtraction]] = (
db.query(ToolJob, SRSExtraction)
.join(SRSExtraction, SRSExtraction.job_id == ToolJob.id)
.filter(ToolJob.user_id == user_id)
.order_by(ToolJob.created_at.desc())
.all()
)
items: List[Dict[str, Any]] = []
for job, extraction in records:
items.append(
{
"jobId": job.id,
"documentName": extraction.document_name,
"generatedAt": extraction.generated_at.isoformat(),
"totalRequirements": extraction.total_requirements,
"status": job.status,
"createdAt": job.created_at.isoformat(),
"updatedAt": job.updated_at.isoformat(),
}
)
return items
def delete_srs_job(db: Session, job: ToolJob) -> None:
db.delete(job)
db.commit()
def build_srs_upload_path(job_id: int) -> Path:
return Path("uploads") / "srs_jobs" / str(job_id)
2026-04-13 11:34:23 +08:00
def _count_requirement_types(items: List[Dict[str, Any]]) -> Dict[str, int]:
stats: Dict[str, int] = {}
for item in items:
req_type = item.get("requirementType") or "functional"
stats[req_type] = stats.get(req_type, 0) + 1
return stats
def _requirement_model_to_payload(item: SRSRequirement) -> Dict[str, Any]:
return {
"id": item.requirement_uid,
"description": item.description,
"priority": item.priority,
"acceptanceCriteria": item.acceptance_criteria or [],
"sourceField": item.source_field,
"sectionUid": item.section_uid,
"sectionNumber": item.section_number,
"sectionTitle": item.section_title,
"requirementType": item.requirement_type,
"interfaceName": item.interface_name,
"interfaceType": item.interface_type,
"dataSource": item.data_source,
"dataDestination": item.data_destination,
"sortOrder": item.sort_order,
}
def _normalize_update_payload(item: Dict[str, Any], index: int) -> Dict[str, Any]:
requirement_type = _normalize_requirement_type(item.get("requirementType"))
normalized = {
"id": str(item.get("id") or f"REQ-{index + 1:03d}"),
"description": str(item.get("description") or "").strip(),
"priority": item.get("priority") or "",
"acceptanceCriteria": item.get("acceptanceCriteria") or ["待补充验收标准"],
"sourceField": item.get("sourceField") or "文档解析",
"sectionUid": item.get("sectionUid"),
"sectionNumber": item.get("sectionNumber"),
"sectionTitle": item.get("sectionTitle"),
"requirementType": requirement_type,
"interfaceName": item.get("interfaceName") or "",
"interfaceType": item.get("interfaceType") or "",
"dataSource": item.get("dataSource") or "",
"dataDestination": item.get("dataDestination") or "",
"sortOrder": int(item.get("sortOrder") or index),
}
if requirement_type != "interface":
normalized["interfaceName"] = ""
normalized["interfaceType"] = ""
normalized["dataSource"] = ""
normalized["dataDestination"] = ""
return normalized
def _merge_updates_into_raw_output(
raw_output: Dict[str, Any] | None,
updates: List[Dict[str, Any]],
document_name: str,
) -> Dict[str, Any]:
if not isinstance(raw_output, dict) or "需求内容" not in raw_output:
return _build_raw_output_from_flat(updates, document_name)
result = deepcopy(raw_output)
content = result.get("需求内容")
if not isinstance(content, dict):
return _build_raw_output_from_flat(updates, document_name)
updates_by_section = _group_updates_by_section(updates)
_rewrite_content_requirements(content, updates_by_section)
_append_unmatched_sections(content, updates_by_section)
_refresh_metadata(result, updates)
return result
def _group_updates_by_section(updates: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
grouped: Dict[str, List[Dict[str, Any]]] = {}
for item in updates:
key = _section_key(item.get("sectionUid"), item.get("sectionNumber"), item.get("sectionTitle"))
grouped.setdefault(key, []).append(item)
for values in grouped.values():
values.sort(key=lambda value: int(value.get("sortOrder") or 0))
return grouped
def _rewrite_content_requirements(content: Dict[str, Any], updates_by_section: Dict[str, List[Dict[str, Any]]]) -> None:
for section in content.values():
if not isinstance(section, dict):
continue
section_info = section.get("章节信息") or {}
section_key = _section_key(
section_info.get("章节UID"),
section_info.get("章节编号"),
section_info.get("章节标题"),
)
if section_key in updates_by_section:
section["需求列表"] = [_to_raw_requirement_item(item) for item in updates_by_section.pop(section_key)]
elif "需求列表" in section:
section["需求列表"] = []
children = section.get("子章节")
if isinstance(children, dict):
_rewrite_content_requirements(children, updates_by_section)
def _append_unmatched_sections(content: Dict[str, Any], updates_by_section: Dict[str, List[Dict[str, Any]]]) -> None:
if not updates_by_section:
return
orphan_key = "未归类章节"
orphan_section = content.get(orphan_key)
if not isinstance(orphan_section, dict):
orphan_section = {
"章节信息": {
"章节编号": "",
"章节标题": orphan_key,
"章节级别": 1,
},
"需求列表": [],
}
content[orphan_key] = orphan_section
all_reqs: List[Dict[str, Any]] = orphan_section.get("需求列表") or []
for values in updates_by_section.values():
for item in values:
all_reqs.append(_to_raw_requirement_item(item))
orphan_section["需求列表"] = all_reqs
updates_by_section.clear()
def _refresh_metadata(raw_output: Dict[str, Any], updates: List[Dict[str, Any]]) -> None:
metadata = raw_output.get("文档元数据")
if not isinstance(metadata, dict):
metadata = {}
raw_output["文档元数据"] = metadata
metadata["总需求数"] = len(updates)
metadata["生成时间"] = datetime.now().isoformat()
type_stats: Dict[str, int] = {}
for item in updates:
req_type = item.get("requirementType") or "functional"
cn_type = TYPE_TO_CHINESE.get(req_type, "其他需求")
type_stats[cn_type] = type_stats.get(cn_type, 0) + 1
metadata["需求类型统计"] = type_stats
def _to_raw_requirement_item(item: Dict[str, Any]) -> Dict[str, Any]:
req_type = item.get("requirementType") or "functional"
raw_item = {
"需求类型": TYPE_TO_CHINESE.get(req_type, "其他需求"),
"需求编号": item.get("id") or "",
"需求描述": item.get("description") or "",
"优先级": item.get("priority") or "",
}
if req_type == "interface":
raw_item["接口名称"] = item.get("interfaceName") or ""
raw_item["接口类型"] = item.get("interfaceType") or ""
raw_item["数据来源"] = item.get("dataSource") or ""
raw_item["数据目的地"] = item.get("dataDestination") or ""
return raw_item
def _build_raw_output_from_flat(updates: List[Dict[str, Any]], document_name: str) -> Dict[str, Any]:
grouped = _group_updates_by_section(updates)
content: Dict[str, Any] = {}
for key, values in grouped.items():
number, title = _parse_section_key(key)
section_title = title or "未归类章节"
display_key = f"{number} {section_title}".strip()
content[display_key] = {
"章节信息": {
"章节编号": number,
"章节标题": section_title,
"章节级别": max(len(number.split(".")), 1) if number else 1,
},
"需求列表": [_to_raw_requirement_item(item) for item in values],
}
raw_output = {
"文档元数据": {
"标题": document_name,
"生成时间": datetime.now().isoformat(),
"总需求数": len(updates),
"需求类型统计": {},
},
"需求内容": content,
}
_refresh_metadata(raw_output, updates)
return raw_output
def _section_key(section_uid: Any, section_number: Any, section_title: Any) -> str:
uid = str(section_uid or "").strip()
if uid:
return f"uid::{uid}"
number = str(section_number or "").strip()
title = str(section_title or "").strip()
return f"number::{number}::title::{title}"
def _parse_section_key(value: str) -> Tuple[str, str]:
if value.startswith("uid::"):
return "", "未归类章节"
number = ""
title = ""
parts = value.split("::")
if len(parts) >= 4:
number = parts[1]
title = parts[3]
return number, title