Files
rag_agent/rag-web-ui/backend/app/tools/srs_reqs_qwen/tool.py

152 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Dict, List
import yaml
from app.core.config import settings
from app.tools.base import ToolDefinition
from app.tools.registry import ToolRegistry
from app.tools.srs_reqs_qwen.src.document_parser import create_parser
from app.tools.srs_reqs_qwen.src.json_generator import JSONGenerator
from app.tools.srs_reqs_qwen.src.llm_interface import QwenLLM
from app.tools.srs_reqs_qwen.src.requirement_extractor import RequirementExtractor
class SRSTool:
TOOL_NAME = "srs.requirement_extractor"
DEFINITION = ToolDefinition(
name=TOOL_NAME,
version="1.0.0",
description="Extract structured requirements from SRS documents.",
input_schema={
"type": "object",
"properties": {
"file_path": {"type": "string"},
"enable_llm": {"type": "boolean"},
},
"required": ["file_path"],
},
output_schema={
"type": "object",
"properties": {
"document_name": {"type": "string"},
"generated_at": {"type": "string"},
"requirements": {"type": "array"},
"statistics": {"type": "object"},
"raw_output": {"type": "object"},
},
},
)
PRIORITY_BY_TYPE = {
"functional": "",
"interface": "",
"performance": "",
"security": "",
"reliability": "",
"other": "",
}
def __init__(self) -> None:
ToolRegistry.register(self.DEFINITION)
def run(self, file_path: str, enable_llm: bool = True) -> Dict[str, Any]:
if not enable_llm:
raise ValueError("当前版本仅支持LLM模式请将 enable_llm 设为 true")
config = self._load_config()
llm = self._build_llm(config, enable_llm=enable_llm)
parser = create_parser(file_path)
if llm is not None:
parser.set_llm(llm)
sections = parser.parse()
document_title = parser.get_document_title() or Path(file_path).name
extractor = RequirementExtractor(config, llm=llm)
extracted = extractor.extract_from_sections(sections)
stats = extractor.get_statistics()
generator = JSONGenerator(config)
raw_output = generator.generate(sections, extracted, document_title)
requirements = self._normalize_requirements(extracted)
return {
"document_name": Path(file_path).name,
"document_title": document_title,
"generated_at": raw_output.get("文档元数据", {}).get("生成时间"),
"requirements": requirements,
"statistics": stats,
"raw_output": raw_output,
}
def _normalize_requirements(self, extracted: List[Any]) -> List[Dict[str, Any]]:
normalized: List[Dict[str, Any]] = []
for index, req in enumerate(extracted, start=1):
description = (req.description or "").strip()
title = description[:40] if description else f"需求项 {index}"
source_field = f"{req.section_number} {req.section_title}".strip() or "文档解析"
normalized.append(
{
"id": req.id,
"title": title,
"description": description,
"priority": self.PRIORITY_BY_TYPE.get(req.type, ""),
"acceptance_criteria": [description] if description else ["待补充验收标准"],
"source_field": source_field,
"section_number": req.section_number,
"section_title": req.section_title,
"requirement_type": req.type,
"sort_order": index,
}
)
return normalized
def _load_config(self) -> Dict[str, Any]:
config_path = Path(__file__).with_name("default_config.yaml")
if config_path.exists():
with config_path.open("r", encoding="utf-8") as handle:
config = yaml.safe_load(handle) or {}
else:
config = {}
config.setdefault("llm", {})
config["llm"]["model"] = settings.DASH_SCOPE_CHAT_MODEL or settings.OPENAI_MODEL
config["llm"]["api_key"] = settings.DASH_SCOPE_API_KEY or os.getenv("DASHSCOPE_API_KEY", "")
config["llm"]["api_base"] = settings.DASH_SCOPE_API_BASE
config["llm"]["enabled"] = bool(config["llm"].get("api_key"))
return config
def _build_llm(self, config: Dict[str, Any], enable_llm: bool) -> QwenLLM | None:
if not enable_llm:
raise ValueError("当前版本仅支持LLM模式")
llm_cfg = config.get("llm", {})
api_key = llm_cfg.get("api_key")
if not api_key:
raise ValueError("未配置API密钥请设置 DASH_SCOPE_API_KEY 或 DASHSCOPE_API_KEY")
return QwenLLM(
api_key=api_key,
model=llm_cfg.get("model", "qwen3-max"),
api_endpoint=llm_cfg.get("api_base") or settings.DASH_SCOPE_API_BASE,
temperature=llm_cfg.get("temperature", 0.3),
max_tokens=llm_cfg.get("max_tokens", 1024),
)
_SRS_TOOL_SINGLETON: SRSTool | None = None
def get_srs_tool() -> SRSTool:
global _SRS_TOOL_SINGLETON
if _SRS_TOOL_SINGLETON is None:
_SRS_TOOL_SINGLETON = SRSTool()
return _SRS_TOOL_SINGLETON