213 lines
8.2 KiB
Python
213 lines
8.2 KiB
Python
from __future__ import annotations
|
||
|
||
import os
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List
|
||
|
||
import yaml
|
||
|
||
from app.core.config import settings
|
||
from app.tools.base import ToolDefinition
|
||
from app.tools.registry import ToolRegistry
|
||
from app.tools.srs_reqs_qwen.src.document_parser import create_parser
|
||
from app.tools.srs_reqs_qwen.src.json_generator import JSONGenerator
|
||
from app.tools.srs_reqs_qwen.src.llm_interface import QwenLLM
|
||
from app.tools.srs_reqs_qwen.src.requirement_extractor import RequirementExtractor
|
||
|
||
|
||
class SRSTool:
|
||
TOOL_NAME = "srs.requirement_extractor"
|
||
|
||
DEFINITION = ToolDefinition(
|
||
name=TOOL_NAME,
|
||
version="1.0.0",
|
||
description="Extract structured requirements from SRS documents.",
|
||
input_schema={
|
||
"type": "object",
|
||
"properties": {
|
||
"file_path": {"type": "string"},
|
||
"enable_llm": {"type": "boolean"},
|
||
},
|
||
"required": ["file_path"],
|
||
},
|
||
output_schema={
|
||
"type": "object",
|
||
"properties": {
|
||
"document_name": {"type": "string"},
|
||
"generated_at": {"type": "string"},
|
||
"requirements": {"type": "array"},
|
||
"statistics": {"type": "object"},
|
||
"raw_output": {"type": "object"},
|
||
},
|
||
},
|
||
)
|
||
|
||
PRIORITY_BY_TYPE = {
|
||
"functional": "中",
|
||
"interface": "高",
|
||
"performance": "中",
|
||
"security": "高",
|
||
"reliability": "中",
|
||
"other": "低",
|
||
}
|
||
|
||
UNKNOWN_INTERFACE_VALUES = {"", "未知", "unknown", "n/a", "-", "--", "无", "none", "null"}
|
||
|
||
def __init__(self) -> None:
|
||
ToolRegistry.register(self.DEFINITION)
|
||
|
||
def run(self, file_path: str, enable_llm: bool = True, model_profile: Any = None) -> Dict[str, Any]:
|
||
if not enable_llm:
|
||
raise ValueError("当前版本仅支持LLM模式,请将 enable_llm 设为 true")
|
||
|
||
config = self._load_config(model_profile=model_profile)
|
||
llm = self._build_llm(config, enable_llm=enable_llm)
|
||
|
||
parser = create_parser(file_path)
|
||
if llm is not None:
|
||
parser.set_llm(llm)
|
||
|
||
sections = parser.parse()
|
||
document_title = parser.get_document_title() or Path(file_path).name
|
||
|
||
extractor = RequirementExtractor(config, llm=llm)
|
||
extracted = extractor.extract_from_sections(sections)
|
||
stats = extractor.get_statistics()
|
||
|
||
generator = JSONGenerator(config)
|
||
raw_output = generator.generate(sections, extracted, document_title)
|
||
|
||
requirements = self._normalize_requirements(extracted)
|
||
|
||
return {
|
||
"document_name": Path(file_path).name,
|
||
"document_title": document_title,
|
||
"generated_at": raw_output.get("文档元数据", {}).get("生成时间"),
|
||
"requirements": requirements,
|
||
"statistics": stats,
|
||
"raw_output": raw_output,
|
||
}
|
||
|
||
def _normalize_requirements(self, extracted: List[Any]) -> List[Dict[str, Any]]:
|
||
normalized: List[Dict[str, Any]] = []
|
||
for index, req in enumerate(extracted, start=1):
|
||
description = (req.description or "").strip()
|
||
title = self._build_short_title(description, index)
|
||
requirement_type = self._normalize_requirement_type(
|
||
req_type=getattr(req, "type", "functional"),
|
||
interface_name=getattr(req, "interface_name", ""),
|
||
interface_type=getattr(req, "interface_type", ""),
|
||
data_source=getattr(req, "source", ""),
|
||
data_destination=getattr(req, "destination", ""),
|
||
)
|
||
source_field = f"{req.section_number} {req.section_title}".strip() or "文档解析"
|
||
normalized.append(
|
||
{
|
||
"id": req.id,
|
||
"title": title,
|
||
"description": description,
|
||
"priority": "中",
|
||
"acceptance_criteria": [description] if description else ["待补充验收标准"],
|
||
"source_field": source_field,
|
||
"section_uid": req.section_uid,
|
||
"section_number": req.section_number,
|
||
"section_title": req.section_title,
|
||
"requirement_type": requirement_type,
|
||
"interface_name": req.interface_name if requirement_type == "interface" else "",
|
||
"interface_type": req.interface_type if requirement_type == "interface" else "",
|
||
"data_source": req.source if requirement_type == "interface" else "",
|
||
"data_destination": req.destination if requirement_type == "interface" else "",
|
||
"sort_order": index,
|
||
}
|
||
)
|
||
return normalized
|
||
|
||
def _normalize_requirement_type(
|
||
self,
|
||
req_type: Any,
|
||
interface_name: Any,
|
||
interface_type: Any,
|
||
data_source: Any,
|
||
data_destination: Any,
|
||
) -> str:
|
||
raw_type = str(req_type or "").strip()
|
||
mapping = {
|
||
"功能需求": "functional",
|
||
"接口需求": "interface",
|
||
"性能需求": "performance",
|
||
"安全需求": "security",
|
||
"可靠性需求": "reliability",
|
||
"其他需求": "other",
|
||
}
|
||
normalized_type = mapping.get(raw_type, raw_type)
|
||
if normalized_type not in self.PRIORITY_BY_TYPE:
|
||
normalized_type = "functional"
|
||
|
||
fields = [interface_name, interface_type, data_source, data_destination]
|
||
has_interface_fields = any(
|
||
str(value or "").strip().lower() not in self.UNKNOWN_INTERFACE_VALUES for value in fields
|
||
)
|
||
|
||
if normalized_type == "interface" or has_interface_fields:
|
||
return "interface"
|
||
return normalized_type
|
||
|
||
def _build_short_title(self, description: str, index: int) -> str:
|
||
text = (description or "").strip()
|
||
if not text:
|
||
return f"需求项 {index}"
|
||
for separator in ("。", ";", "\n", ";", "."):
|
||
if separator in text:
|
||
text = text.split(separator, 1)[0].strip()
|
||
break
|
||
if len(text) <= 20:
|
||
return text
|
||
return f"{text[:20].rstrip()}"
|
||
|
||
def _load_config(self, model_profile: Any = None) -> Dict[str, Any]:
|
||
config_path = Path(__file__).with_name("default_config.yaml")
|
||
if config_path.exists():
|
||
with config_path.open("r", encoding="utf-8") as handle:
|
||
config = yaml.safe_load(handle) or {}
|
||
else:
|
||
config = {}
|
||
|
||
config.setdefault("llm", {})
|
||
if model_profile is not None:
|
||
config["llm"]["model"] = getattr(model_profile, "chat_model", None) or settings.DASH_SCOPE_CHAT_MODEL
|
||
config["llm"]["api_key"] = getattr(model_profile, "api_key", "") or ""
|
||
config["llm"]["api_base"] = getattr(model_profile, "api_base", None) or settings.DASH_SCOPE_API_BASE
|
||
else:
|
||
config["llm"]["model"] = settings.DASH_SCOPE_CHAT_MODEL or settings.OPENAI_MODEL
|
||
config["llm"]["api_key"] = settings.DASH_SCOPE_API_KEY or os.getenv("DASHSCOPE_API_KEY", "")
|
||
config["llm"]["api_base"] = settings.DASH_SCOPE_API_BASE
|
||
config["llm"]["enabled"] = bool(config["llm"].get("api_key"))
|
||
return config
|
||
|
||
def _build_llm(self, config: Dict[str, Any], enable_llm: bool) -> QwenLLM | None:
|
||
if not enable_llm:
|
||
raise ValueError("当前版本仅支持LLM模式")
|
||
|
||
llm_cfg = config.get("llm", {})
|
||
api_key = llm_cfg.get("api_key")
|
||
if not api_key:
|
||
raise ValueError("未配置模型 API 密钥,请先在 API 密钥页面新增并启用模型配置。")
|
||
|
||
return QwenLLM(
|
||
api_key=api_key,
|
||
model=llm_cfg.get("model", "qwen3-max"),
|
||
api_endpoint=llm_cfg.get("api_base") or settings.DASH_SCOPE_API_BASE,
|
||
temperature=llm_cfg.get("temperature", 0.3),
|
||
max_tokens=llm_cfg.get("max_tokens", 1024),
|
||
)
|
||
|
||
|
||
_SRS_TOOL_SINGLETON: SRSTool | None = None
|
||
|
||
|
||
def get_srs_tool() -> SRSTool:
|
||
global _SRS_TOOL_SINGLETON
|
||
if _SRS_TOOL_SINGLETON is None:
|
||
_SRS_TOOL_SINGLETON = SRSTool()
|
||
return _SRS_TOOL_SINGLETON
|