Files
rag_agent/rag-web-ui/backend/app/tools/srs_reqs_qwen/tool.py
2026-04-13 11:34:23 +08:00

149 lines
5.2 KiB
Python

from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Dict, List
import yaml
from app.core.config import settings
from app.tools.base import ToolDefinition
from app.tools.registry import ToolRegistry
from app.tools.srs_reqs_qwen.src.document_parser import create_parser
from app.tools.srs_reqs_qwen.src.json_generator import JSONGenerator
from app.tools.srs_reqs_qwen.src.llm_interface import QwenLLM
from app.tools.srs_reqs_qwen.src.requirement_extractor import RequirementExtractor
class SRSTool:
TOOL_NAME = "srs.requirement_extractor"
DEFINITION = ToolDefinition(
name=TOOL_NAME,
version="1.0.0",
description="Extract structured requirements from SRS documents.",
input_schema={
"type": "object",
"properties": {
"file_path": {"type": "string"},
"enable_llm": {"type": "boolean"},
},
"required": ["file_path"],
},
output_schema={
"type": "object",
"properties": {
"document_name": {"type": "string"},
"generated_at": {"type": "string"},
"requirements": {"type": "array"},
"statistics": {"type": "object"},
"raw_output": {"type": "object"},
},
},
)
PRIORITY_BY_TYPE = {
"functional": "",
"interface": "",
"performance": "",
"security": "",
"reliability": "",
"other": "",
}
def __init__(self) -> None:
ToolRegistry.register(self.DEFINITION)
def run(self, file_path: str, enable_llm: bool = True) -> Dict[str, Any]:
config = self._load_config()
llm = self._build_llm(config, enable_llm=enable_llm)
parser = create_parser(file_path)
if llm is not None:
parser.set_llm(llm)
sections = parser.parse()
document_title = parser.get_document_title() or Path(file_path).name
extractor = RequirementExtractor(config, llm=llm)
extracted = extractor.extract_from_sections(sections)
stats = extractor.get_statistics()
generator = JSONGenerator(config)
raw_output = generator.generate(sections, extracted, document_title)
requirements = self._normalize_requirements(extracted)
return {
"document_name": Path(file_path).name,
"document_title": document_title,
"generated_at": raw_output.get("文档元数据", {}).get("生成时间"),
"requirements": requirements,
"statistics": stats,
"raw_output": raw_output,
}
def _normalize_requirements(self, extracted: List[Any]) -> List[Dict[str, Any]]:
normalized: List[Dict[str, Any]] = []
for index, req in enumerate(extracted, start=1):
description = (req.description or "").strip()
title = description[:40] if description else f"需求项 {index}"
source_field = f"{req.section_number} {req.section_title}".strip() or "文档解析"
normalized.append(
{
"id": req.id,
"title": title,
"description": description,
"priority": self.PRIORITY_BY_TYPE.get(req.type, ""),
"acceptance_criteria": [description] if description else ["待补充验收标准"],
"source_field": source_field,
"section_number": req.section_number,
"section_title": req.section_title,
"requirement_type": req.type,
"sort_order": index,
}
)
return normalized
def _load_config(self) -> Dict[str, Any]:
config_path = Path(__file__).with_name("default_config.yaml")
if config_path.exists():
with config_path.open("r", encoding="utf-8") as handle:
config = yaml.safe_load(handle) or {}
else:
config = {}
config.setdefault("llm", {})
config["llm"]["model"] = settings.DASH_SCOPE_CHAT_MODEL or settings.OPENAI_MODEL
config["llm"]["api_key"] = settings.DASH_SCOPE_API_KEY or os.getenv("DASHSCOPE_API_KEY", "")
config["llm"]["api_base"] = settings.DASH_SCOPE_API_BASE
config["llm"]["enabled"] = bool(config["llm"].get("api_key"))
return config
def _build_llm(self, config: Dict[str, Any], enable_llm: bool) -> QwenLLM | None:
if not enable_llm:
return None
llm_cfg = config.get("llm", {})
api_key = llm_cfg.get("api_key")
if not api_key:
return None
return QwenLLM(
api_key=api_key,
model=llm_cfg.get("model", "qwen3-max"),
api_endpoint=llm_cfg.get("api_base") or settings.DASH_SCOPE_API_BASE,
temperature=llm_cfg.get("temperature", 0.3),
max_tokens=llm_cfg.get("max_tokens", 1024),
)
_SRS_TOOL_SINGLETON: SRSTool | None = None
def get_srs_tool() -> SRSTool:
global _SRS_TOOL_SINGLETON
if _SRS_TOOL_SINGLETON is None:
_SRS_TOOL_SINGLETON = SRSTool()
return _SRS_TOOL_SINGLETON