init. project

This commit is contained in:
2026-04-13 11:34:23 +08:00
commit c7c0659a85
202 changed files with 31196 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
from app.tools.base import ToolDefinition
from app.tools.registry import ToolRegistry
__all__ = ["ToolDefinition", "ToolRegistry"]

View File

@@ -0,0 +1,11 @@
from dataclasses import dataclass
from typing import Any, Dict
@dataclass(frozen=True)
class ToolDefinition:
name: str
version: str
description: str
input_schema: Dict[str, Any]
output_schema: Dict[str, Any]

View File

@@ -0,0 +1,19 @@
from typing import Dict, List
from app.tools.base import ToolDefinition
class ToolRegistry:
_tools: Dict[str, ToolDefinition] = {}
@classmethod
def register(cls, definition: ToolDefinition) -> None:
cls._tools[definition.name] = definition
@classmethod
def get(cls, name: str) -> ToolDefinition:
return cls._tools[name]
@classmethod
def list(cls) -> List[ToolDefinition]:
return list(cls._tools.values())

View File

@@ -0,0 +1,3 @@
from app.tools.srs_reqs_qwen.tool import SRSTool, get_srs_tool
__all__ = ["SRSTool", "get_srs_tool"]

View File

@@ -0,0 +1,102 @@
# 配置文件 - SRS 需求文档解析工具 (LLM增强版)
# Configuration file for SRS Requirement Document Parser (LLM Enhanced Version)
# LLM配置 - 阿里云千问
llm:
# 是否启用LLM设为false则使用纯规则提取
enabled: true
# LLM提供商qwen阿里云千问
provider: "qwen"
# 模型名称
model: "qwen3-max"
# API密钥统一由 rag-web-ui 的环境变量提供
api_key: ""
# 可选参数
temperature: 0.3
max_tokens: 1024
# 文档解析配置
document:
supported_formats:
- ".pdf"
- ".docx"
# 标题识别的样式列表
heading_styles:
- "Heading 1"
- "Heading 2"
- "Heading 3"
- "Heading 4"
- "Heading 5"
# 需要过滤的非需求章节GJB438B标准
non_requirement_sections:
- "标识"
- "系统概述"
- "文档概述"
- "引用文档"
- "合格性规定"
- "需求可追踪性"
- "注释"
- "附录"
# 需求提取配置
extraction:
# 需求类型关键字(用于自动判断需求类型)
requirement_types:
功能需求:
prefix: "FR"
keywords: ["功能", "feature", "requirement", "CSCI组成", "控制", "处理", "监测", "显示"]
priority: 1
接口需求:
prefix: "IR"
keywords: ["接口", "interface", "api", "外部接口", "内部接口", "CAN", "以太网", "通信"]
priority: 2
性能需求:
prefix: "PR"
keywords: ["性能", "performance", "速度", "响应时间", "吞吐量"]
priority: 3
安全需求:
prefix: "SR"
keywords: ["安全", "security", "安全性", "报警"]
priority: 4
可靠性需求:
prefix: "RR"
keywords: ["可靠", "reliability", "容错", "恢复", "冗余"]
priority: 5
其他需求:
prefix: "OR"
keywords: ["约束", "资源", "适应性", "保密", "环境", "计算机", "质量", "设计", "人员", "培训", "保障", "验收", "交付"]
priority: 6
splitter:
enabled: true
max_sentence_len: 120
min_clause_len: 12
semantic_guard:
enabled: true
preserve_condition_action_chain: true
preserve_alarm_chain: true
table_strategy:
llm_semantic_enabled: true
sequence_table_merge: "single_requirement"
merge_time_series_rows_min: 3
rewrite_policy:
llm_light_rewrite_enabled: true
preserve_ratio_min: 0.65
max_length_growth_ratio: 1.25
renumber_policy:
enabled: true
mode: "section_continuous"
# 输出配置
output:
format: "json"
indent: 2
# 是否美化输出(格式化)
pretty_print: true
# 是否包含元数据
include_metadata: true
# 日志配置
logging:
level: "INFO" # DEBUG, INFO, WARNING, ERROR
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
file: "srs_parser.log"

View File

@@ -0,0 +1,26 @@
# src/__init__.py
"""
SRS 需求文档解析工具包
"""
__version__ = "1.0.0"
__author__ = "SRS Parser Team"
from .document_parser import DocumentParser
from .llm_interface import LLMInterface, QwenLLM
from .requirement_extractor import RequirementExtractor
from .json_generator import JSONGenerator
from .settings import AppSettings
from .requirement_splitter import RequirementSplitter
from .requirement_id_generator import RequirementIDGenerator
__all__ = [
'DocumentParser',
'LLMInterface',
'QwenLLM',
'RequirementExtractor',
'JSONGenerator',
'AppSettings',
'RequirementSplitter',
'RequirementIDGenerator',
]

View File

@@ -0,0 +1,709 @@
# -*- coding: utf-8 -*-
"""
文档解析模块 - LLM增强版
支持PDF和Docx格式针对GJB438B标准SRS文档优化
"""
import os
import re
import logging
import importlib
from abc import ABC, abstractmethod
from typing import List, Dict, Tuple, Optional, Any
from pathlib import Path
try:
from docx import Document
HAS_DOCX = True
except ImportError:
HAS_DOCX = False
try:
import PyPDF2
HAS_PDF = True
except ImportError:
HAS_PDF = False
HAS_PDF_TABLE = importlib.util.find_spec("pdfplumber") is not None
logger = logging.getLogger(__name__)
class Section:
"""表示文档中的一个章节"""
def __init__(self, level: int, title: str, number: str = None, content: str = "", uid: str = ""):
self.level = level
self.title = title
self.number = number
self.content = content
self.uid = uid
self.parent = None
self.children = []
self.tables = []
self.blocks = []
def add_child(self, child: 'Section') -> None:
self.children.append(child)
child.parent = self
def add_content(self, text: str) -> None:
text = (text or "").strip()
if not text:
return
if self.content:
self.content += "\n" + text
else:
self.content = text
self.blocks.append({"type": "text", "text": text})
def add_table(self, table_data: List[List[str]]) -> None:
if not table_data:
return
self.tables.append(table_data)
table_index = len(self.tables) - 1
self.blocks.append({"type": "table", "table_index": table_index, "table": table_data})
def generate_auto_number(self, parent_number: str = "", sibling_index: int = 1) -> None:
"""
自动生成章节编号(当章节没有编号时)
Args:
parent_number: 父章节编号
sibling_index: 在同级章节中的序号从1开始
"""
if not self.number:
if parent_number:
self.number = f"{parent_number}.{sibling_index}"
else:
self.number = str(sibling_index)
def __repr__(self) -> str:
return f"Section(level={self.level}, number='{self.number}', title='{self.title}')"
class DocumentParser(ABC):
"""文档解析器基类"""
def __init__(self, file_path: str):
self.file_path = file_path
self.sections: List[Section] = []
self.document_title = ""
self.raw_text = ""
self.llm = None
self._uid_counter = 0
def set_llm(self, llm) -> None:
"""设置LLM实例"""
self.llm = llm
@abstractmethod
def parse(self) -> List[Section]:
pass
def get_document_title(self) -> str:
return self.document_title
def _next_uid(self) -> str:
self._uid_counter += 1
return f"sec-{self._uid_counter}"
def _auto_number_sections(self, sections: List[Section], parent_number: str = "") -> None:
"""
为没有编号的章节自动生成编号
规则使用Word样式确定级别跳过前置章节目录、概述等
从第一个正文章节(如"外部接口"开始编号为1
Args:
sections: 章节列表
parent_number: 父章节编号
"""
# 仅在顶级章节重编号
if not parent_number:
# 前置章节关键词(需要跳过的)
skip_keywords = ['目录', '封面', '扉页', '未命名', '', '']
# 正文章节关键词(遇到这些说明正文开始)
content_keywords = ['外部接口', '接口', '软件需求', '需求', '功能', '性能', '设计', '概述', '标识', '引言']
start_index = 0
for idx, section in enumerate(sections):
# 优先检查是否是正文章节
is_content = any(kw in section.title for kw in content_keywords)
if is_content and section.level == 1:
start_index = idx
break
# 重新编号所有章节
counter = 1
for i, section in enumerate(sections):
if i < start_index:
# 前置章节不编号
section.number = ""
else:
# 正文章节顶级章节从1开始编号
if section.level == 1:
section.number = str(counter)
counter += 1
# 递归处理子章节
if section.children:
self._auto_number_sections(section.children, section.number)
else:
# 子章节编号
for i, section in enumerate(sections, 1):
if not section.number or self._is_chinese_number(section.number):
section.generate_auto_number(parent_number, i)
if section.children:
self._auto_number_sections(section.children, section.number)
def _is_chinese_number(self, text: str) -> bool:
"""检查是否是中文数字编号"""
chinese_numbers = '一二三四五六七八九十百千万'
return text and all(c in chinese_numbers for c in text)
class DocxParser(DocumentParser):
"""DOCX格式文档解析器"""
def __init__(self, file_path: str):
if not HAS_DOCX:
raise ImportError("python-docx库未安装请运行: pip install python-docx")
super().__init__(file_path)
self.document = None
def parse(self) -> List[Section]:
try:
self.document = Document(self.file_path)
self.document_title = self.document.core_properties.title or "SRS Document"
section_stack = {}
for block in self._iter_block_items(self.document):
from docx.text.paragraph import Paragraph
from docx.table import Table
if isinstance(block, Paragraph):
text = block.text.strip()
if not text:
continue
heading_info = self._parse_heading(block, text)
if heading_info:
number, title, level = heading_info
section = Section(level=level, title=title, number=number, uid=self._next_uid())
if level == 1 or not section_stack:
self.sections.append(section)
section_stack = {1: section}
else:
parent_level = level - 1
while parent_level >= 1 and parent_level not in section_stack:
parent_level -= 1
if parent_level >= 1 and parent_level in section_stack:
section_stack[parent_level].add_child(section)
elif self.sections:
self.sections[-1].add_child(section)
section_stack[level] = section
for l in list(section_stack.keys()):
if l > level:
del section_stack[l]
else:
# 添加内容到当前章节
if section_stack:
max_level = max(section_stack.keys())
section_stack[max_level].add_content(text)
else:
# 没有标题时,创建默认章节
default_section = Section(level=1, title="未命名章节", number="", uid=self._next_uid())
default_section.add_content(text)
self.sections.append(default_section)
section_stack = {1: default_section}
elif isinstance(block, Table):
# 表格处理
table_data = self._extract_table_data(block)
if table_data:
if section_stack:
max_level = max(section_stack.keys())
section_stack[max_level].add_table(table_data)
else:
default_section = Section(level=1, title="未命名章节", number="", uid=self._next_uid())
default_section.add_table(table_data)
self.sections.append(default_section)
section_stack = {1: default_section}
# 为没有编号的章节自动生成编号
self._auto_number_sections(self.sections)
logger.info(f"完成Docx解析提取{len(self.sections)}个顶级章节")
return self.sections
except Exception as e:
logger.error(f"解析Docx文档失败: {e}")
raise
def _is_valid_heading(self, text: str) -> bool:
"""检查是否是有效的标题"""
if len(text) > 120 or '...' in text:
return False
# 标题应包含中文或字母
if not re.search(r'[\u4e00-\u9fa5A-Za-z]', text):
return False
# 过滤目录项(标题后跟页码,如"概述 2"或"概述 . . . . 2"
if re.search(r'\s{2,}\d+$', text): # 多个空格后跟数字结尾
return False
if re.search(r'[\.。\s]+\d+$', text): # 点号或空格后跟数字结尾
return False
return True
def _parse_heading(self, paragraph, text: str) -> Optional[Tuple[str, str, int]]:
"""解析标题,返回(编号, 标题, 级别)"""
style_name = paragraph.style.name if paragraph.style else ""
is_heading_style = style_name.lower().startswith('heading') if style_name else False
# 数字编号标题
match = re.match(r'^(\d+(?:\.\d+)*)\s*[\.、]?\s*(.+)$', text)
if match and self._is_valid_heading(match.group(2)):
number = match.group(1)
title = match.group(2).strip()
level = len(number.split('.'))
return number, title, level
# 中文编号标题
match = re.match(r'^([一二三四五六七八九十]+)[、\.]+\s*(.+)$', text)
if match and self._is_valid_heading(match.group(2)):
number = match.group(1)
title = match.group(2).strip()
level = 1
return number, title, level
# 样式标题
if is_heading_style and self._is_valid_heading(text):
level = 1
level_match = re.search(r'(\d+)', style_name)
if level_match:
level = int(level_match.group(1))
return "", text, level
return None
def _iter_block_items(self, parent):
"""按文档顺序迭代段落和表格"""
from docx.text.paragraph import Paragraph
from docx.table import Table
from docx.oxml.text.paragraph import CT_P
from docx.oxml.table import CT_Tbl
for child in parent.element.body.iterchildren():
if isinstance(child, CT_P):
yield Paragraph(child, parent)
elif isinstance(child, CT_Tbl):
yield Table(child, parent)
def _extract_table_data(self, table) -> List[List[str]]:
"""提取表格数据"""
table_data = []
for row in table.rows:
row_data = []
for cell in row.cells:
text = cell.text.replace('\n', ' ').strip()
text = re.sub(r'\s+', ' ', text)
row_data.append(text)
if any(cell for cell in row_data):
table_data.append(row_data)
return table_data
class PDFParser(DocumentParser):
"""PDF格式文档解析器 - LLM增强版"""
# GJB438B标准SRS文档的有效章节标题关键词
VALID_TITLE_KEYWORDS = [
'范围', '标识', '概述', '引用', '文档',
'需求', '功能', '接口', '性能', '安全', '保密',
'环境', '资源', '质量', '设计', '约束',
'人员', '培训', '保障', '验收', '交付', '包装',
'优先', '关键', '合格', '追踪', '注释',
'CSCI', '计算机', '软件', '硬件', '通信', '通讯',
'数据', '适应', '可靠', '内部', '外部',
'描述', '要求', '规定', '说明', '定义',
'电场', '防护', '装置', '控制', '监控', '显控'
]
# 明显无效的章节标题模式(噪声)
INVALID_TITLE_PATTERNS = [
'本文档可作为', '参比电位', '补偿电流', '以太网',
'电源', '软件接', '功能\\', '性能 \\', '输入/输出 \\',
'数据处理要求 \\', '固件 \\', '质量控制要求',
'信安科技', '浙江', '公司'
]
def __init__(self, file_path: str):
if not HAS_PDF:
raise ImportError("PyPDF2库未安装请运行: pip install PyPDF2")
super().__init__(file_path)
self.document_title = "SRS Document"
self._page_texts: List[str] = []
def parse(self) -> List[Section]:
"""解析PDF文档"""
try:
# 1. 提取所有文本
self.raw_text = self._extract_all_text()
# 2. 清洗文本
cleaned_text = self._clean_text(self.raw_text)
# 3. 识别章节结构
self.sections = self._parse_sections(cleaned_text)
# 4. 使用LLM验证和清理章节如果可用
if self.llm:
self.sections = self._llm_validate_sections(self.sections)
# 章节识别失败时,创建兜底章节避免后续表格数据丢失。
if not self.sections:
fallback = Section(level=1, title="未命名章节", number="1", uid=self._next_uid())
if cleaned_text:
fallback.add_content(cleaned_text)
self.sections = [fallback]
# 5. 提取并挂接PDF表格到章节若依赖可用
pdf_tables = self._extract_pdf_tables()
if pdf_tables:
self._attach_pdf_tables_to_sections(pdf_tables)
# 6. 为没有编号的章节自动生成编号
self._auto_number_sections(self.sections)
logger.info(f"完成PDF解析提取{len(self.sections)}个顶级章节")
return self.sections
except Exception as e:
logger.error(f"解析PDF文档失败: {e}")
raise
def _extract_all_text(self) -> str:
"""从PDF提取所有文本"""
all_text = []
with open(self.file_path, 'rb') as f:
pdf_reader = PyPDF2.PdfReader(f)
for page in pdf_reader.pages:
text = page.extract_text()
if text:
all_text.append(text)
self._page_texts = all_text
return '\n'.join(all_text)
def _extract_pdf_tables(self) -> List[Dict[str, Any]]:
"""提取PDF中的表格数据。"""
if not HAS_PDF_TABLE:
logger.warning("未安装pdfplumber跳过PDF表格提取。可执行: pip install pdfplumber")
return []
tables: List[Dict[str, Any]] = []
try:
pdfplumber = importlib.import_module("pdfplumber")
with pdfplumber.open(self.file_path) as pdf:
for page_idx, page in enumerate(pdf.pages):
page_text = ""
if page_idx < len(self._page_texts):
page_text = self._page_texts[page_idx]
extracted_tables = page.extract_tables() or []
for table_idx, table in enumerate(extracted_tables):
cleaned_table: List[List[str]] = []
for row in table or []:
cells = [re.sub(r'\s+', ' ', str(cell or '')).strip() for cell in row]
if any(cells):
cleaned_table.append(cells)
if cleaned_table:
tables.append(
{
"page_idx": page_idx,
"table_idx": table_idx,
"page_text": page_text,
"data": cleaned_table,
}
)
except Exception as e:
logger.warning(f"PDF表格提取失败继续纯文本流程: {e}")
return []
logger.info(f"PDF表格提取完成{len(tables)}个表格")
return tables
def _attach_pdf_tables_to_sections(self, tables: List[Dict[str, Any]]) -> None:
"""将提取出的PDF表格挂接到最匹配的章节。"""
flat_sections = self._flatten_sections(self.sections)
if not flat_sections:
return
last_section: Optional[Section] = None
for table in tables:
matched = self._match_table_section(table.get("page_text", ""), flat_sections)
target = matched or last_section or flat_sections[0]
target.add_table(table["data"])
last_section = target
def _flatten_sections(self, sections: List[Section]) -> List[Section]:
"""按文档顺序拉平章节树。"""
result: List[Section] = []
for section in sections:
result.append(section)
if section.children:
result.extend(self._flatten_sections(section.children))
return result
def _match_table_section(self, page_text: str, sections: List[Section]) -> Optional[Section]:
"""基于页文本匹配表格归属章节。"""
normalized_page = re.sub(r"\s+", "", (page_text or "")).lower()
if not normalized_page:
return None
matched: Optional[Section] = None
matched_score = -1
for section in sections:
title = (section.title or "").strip()
if not title:
continue
number = (section.number or "").strip()
candidates = [title]
if number:
candidates.append(f"{number}{title}")
candidates.append(f"{number} {title}")
for candidate in candidates:
normalized_candidate = re.sub(r"\s+", "", candidate).lower()
if normalized_candidate and normalized_candidate in normalized_page:
score = len(normalized_candidate)
if score > matched_score:
matched = section
matched_score = score
return matched
def _clean_text(self, text: str) -> str:
"""清洗PDF提取的文本"""
lines = text.split('\n')
cleaned_lines = []
for line in lines:
line = line.strip()
if not line:
continue
# 跳过页码通常是1-3位数字单独一行
if re.match(r'^\d{1,3}$', line):
continue
# 跳过目录行
if line.count('.') > 10 and '...' in line:
continue
cleaned_lines.append(line)
return '\n'.join(cleaned_lines)
def _parse_sections(self, text: str) -> List[Section]:
"""解析章节结构"""
sections = []
section_stack = {}
lines = text.split('\n')
current_section = None
content_buffer = []
found_sections = set()
for line in lines:
line = line.strip()
if not line:
continue
# 尝试匹配章节标题
section_info = self._match_section_header(line, found_sections)
if section_info:
number, title = section_info
level = len(number.split('.'))
# 保存之前章节的内容
if current_section and content_buffer:
current_section.add_content('\n'.join(content_buffer))
content_buffer = []
# 创建新章节
section = Section(level=level, title=title, number=number, uid=self._next_uid())
found_sections.add(number)
# 建立层次结构
if level == 1:
sections.append(section)
section_stack = {1: section}
else:
parent_level = level - 1
while parent_level >= 1 and parent_level not in section_stack:
parent_level -= 1
if parent_level >= 1 and parent_level in section_stack:
section_stack[parent_level].add_child(section)
elif sections:
sections[-1].add_child(section)
else:
sections.append(section)
section_stack = {1: section}
section_stack[level] = section
for l in list(section_stack.keys()):
if l > level:
del section_stack[l]
current_section = section
else:
# 收集内容
if line and not self._is_noise(line):
content_buffer.append(line)
# 保存最后一个章节的内容
if current_section and content_buffer:
current_section.add_content('\n'.join(content_buffer))
return sections
def _match_section_header(self, line: str, found_sections: set) -> Optional[Tuple[str, str]]:
"""
匹配章节标题
Returns:
(章节编号, 章节标题) 或 None
"""
# 模式: "3.1功能需求" 或 "3.1 功能需求"
match = re.match(r'^(\d+(?:\.\d+)*)\s*(.+)$', line)
if not match:
return None
number = match.group(1)
title = match.group(2).strip()
# 排除目录行
if '...' in title or title.count('.') > 5:
return None
# 验证章节编号
parts = number.split('.')
first_part = int(parts[0])
# 放宽一级章节编号范围非严格GJB结构
if first_part < 1 or first_part > 30:
return None
# 检查子部分是否合理
for part in parts[1:]:
if int(part) > 20:
return None
# 避免重复
if number in found_sections:
return None
# 标题长度检查
if len(title) > 60 or len(title) < 2:
return None
# 放宽标题字符要求兼容部分PDF字体导致中文抽取异常的情况
if not re.search(r'[\u4e00-\u9fa5A-Za-z]', title):
return None
# 检查是否包含无效模式
for invalid_pattern in self.INVALID_TITLE_PATTERNS:
if invalid_pattern in title:
return None
# 标题不能以数字开头
if title[0].isdigit():
return None
# 数字比例检查
digit_ratio = sum(c.isdigit() for c in title) / max(len(title), 1)
if digit_ratio > 0.3:
return None
# 检查标题是否包含反斜杠(通常是表格噪声)
if '\\' in title and '需求' not in title:
return None
return (number, title)
def _is_noise(self, line: str) -> bool:
"""检查是否是噪声行"""
# 纯数字行
if re.match(r'^[\d\s,.]+$', line):
return True
# 非常短的行
if len(line) < 3:
return True
# 罗马数字
if re.match(r'^[ivxIVX]+$', line):
return True
return False
def _llm_validate_sections(self, sections: List[Section]) -> List[Section]:
"""使用LLM验证章节是否有效"""
if not self.llm:
return sections
validated_sections = []
for section in sections:
# 验证顶级章节
if self._is_valid_section_with_llm(section):
# 递归验证子章节
section.children = self._validate_children(section.children)
validated_sections.append(section)
return validated_sections
def _validate_children(self, children: List[Section]) -> List[Section]:
"""递归验证子章节"""
validated = []
for child in children:
if self._is_valid_section_with_llm(child):
child.children = self._validate_children(child.children)
validated.append(child)
return validated
def _is_valid_section_with_llm(self, section: Section) -> bool:
"""使用LLM判断章节是否有效"""
# 先用规则快速过滤明显无效的章节
invalid_titles = [
'本文档可作为', '故障', '实时', '输入/输出',
'固件', '功能\\', '\\4.', '\\3.'
]
for invalid in invalid_titles:
if invalid in section.title:
logger.debug(f"过滤无效章节: {section.number} {section.title}")
return False
# 对于需求相关章节第3章额外验证
if section.number and section.number.startswith('3'):
# 检查标题是否看起来像是有效的需求章节标题
# 有效的标题应该是完整的中文短语
if '\\' in section.title or '/' in section.title:
if not any(kw in section.title for kw in ['输入', '输出', '接口']):
return False
return True
def create_parser(file_path: str) -> DocumentParser:
"""
工厂函数:根据文件扩展名创建相应的解析器
"""
ext = Path(file_path).suffix.lower()
if ext == '.docx':
return DocxParser(file_path)
elif ext == '.pdf':
return PDFParser(file_path)
else:
raise ValueError(f"不支持的文件格式: {ext}")

View File

@@ -0,0 +1,198 @@
# -*- coding: utf-8 -*-
"""
JSON生成器模块 - LLM增强版
将提取的需求和章节结构转换为结构化JSON输出
"""
import json
import logging
from datetime import datetime
from typing import List, Dict, Any, Optional
from .document_parser import Section
from .requirement_extractor import Requirement
from .settings import AppSettings
logger = logging.getLogger(__name__)
class JSONGenerator:
"""JSON输出生成器"""
def __init__(self, config: Dict = None):
self.config = config or {}
self.settings = AppSettings(self.config)
def generate(self, sections: List[Section], requirements: List[Requirement],
document_title: str = "SRS Document") -> Dict[str, Any]:
"""
生成JSON输出
Args:
sections: 章节列表
requirements: 需求列表
document_title: 文档标题
Returns:
结构化JSON字典
"""
# 按章节组织需求
reqs_by_section = self._group_requirements_by_section(requirements)
# 统计需求类型
type_stats = self._calculate_type_statistics(requirements)
# 构建输出结构
output = {
"文档元数据": {
"标题": document_title,
"生成时间": datetime.now().isoformat(),
"总需求数": len(requirements),
"需求类型统计": type_stats
},
"需求内容": self._build_requirement_content(sections, reqs_by_section)
}
logger.info(f"生成JSON输出{len(requirements)}个需求")
return output
def _group_requirements_by_section(self, requirements: List[Requirement]) -> Dict[str, List[Requirement]]:
"""按章节编号分组需求"""
grouped = {}
for req in requirements:
section_key = req.section_uid or req.section_number or 'unknown'
if section_key not in grouped:
grouped[section_key] = []
grouped[section_key].append(req)
return grouped
def _calculate_type_statistics(self, requirements: List[Requirement]) -> Dict[str, int]:
"""计算需求类型统计"""
stats = {}
for req in requirements:
type_chinese = self.settings.type_chinese.get(req.type, '其他需求')
if type_chinese not in stats:
stats[type_chinese] = 0
stats[type_chinese] += 1
return stats
def _should_include_section(self, section: Section) -> bool:
"""判断章节是否应该包含在输出中"""
return not self.settings.is_non_requirement_section(section.title)
def _build_requirement_content(self, sections: List[Section],
reqs_by_section: Dict[str, List[Requirement]]) -> Dict[str, Any]:
"""构建需求内容的层次结构"""
content = {}
for section in sections:
# 只处理需求相关章节
if not self._should_include_section(section):
# 但仍需检查子章节
for child in section.children:
child_content = self._build_section_content_recursive(child, reqs_by_section)
if child_content:
key = f"{child.number} {child.title}" if child.number else child.title
content[key] = child_content
continue
section_content = self._build_section_content_recursive(section, reqs_by_section)
if section_content:
key = f"{section.number} {section.title}" if section.number else section.title
content[key] = section_content
return content
def _build_section_content_recursive(self, section: Section,
reqs_by_section: Dict[str, List[Requirement]]) -> Optional[Dict[str, Any]]:
"""递归构建章节内容"""
# 检查是否应该包含此章节
if not self._should_include_section(section):
return None
# 章节基本信息
result = {
"章节信息": {
"章节编号": section.number or "",
"章节标题": section.title,
"章节级别": section.level
}
}
# 检查是否有子章节
has_valid_children = False
subsections = {}
for child in section.children:
child_content = self._build_section_content_recursive(child, reqs_by_section)
if child_content:
has_valid_children = True
key = f"{child.number} {child.title}" if child.number else child.title
subsections[key] = child_content
# 添加当前章节需求
reqs = reqs_by_section.get(section.uid or section.number or 'unknown', [])
reqs = sorted(reqs, key=lambda r: getattr(r, 'source_order', 0))
if reqs:
result["需求列表"] = []
for req in reqs:
# 需求类型放在最前面
type_chinese = self.settings.type_chinese.get(req.type, '功能需求')
req_dict = {
"需求类型": type_chinese,
"需求编号": req.id,
"需求描述": req.description
}
# 接口需求增加额外字段
if req.type == 'interface':
req_dict["接口名称"] = req.interface_name
req_dict["接口类型"] = req.interface_type
req_dict["来源"] = req.source
req_dict["目的地"] = req.destination
result["需求列表"].append(req_dict)
# 如果有子章节,添加子章节
if has_valid_children:
result["子章节"] = subsections
# 如果章节既没有需求也没有子章节返回None
if "需求列表" not in result and "子章节" not in result:
return None
return result
def save_to_file(self, output: Dict[str, Any], file_path: str) -> None:
"""
将输出保存到文件
Args:
output: 输出字典
file_path: 输出文件路径
"""
try:
output_cfg = self.config.get("output", {})
indent = output_cfg.get("indent", 2)
pretty = output_cfg.get("pretty_print", True)
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(output, f, ensure_ascii=False, indent=indent if pretty else None)
logger.info(f"成功保存JSON到: {file_path}")
except Exception as e:
logger.error(f"保存JSON文件失败: {e}")
raise
def generate_and_save(self, sections: List[Section], requirements: List[Requirement],
document_title: str, file_path: str) -> Dict[str, Any]:
"""
生成并保存JSON
Args:
sections: 章节列表
requirements: 需求列表
document_title: 文档标题
file_path: 输出文件路径
Returns:
生成的输出字典
"""
output = self.generate(sections, requirements, document_title)
self.save_to_file(output, file_path)
return output

View File

@@ -0,0 +1,197 @@
# src/llm_interface.py
"""
LLM接口模块 - 支持多个LLM提供商
"""
import logging
import json
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any
from .utils import get_env_or_config
logger = logging.getLogger(__name__)
class LLMInterface(ABC):
"""LLM接口基类"""
def __init__(self, api_key: str = None, model: str = None, **kwargs):
"""
初始化LLM接口
Args:
api_key: API密钥
model: 模型名称
**kwargs: 其他参数如temperature, max_tokens等
"""
self.api_key = api_key
self.model = model
self.extra_params = kwargs
@abstractmethod
def call(self, prompt: str) -> str:
"""
调用LLM API
Args:
prompt: 提示词
Returns:
LLM的响应文本
"""
pass
@abstractmethod
def call_json(self, prompt: str) -> Dict[str, Any]:
"""
调用LLM API并获取JSON格式的响应
Args:
prompt: 提示词
Returns:
解析后的JSON字典
"""
pass
def validate_config(self) -> bool:
"""验证配置是否完整"""
return bool(self.api_key and self.model)
class QwenLLM(LLMInterface):
"""阿里云千问LLM实现"""
def __init__(self, api_key: str = None, model: str = "qwen-plus",
api_endpoint: str = None, **kwargs):
"""
初始化千问LLM
Args:
api_key: 阿里云API密钥
model: 模型名称如qwen-plus, qwen-turbo
api_endpoint: API端点地址
**kwargs: 其他参数
"""
super().__init__(api_key, model, **kwargs)
self.api_endpoint = api_endpoint or "https://dashscope.aliyuncs.com/compatible-mode/v1"
self._check_dashscope_import()
def _check_dashscope_import(self) -> None:
"""检查dashscope库是否已安装"""
try:
import dashscope
self.dashscope = dashscope
except ImportError:
logger.error("dashscope库未安装请运行: pip install dashscope")
raise
def call(self, prompt: str) -> str:
"""
调用千问LLM
Args:
prompt: 提示词
Returns:
LLM的响应文本
"""
if not self.validate_config():
raise ValueError("LLM配置不完整api_key或model未设置")
try:
from dashscope import Generation
# 设置API密钥
self.dashscope.api_key = self.api_key
# 构建请求参数 - dashscope 1.7.0 格式
response = Generation.call(
model=self.model,
messages=[
{'role': 'user', 'content': prompt}
],
result_format='message' # 使用message格式
)
# 调试输出
logger.debug(f"API响应类型: {type(response)}")
logger.debug(f"API响应内容: {response}")
# 处理响应
if isinstance(response, dict):
# dict格式响应
status_code = response.get('status_code', 200)
if status_code == 200:
output = response.get('output', {})
if 'choices' in output:
return output['choices'][0]['message']['content']
elif 'text' in output:
return output['text']
else:
# 尝试直接获取text
return output.get('text', str(output))
else:
error_msg = response.get('message', response.get('code', 'Unknown error'))
logger.error(f"千问API返回错误: {error_msg}")
raise Exception(f"API调用失败: {error_msg}")
else:
# 对象格式响应
if hasattr(response, 'status_code') and response.status_code == 200:
output = response.output
if hasattr(output, 'choices'):
return output.choices[0].message.content
elif hasattr(output, 'text'):
return output.text
else:
return str(output)
elif hasattr(response, 'status_code'):
error_msg = getattr(response, 'message', str(response))
raise Exception(f"API调用失败: {error_msg}")
else:
return str(response)
except Exception as e:
logger.error(f"调用千问LLM失败: {e}")
raise
def call_json(self, prompt: str) -> Dict[str, Any]:
"""
调用千问LLM并获取JSON格式响应
Args:
prompt: 提示词
Returns:
解析后的JSON字典
"""
# 添加JSON格式要求到提示词
json_prompt = prompt + "\n\n请确保响应是有效的JSON格式。"
response = self.call(json_prompt)
try:
# 尝试解析JSON
# 首先尝试直接解析
return json.loads(response)
except json.JSONDecodeError:
# 尝试提取JSON代码块
try:
import re
# 查找JSON代码块
json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
if json_match:
return json.loads(json_match.group(1))
# 尝试查找任何JSON对象
json_match = re.search(r'\{.*\}', response, re.DOTALL)
if json_match:
return json.loads(json_match.group(0))
except Exception as e:
logger.warning(f"无法从响应中提取JSON: {e}")
# 如果都失败,返回错误信息
logger.error(f"无法解析LLM响应为JSON: {response}")
return {"error": "Failed to parse response as JSON", "raw_response": response}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
"""
需求编号生成与提取工具。
"""
import re
from typing import Optional, Tuple, Dict
class RequirementIDGenerator:
def __init__(self, type_prefix: Dict[str, str]):
self.type_prefix = type_prefix
def normalize(self, req_id: str) -> str:
if not req_id:
return ""
return str(req_id).strip()
def extract_from_text(self, text: str) -> Tuple[Optional[str], str]:
if not text:
return None, text
pattern1 = r"^\s*([A-Za-z]{2,10}[-_]\d+(?:[-.\d]+)*)\s*[:\)\]】]?\s*(.+)$"
match = re.match(pattern1, text)
if match:
return match.group(1).strip(), match.group(2).strip()
pattern2 = r"^\s*([A-Za-z]\d+)\s*[:\)\]】]?\s*(.+)$"
match = re.match(pattern2, text)
if match:
return match.group(1).strip(), match.group(2).strip()
pattern3 = r"^\s*([a-z0-9]{1,2}[\)])\s*(.+)$"
match = re.match(pattern3, text)
if match:
code = match.group(1).strip().rstrip(")")
return code, match.group(2).strip()
return None, text
def generate(
self,
req_type: str,
section_number: str,
index: int,
doc_req_id: str = "",
parent_req_id: str = "",
split_index: int = 1,
split_total: int = 1,
) -> str:
base_id = self._generate_base(req_type, section_number, index, doc_req_id, parent_req_id)
if split_total > 1:
return f"{base_id}-S{split_index}"
return base_id
def _generate_base(
self,
req_type: str,
section_number: str,
index: int,
doc_req_id: str,
parent_req_id: str,
) -> str:
if doc_req_id:
complete_id_pattern = r"^[A-Za-z0-9]{2,10}[-_].+$"
if re.match(complete_id_pattern, doc_req_id):
return doc_req_id.replace("_", "-")
if doc_req_id and parent_req_id:
return f"{parent_req_id}-{doc_req_id}"
prefix = self.type_prefix.get(req_type, "FR")
section_part = section_number if section_number else "NA"
return f"{prefix}-{section_part}-{index}"

View File

@@ -0,0 +1,188 @@
# -*- coding: utf-8 -*-
"""
需求长句拆分器。
将复合长句拆分为可验证的原子需求片段。
"""
import re
from typing import List
class RequirementSplitter:
ACTION_HINTS = [
"产生",
"发送",
"设置",
"进入",
"退出",
"关闭",
"开启",
"监测",
"判断",
"记录",
"上传",
"重启",
"恢复",
"关断",
"断电",
"加电",
"执行",
"进行",
]
CONNECTOR_HINTS = ["", "并且", "同时", "然后", "", "以及", ""]
CONDITIONAL_HINTS = ["如果", "", "", "", "其中", "此时", "满足"]
CONTEXT_PRONOUN_HINTS = ["", "", "上述", "", "这些", "那些"]
def __init__(self, max_sentence_len: int = 120, min_clause_len: int = 12):
self.max_sentence_len = max_sentence_len
self.min_clause_len = min_clause_len
def split(self, text: str) -> List[str]:
cleaned = self._clean(text)
if not cleaned:
return []
if self._contains_strong_semantic_chain(cleaned):
return [cleaned]
# 先按强分隔符切分为主片段。
base_parts = self._split_by_strong_punctuation(cleaned)
result: List[str] = []
for part in base_parts:
if len(part) <= self.max_sentence_len:
result.append(part)
continue
# 对超长片段进一步基于逗号和连接词拆分。
refined = self._split_long_clause(part)
result.extend(refined)
result = self._merge_semantic_chain(result)
result = self._merge_too_short(result)
return self._deduplicate(result)
def _contains_strong_semantic_chain(self, text: str) -> bool:
# 条件-动作链完整时,避免强拆。
has_conditional = any(h in text for h in ["如果", "", ""])
has_result = "" in text or "" in text
action_count = sum(1 for h in self.ACTION_HINTS if h in text)
if has_conditional and has_result and action_count >= 2:
return True
return False
def _clean(self, text: str) -> str:
text = re.sub(r"\s+", " ", text or "")
return text.strip(" ;;。")
def _split_by_strong_punctuation(self, text: str) -> List[str]:
chunks = re.split(r"[;。]", text)
return [c.strip(" ,") for c in chunks if c and c.strip(" ,")]
def _split_long_clause(self, clause: str) -> List[str]:
if self._contains_strong_semantic_chain(clause):
return [clause]
raw_parts = [x.strip() for x in re.split(r"[,]", clause) if x.strip()]
if len(raw_parts) <= 1:
return [clause]
assembled: List[str] = []
current = raw_parts[0]
for fragment in raw_parts[1:]:
if self._should_split(current, fragment):
assembled.append(current.strip())
current = fragment
else:
current = f"{current}{fragment}"
if current.strip():
assembled.append(current.strip())
return assembled
def _should_split(self, current: str, fragment: str) -> bool:
if len(current) < self.min_clause_len:
return False
# 指代承接片段通常是语义延续,不应切断。
if any(fragment.startswith(h) for h in self.CONTEXT_PRONOUN_HINTS):
return False
# 条件链中带“则/并/同时”的后继片段,优先保持在同一需求中。
if self._contains_strong_semantic_chain(current + "" + fragment):
return False
frag_starts_with_condition = any(fragment.startswith(h) for h in self.CONDITIONAL_HINTS)
if frag_starts_with_condition:
return False
has_connector = any(fragment.startswith(h) for h in self.CONNECTOR_HINTS)
has_action = any(h in fragment for h in self.ACTION_HINTS)
current_has_action = any(h in current for h in self.ACTION_HINTS)
# 连接词 + 动作词,且当前片段已经包含动作,优先拆分。
if has_connector and has_action and current_has_action:
return True
# 无连接词但出现新的动作片段且整体过长,也拆分。
if has_action and current_has_action and len(current) >= self.max_sentence_len // 2:
return True
return False
def _merge_semantic_chain(self, parts: List[str]) -> List[str]:
if not parts:
return []
merged: List[str] = [parts[0]]
for part in parts[1:]:
prev = merged[-1]
if self._should_merge(prev, part):
merged[-1] = f"{prev}{part}"
else:
merged.append(part)
return merged
def _should_merge(self, prev: str, current: str) -> bool:
# 指代开头:如“该报警信号...”。
if any(current.startswith(h) for h in self.CONTEXT_PRONOUN_HINTS):
return True
# 报警触发后的持续条件与动作属于同一链。
if ("报警" in prev and "持续" in current) or ("产生" in prev and "报警" in prev and "持续" in current):
return True
# 状态迁移 + 后续控制动作保持合并。
if ("进入" in prev or "设置" in prev or "发送" in prev) and ("" in current or "连续" in current):
return True
# 条件链分裂片段重新合并。
if self._contains_strong_semantic_chain(prev + "" + current):
return True
return False
def _merge_too_short(self, parts: List[str]) -> List[str]:
if not parts:
return []
merged: List[str] = []
for part in parts:
if merged and len(part) < self.min_clause_len:
merged[-1] = f"{merged[-1]}{part}"
else:
merged.append(part)
return merged
def _deduplicate(self, parts: List[str]) -> List[str]:
seen = set()
result = []
for part in parts:
key = re.sub(r"\s+", "", part)
if key and key not in seen:
seen.add(key)
result.append(part)
return result

View File

@@ -0,0 +1,162 @@
# -*- coding: utf-8 -*-
"""
统一配置与映射模块。
将需求类型、章节过滤、输出映射和拆分参数收敛到单一入口。
"""
from dataclasses import dataclass
from typing import Dict, List, Any
@dataclass
class RequirementTypeRule:
key: str
chinese_name: str
prefix: str
keywords: List[str]
priority: int
class AppSettings:
"""从 config 读取并提供统一访问接口。"""
TYPE_NAME_MAP = {
"功能需求": "functional",
"接口需求": "interface",
"性能需求": "performance",
"安全需求": "security",
"可靠性需求": "reliability",
"其他需求": "other",
}
DEFAULT_NON_REQUIREMENT_SECTIONS = [
"标识",
"系统概述",
"文档概述",
"引用文档",
"合格性规定",
"需求可追踪性",
"注释",
"附录",
"范围",
"概述",
]
DEFAULT_TYPE_CHINESE = {
"functional": "功能需求",
"interface": "接口需求",
"performance": "其他需求",
"security": "其他需求",
"reliability": "其他需求",
"other": "其他需求",
}
DEFAULT_PREFIX = {
"functional": "FR",
"interface": "IR",
"performance": "PR",
"security": "SR",
"reliability": "RR",
"other": "OR",
}
def __init__(self, config: Dict[str, Any] = None):
self.config = config or {}
document_cfg = self.config.get("document", {})
self.non_requirement_sections = document_cfg.get(
"non_requirement_sections", self.DEFAULT_NON_REQUIREMENT_SECTIONS
)
extraction_cfg = self.config.get("extraction", {})
req_types_cfg = extraction_cfg.get("requirement_types", {})
self.requirement_rules = self._build_rules(req_types_cfg)
self.type_prefix = self._build_type_prefix(req_types_cfg)
self.type_chinese = self._build_type_chinese(req_types_cfg)
splitter_cfg = extraction_cfg.get("splitter", {})
self.splitter_max_sentence_len = int(splitter_cfg.get("max_sentence_len", 120))
self.splitter_min_clause_len = int(splitter_cfg.get("min_clause_len", 12))
self.splitter_enabled = bool(splitter_cfg.get("enabled", True))
semantic_cfg = extraction_cfg.get("semantic_guard", {})
self.semantic_guard_enabled = bool(semantic_cfg.get("enabled", True))
self.preserve_condition_action_chain = bool(
semantic_cfg.get("preserve_condition_action_chain", True)
)
self.preserve_alarm_chain = bool(semantic_cfg.get("preserve_alarm_chain", True))
table_cfg = extraction_cfg.get("table_strategy", {})
self.table_llm_semantic_enabled = bool(table_cfg.get("llm_semantic_enabled", True))
self.sequence_table_merge = table_cfg.get("sequence_table_merge", "single_requirement")
self.merge_time_series_rows_min = int(table_cfg.get("merge_time_series_rows_min", 3))
rewrite_cfg = extraction_cfg.get("rewrite_policy", {})
self.llm_light_rewrite_enabled = bool(rewrite_cfg.get("llm_light_rewrite_enabled", True))
self.preserve_ratio_min = float(rewrite_cfg.get("preserve_ratio_min", 0.65))
self.max_length_growth_ratio = float(rewrite_cfg.get("max_length_growth_ratio", 1.25))
renumber_cfg = extraction_cfg.get("renumber_policy", {})
self.renumber_enabled = bool(renumber_cfg.get("enabled", True))
self.renumber_mode = renumber_cfg.get("mode", "section_continuous")
def _build_rules(self, req_types_cfg: Dict[str, Dict[str, Any]]) -> List[RequirementTypeRule]:
rules: List[RequirementTypeRule] = []
if not req_types_cfg:
# 用默认两类保证兼容旧行为
return [
RequirementTypeRule(
key="interface",
chinese_name="接口需求",
prefix="IR",
keywords=["接口", "interface", "api", "串口", "通信", "CAN", "以太网"],
priority=1,
),
RequirementTypeRule(
key="functional",
chinese_name="功能需求",
prefix="FR",
keywords=["功能", "控制", "处理", "监测", "显示"],
priority=2,
),
]
for zh_name, item in req_types_cfg.items():
key = self.TYPE_NAME_MAP.get(zh_name, "other")
rules.append(
RequirementTypeRule(
key=key,
chinese_name=zh_name,
prefix=item.get("prefix", self.DEFAULT_PREFIX.get(key, "FR")),
keywords=item.get("keywords", []),
priority=int(item.get("priority", 99)),
)
)
return sorted(rules, key=lambda x: x.priority)
def _build_type_prefix(self, req_types_cfg: Dict[str, Dict[str, Any]]) -> Dict[str, str]:
mapping = dict(self.DEFAULT_PREFIX)
for zh_name, key in self.TYPE_NAME_MAP.items():
if zh_name in req_types_cfg:
mapping[key] = req_types_cfg[zh_name].get("prefix", mapping[key])
return mapping
def _build_type_chinese(self, req_types_cfg: Dict[str, Dict[str, Any]]) -> Dict[str, str]:
mapping = dict(self.DEFAULT_TYPE_CHINESE)
for zh_name, key in self.TYPE_NAME_MAP.items():
if zh_name in req_types_cfg:
mapping[key] = zh_name
return mapping
def is_non_requirement_section(self, title: str) -> bool:
return any(keyword in title for keyword in self.non_requirement_sections)
def detect_requirement_type(self, title: str, content: str) -> str:
combined_text = f"{title} {(content or '')[:500]}".lower()
for rule in self.requirement_rules:
for keyword in rule.keywords:
if keyword.lower() in combined_text:
return rule.key
return "functional"

View File

@@ -0,0 +1,134 @@
# src/utils.py
"""
工具函数模块 - 提供各种辅助功能
"""
import os
import logging
from pathlib import Path
from typing import Dict, Any, List, Optional
import yaml
logger = logging.getLogger(__name__)
def load_config(config_path: str = None) -> Dict[str, Any]:
"""
加载配置文件
Args:
config_path: 配置文件路径如果为None则使用默认路径
Returns:
配置字典
"""
if config_path is None:
config_path = os.path.join(os.path.dirname(__file__), '..', 'config.yaml')
if not os.path.exists(config_path):
logger.warning(f"配置文件不存在: {config_path}")
return {}
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
logger.info(f"成功加载配置文件: {config_path}")
return config or {}
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
return {}
def setup_logging(config: Dict[str, Any]) -> None:
"""
配置日志系统
Args:
config: 配置字典
"""
logging_config = config.get('logging', {})
level = logging_config.get('level', 'INFO')
log_format = logging_config.get('format', '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
log_file = logging_config.get('file', None)
# 创建logger
logging.basicConfig(
level=getattr(logging, level),
format=log_format,
handlers=[
logging.StreamHandler(),
logging.FileHandler(log_file) if log_file else logging.NullHandler()
]
)
def validate_file_path(file_path: str, allowed_extensions: List[str] = None) -> bool:
"""
验证文件路径的合法性
Args:
file_path: 文件路径
allowed_extensions: 允许的文件扩展名列表(如['.pdf', '.docx']
Returns:
文件是否合法
"""
if not os.path.exists(file_path):
logger.error(f"文件不存在: {file_path}")
return False
if not os.path.isfile(file_path):
logger.error(f"路径不是文件: {file_path}")
return False
if allowed_extensions:
ext = Path(file_path).suffix.lower()
if ext not in allowed_extensions:
logger.error(f"不支持的文件格式: {ext}")
return False
return True
def ensure_directory_exists(directory: str) -> bool:
"""
确保目录存在,如果不存在则创建
Args:
directory: 目录路径
Returns:
目录是否存在或创建成功
"""
try:
Path(directory).mkdir(parents=True, exist_ok=True)
return True
except Exception as e:
logger.error(f"创建目录失败: {e}")
return False
def get_env_or_config(env_var: str, config_dict: Dict[str, Any],
default: Any = None) -> Any:
"""
优先从环境变量读取,其次从配置字典读取
Args:
env_var: 环境变量名
config_dict: 配置字典
default: 默认值
Returns:
获取到的值
"""
# 尝试从环境变量读取
env_value = os.environ.get(env_var)
if env_value:
return env_value
# 尝试从配置字典读取
config_value = config_dict.get(env_var)
if config_value and not config_value.startswith('${'):
return config_value
return default

View File

@@ -0,0 +1,148 @@
from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Dict, List
import yaml
from app.core.config import settings
from app.tools.base import ToolDefinition
from app.tools.registry import ToolRegistry
from app.tools.srs_reqs_qwen.src.document_parser import create_parser
from app.tools.srs_reqs_qwen.src.json_generator import JSONGenerator
from app.tools.srs_reqs_qwen.src.llm_interface import QwenLLM
from app.tools.srs_reqs_qwen.src.requirement_extractor import RequirementExtractor
class SRSTool:
TOOL_NAME = "srs.requirement_extractor"
DEFINITION = ToolDefinition(
name=TOOL_NAME,
version="1.0.0",
description="Extract structured requirements from SRS documents.",
input_schema={
"type": "object",
"properties": {
"file_path": {"type": "string"},
"enable_llm": {"type": "boolean"},
},
"required": ["file_path"],
},
output_schema={
"type": "object",
"properties": {
"document_name": {"type": "string"},
"generated_at": {"type": "string"},
"requirements": {"type": "array"},
"statistics": {"type": "object"},
"raw_output": {"type": "object"},
},
},
)
PRIORITY_BY_TYPE = {
"functional": "",
"interface": "",
"performance": "",
"security": "",
"reliability": "",
"other": "",
}
def __init__(self) -> None:
ToolRegistry.register(self.DEFINITION)
def run(self, file_path: str, enable_llm: bool = True) -> Dict[str, Any]:
config = self._load_config()
llm = self._build_llm(config, enable_llm=enable_llm)
parser = create_parser(file_path)
if llm is not None:
parser.set_llm(llm)
sections = parser.parse()
document_title = parser.get_document_title() or Path(file_path).name
extractor = RequirementExtractor(config, llm=llm)
extracted = extractor.extract_from_sections(sections)
stats = extractor.get_statistics()
generator = JSONGenerator(config)
raw_output = generator.generate(sections, extracted, document_title)
requirements = self._normalize_requirements(extracted)
return {
"document_name": Path(file_path).name,
"document_title": document_title,
"generated_at": raw_output.get("文档元数据", {}).get("生成时间"),
"requirements": requirements,
"statistics": stats,
"raw_output": raw_output,
}
def _normalize_requirements(self, extracted: List[Any]) -> List[Dict[str, Any]]:
normalized: List[Dict[str, Any]] = []
for index, req in enumerate(extracted, start=1):
description = (req.description or "").strip()
title = description[:40] if description else f"需求项 {index}"
source_field = f"{req.section_number} {req.section_title}".strip() or "文档解析"
normalized.append(
{
"id": req.id,
"title": title,
"description": description,
"priority": self.PRIORITY_BY_TYPE.get(req.type, ""),
"acceptance_criteria": [description] if description else ["待补充验收标准"],
"source_field": source_field,
"section_number": req.section_number,
"section_title": req.section_title,
"requirement_type": req.type,
"sort_order": index,
}
)
return normalized
def _load_config(self) -> Dict[str, Any]:
config_path = Path(__file__).with_name("default_config.yaml")
if config_path.exists():
with config_path.open("r", encoding="utf-8") as handle:
config = yaml.safe_load(handle) or {}
else:
config = {}
config.setdefault("llm", {})
config["llm"]["model"] = settings.DASH_SCOPE_CHAT_MODEL or settings.OPENAI_MODEL
config["llm"]["api_key"] = settings.DASH_SCOPE_API_KEY or os.getenv("DASHSCOPE_API_KEY", "")
config["llm"]["api_base"] = settings.DASH_SCOPE_API_BASE
config["llm"]["enabled"] = bool(config["llm"].get("api_key"))
return config
def _build_llm(self, config: Dict[str, Any], enable_llm: bool) -> QwenLLM | None:
if not enable_llm:
return None
llm_cfg = config.get("llm", {})
api_key = llm_cfg.get("api_key")
if not api_key:
return None
return QwenLLM(
api_key=api_key,
model=llm_cfg.get("model", "qwen3-max"),
api_endpoint=llm_cfg.get("api_base") or settings.DASH_SCOPE_API_BASE,
temperature=llm_cfg.get("temperature", 0.3),
max_tokens=llm_cfg.get("max_tokens", 1024),
)
_SRS_TOOL_SINGLETON: SRSTool | None = None
def get_srs_tool() -> SRSTool:
global _SRS_TOOL_SINGLETON
if _SRS_TOOL_SINGLETON is None:
_SRS_TOOL_SINGLETON = SRSTool()
return _SRS_TOOL_SINGLETON