Files
rag_agent/rag-web-ui/backend/app/services/srs_job_service.py

472 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Tuple
from sqlalchemy.orm import Session
from app.db.session import SessionLocal
from app.models.tooling import SRSExtraction, SRSRequirement, ToolJob
from app.tools.srs_reqs_qwen import get_srs_tool
TYPE_TO_CHINESE = {
"functional": "功能需求",
"interface": "接口需求",
"performance": "性能需求",
"security": "安全需求",
"reliability": "可靠性需求",
"other": "其他需求",
}
def _build_internal_title(description: Any, fallback: str, index: int = 0) -> str:
text = str(description or "").strip()
if not text:
return fallback or f"需求项 {index + 1}"
for separator in ("", "", "\n", ";", "."):
if separator in text:
text = text.split(separator, 1)[0].strip()
break
text = text[:20].strip()
return text or fallback or f"需求项 {index + 1}"
def _normalize_requirement_type(value: Any) -> str:
text = str(value or "").strip()
if text in {"functional", "interface", "performance", "security", "reliability", "other"}:
return text
chinese_map = {
"接口需求": "interface",
"性能需求": "performance",
"安全需求": "security",
"可靠性需求": "reliability",
"其他需求": "other",
"功能需求": "functional",
}
return chinese_map.get(text, "functional")
def run_srs_job(job_id: int) -> None:
db = SessionLocal()
try:
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
if not job:
return
job.status = "processing"
job.started_at = datetime.utcnow()
job.error_message = None
db.commit()
payload = get_srs_tool().run(job.input_file_path)
extraction = SRSExtraction(
job_id=job.id,
document_name=payload["document_name"],
document_title=payload.get("document_title") or payload["document_name"],
generated_at=_parse_generated_at(payload.get("generated_at")),
total_requirements=len(payload.get("requirements", [])),
statistics=payload.get("statistics", {}),
raw_output=payload.get("raw_output", {}),
)
db.add(extraction)
db.flush()
for item in payload.get("requirements", []):
requirement = SRSRequirement(
extraction_id=extraction.id,
requirement_uid=item["id"],
title=_build_internal_title(item.get("description"), item["id"]),
description=item.get("description") or "",
priority=item.get("priority") or "",
acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"],
source_field=item.get("source_field") or "文档解析",
section_uid=item.get("section_uid"),
section_number=item.get("section_number"),
section_title=item.get("section_title"),
requirement_type=item.get("requirement_type"),
interface_name=item.get("interface_name"),
interface_type=item.get("interface_type"),
data_source=item.get("data_source"),
data_destination=item.get("data_destination"),
sort_order=int(item.get("sort_order") or 0),
)
db.add(requirement)
job.status = "completed"
job.completed_at = datetime.utcnow()
job.output_summary = {
"total_requirements": extraction.total_requirements,
"document_name": extraction.document_name,
}
db.commit()
except Exception as exc:
db.rollback()
_mark_job_failed(job_id=job_id, error_message=str(exc))
finally:
db.close()
def _mark_job_failed(job_id: int, error_message: str) -> None:
db = SessionLocal()
try:
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
if not job:
return
job.status = "failed"
job.completed_at = datetime.utcnow()
job.error_message = error_message[:2000]
db.commit()
finally:
db.close()
def _parse_generated_at(value: Any) -> datetime:
if isinstance(value, str):
try:
return datetime.fromisoformat(value)
except ValueError:
return datetime.utcnow()
return datetime.utcnow()
def ensure_upload_path(job_id: int, file_name: str) -> Path:
target_dir = Path("uploads") / "srs_jobs" / str(job_id)
target_dir.mkdir(parents=True, exist_ok=True)
return target_dir / file_name
def build_result_response(job: ToolJob, extraction: SRSExtraction) -> Dict[str, Any]:
requirements = [_requirement_model_to_payload(item) for item in extraction.requirements]
raw_output = _merge_updates_into_raw_output(extraction.raw_output, requirements, extraction.document_name)
return {
"jobId": job.id,
"documentName": extraction.document_name,
"generatedAt": extraction.generated_at.isoformat(),
"statistics": extraction.statistics or {},
"requirements": requirements,
"rawOutput": raw_output,
}
def replace_requirements(db: Session, extraction: SRSExtraction, updates: List[Dict[str, Any]]) -> None:
normalized_updates = [_normalize_update_payload(item, index) for index, item in enumerate(updates)]
existing = {
req.requirement_uid: req
for req in db.query(SRSRequirement)
.filter(SRSRequirement.extraction_id == extraction.id)
.all()
}
seen_ids = set()
for index, item in enumerate(normalized_updates):
uid = item["id"]
seen_ids.add(uid)
req = existing.get(uid)
if req is None:
description = item.get("description") if item.get("description") is not None else ""
req = SRSRequirement(
extraction_id=extraction.id,
requirement_uid=uid,
title=_build_internal_title(description, uid, int(item.get("sortOrder") or index)),
description=description,
priority=item.get("priority") or "",
acceptance_criteria=item.get("acceptanceCriteria") or ["待补充验收标准"],
source_field=item.get("sourceField") or "文档解析",
section_uid=item.get("sectionUid"),
section_number=item.get("sectionNumber"),
section_title=item.get("sectionTitle"),
requirement_type=item.get("requirementType"),
interface_name=item.get("interfaceName"),
interface_type=item.get("interfaceType"),
data_source=item.get("dataSource"),
data_destination=item.get("dataDestination"),
sort_order=int(item.get("sortOrder") or 0),
)
db.add(req)
continue
req.description = item.get("description", req.description)
req.title = _build_internal_title(req.description, req.requirement_uid, int(item.get("sortOrder") or index))
req.priority = item.get("priority", req.priority)
req.acceptance_criteria = item.get("acceptanceCriteria", req.acceptance_criteria)
req.source_field = item.get("sourceField", req.source_field)
req.section_uid = item.get("sectionUid", req.section_uid)
req.section_number = item.get("sectionNumber", req.section_number)
req.section_title = item.get("sectionTitle", req.section_title)
req.requirement_type = item.get("requirementType", req.requirement_type)
req.interface_name = item.get("interfaceName", req.interface_name)
req.interface_type = item.get("interfaceType", req.interface_type)
req.data_source = item.get("dataSource", req.data_source)
req.data_destination = item.get("dataDestination", req.data_destination)
req.sort_order = int(item.get("sortOrder", req.sort_order))
for uid, req in existing.items():
if uid not in seen_ids:
db.delete(req)
extraction.total_requirements = len(normalized_updates)
extraction.statistics = {
"total": len(normalized_updates),
"by_type": _count_requirement_types(normalized_updates),
}
extraction.raw_output = _merge_updates_into_raw_output(
extraction.raw_output,
normalized_updates,
extraction.document_name,
)
def list_srs_history(db: Session, user_id: int) -> List[Dict[str, Any]]:
records: List[Tuple[ToolJob, SRSExtraction]] = (
db.query(ToolJob, SRSExtraction)
.join(SRSExtraction, SRSExtraction.job_id == ToolJob.id)
.filter(ToolJob.user_id == user_id)
.order_by(ToolJob.created_at.desc())
.all()
)
items: List[Dict[str, Any]] = []
for job, extraction in records:
items.append(
{
"jobId": job.id,
"documentName": extraction.document_name,
"generatedAt": extraction.generated_at.isoformat(),
"totalRequirements": extraction.total_requirements,
"status": job.status,
"createdAt": job.created_at.isoformat(),
"updatedAt": job.updated_at.isoformat(),
}
)
return items
def delete_srs_job(db: Session, job: ToolJob) -> None:
db.delete(job)
db.commit()
def build_srs_upload_path(job_id: int) -> Path:
return Path("uploads") / "srs_jobs" / str(job_id)
def _count_requirement_types(items: List[Dict[str, Any]]) -> Dict[str, int]:
stats: Dict[str, int] = {}
for item in items:
req_type = item.get("requirementType") or "functional"
stats[req_type] = stats.get(req_type, 0) + 1
return stats
def _requirement_model_to_payload(item: SRSRequirement) -> Dict[str, Any]:
return {
"id": item.requirement_uid,
"description": item.description,
"priority": item.priority,
"acceptanceCriteria": item.acceptance_criteria or [],
"sourceField": item.source_field,
"sectionUid": item.section_uid,
"sectionNumber": item.section_number,
"sectionTitle": item.section_title,
"requirementType": item.requirement_type,
"interfaceName": item.interface_name,
"interfaceType": item.interface_type,
"dataSource": item.data_source,
"dataDestination": item.data_destination,
"sortOrder": item.sort_order,
}
def _normalize_update_payload(item: Dict[str, Any], index: int) -> Dict[str, Any]:
requirement_type = _normalize_requirement_type(item.get("requirementType"))
normalized = {
"id": str(item.get("id") or f"REQ-{index + 1:03d}"),
"description": str(item.get("description") or "").strip(),
"priority": item.get("priority") or "",
"acceptanceCriteria": item.get("acceptanceCriteria") or ["待补充验收标准"],
"sourceField": item.get("sourceField") or "文档解析",
"sectionUid": item.get("sectionUid"),
"sectionNumber": item.get("sectionNumber"),
"sectionTitle": item.get("sectionTitle"),
"requirementType": requirement_type,
"interfaceName": item.get("interfaceName") or "",
"interfaceType": item.get("interfaceType") or "",
"dataSource": item.get("dataSource") or "",
"dataDestination": item.get("dataDestination") or "",
"sortOrder": int(item.get("sortOrder") or index),
}
if requirement_type != "interface":
normalized["interfaceName"] = ""
normalized["interfaceType"] = ""
normalized["dataSource"] = ""
normalized["dataDestination"] = ""
return normalized
def _merge_updates_into_raw_output(
raw_output: Dict[str, Any] | None,
updates: List[Dict[str, Any]],
document_name: str,
) -> Dict[str, Any]:
if not isinstance(raw_output, dict) or "需求内容" not in raw_output:
return _build_raw_output_from_flat(updates, document_name)
result = deepcopy(raw_output)
content = result.get("需求内容")
if not isinstance(content, dict):
return _build_raw_output_from_flat(updates, document_name)
updates_by_section = _group_updates_by_section(updates)
_rewrite_content_requirements(content, updates_by_section)
_append_unmatched_sections(content, updates_by_section)
_refresh_metadata(result, updates)
return result
def _group_updates_by_section(updates: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
grouped: Dict[str, List[Dict[str, Any]]] = {}
for item in updates:
key = _section_key(item.get("sectionUid"), item.get("sectionNumber"), item.get("sectionTitle"))
grouped.setdefault(key, []).append(item)
for values in grouped.values():
values.sort(key=lambda value: int(value.get("sortOrder") or 0))
return grouped
def _rewrite_content_requirements(content: Dict[str, Any], updates_by_section: Dict[str, List[Dict[str, Any]]]) -> None:
for section in content.values():
if not isinstance(section, dict):
continue
section_info = section.get("章节信息") or {}
section_key = _section_key(
section_info.get("章节UID"),
section_info.get("章节编号"),
section_info.get("章节标题"),
)
if section_key in updates_by_section:
section["需求列表"] = [_to_raw_requirement_item(item) for item in updates_by_section.pop(section_key)]
elif "需求列表" in section:
section["需求列表"] = []
children = section.get("子章节")
if isinstance(children, dict):
_rewrite_content_requirements(children, updates_by_section)
def _append_unmatched_sections(content: Dict[str, Any], updates_by_section: Dict[str, List[Dict[str, Any]]]) -> None:
if not updates_by_section:
return
orphan_key = "未归类章节"
orphan_section = content.get(orphan_key)
if not isinstance(orphan_section, dict):
orphan_section = {
"章节信息": {
"章节编号": "",
"章节标题": orphan_key,
"章节级别": 1,
},
"需求列表": [],
}
content[orphan_key] = orphan_section
all_reqs: List[Dict[str, Any]] = orphan_section.get("需求列表") or []
for values in updates_by_section.values():
for item in values:
all_reqs.append(_to_raw_requirement_item(item))
orphan_section["需求列表"] = all_reqs
updates_by_section.clear()
def _refresh_metadata(raw_output: Dict[str, Any], updates: List[Dict[str, Any]]) -> None:
metadata = raw_output.get("文档元数据")
if not isinstance(metadata, dict):
metadata = {}
raw_output["文档元数据"] = metadata
metadata["总需求数"] = len(updates)
metadata["生成时间"] = datetime.now().isoformat()
type_stats: Dict[str, int] = {}
for item in updates:
req_type = item.get("requirementType") or "functional"
cn_type = TYPE_TO_CHINESE.get(req_type, "其他需求")
type_stats[cn_type] = type_stats.get(cn_type, 0) + 1
metadata["需求类型统计"] = type_stats
def _to_raw_requirement_item(item: Dict[str, Any]) -> Dict[str, Any]:
req_type = item.get("requirementType") or "functional"
raw_item = {
"需求类型": TYPE_TO_CHINESE.get(req_type, "其他需求"),
"需求编号": item.get("id") or "",
"需求描述": item.get("description") or "",
"优先级": item.get("priority") or "",
}
if req_type == "interface":
raw_item["接口名称"] = item.get("interfaceName") or ""
raw_item["接口类型"] = item.get("interfaceType") or ""
raw_item["数据来源"] = item.get("dataSource") or ""
raw_item["数据目的地"] = item.get("dataDestination") or ""
return raw_item
def _build_raw_output_from_flat(updates: List[Dict[str, Any]], document_name: str) -> Dict[str, Any]:
grouped = _group_updates_by_section(updates)
content: Dict[str, Any] = {}
for key, values in grouped.items():
number, title = _parse_section_key(key)
section_title = title or "未归类章节"
display_key = f"{number} {section_title}".strip()
content[display_key] = {
"章节信息": {
"章节编号": number,
"章节标题": section_title,
"章节级别": max(len(number.split(".")), 1) if number else 1,
},
"需求列表": [_to_raw_requirement_item(item) for item in values],
}
raw_output = {
"文档元数据": {
"标题": document_name,
"生成时间": datetime.now().isoformat(),
"总需求数": len(updates),
"需求类型统计": {},
},
"需求内容": content,
}
_refresh_metadata(raw_output, updates)
return raw_output
def _section_key(section_uid: Any, section_number: Any, section_title: Any) -> str:
uid = str(section_uid or "").strip()
if uid:
return f"uid::{uid}"
number = str(section_number or "").strip()
title = str(section_title or "").strip()
return f"number::{number}::title::{title}"
def _parse_section_key(value: str) -> Tuple[str, str]:
if value.startswith("uid::"):
return "", "未归类章节"
number = ""
title = ""
parts = value.split("::")
if len(parts) >= 4:
number = parts[1]
title = parts[3]
return number, title