Files
Extract_reqs/main.py

315 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
SRS 解析工具 - 主程序入口
"""
import argparse
import os
import sys
import logging
from pathlib import Path
# 添加当前目录到Python路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.utils import load_config, setup_logging, validate_file_path, ensure_directory_exists, get_env_or_config
from src.document_parser import create_parser
from src.document_parser import Section
from src.requirement_extractor import RequirementExtractor
from src.json_generator import JSONGenerator
logger = logging.getLogger(__name__)
def create_llm(config: dict):
"""
创建LLM实例
Args:
config: 配置字典
Returns:
LLM实例或None
"""
llm_config = config.get('llm', {})
# 当前版本仅支持LLM模式
if not llm_config.get('enabled', True):
raise ValueError("当前版本仅支持LLM模式请将配置 llm.enabled 设为 true")
provider = llm_config.get('provider', 'qwen')
# 获取API密钥优先使用环境变量
api_key = get_env_or_config('DASHSCOPE_API_KEY', llm_config.get('api_key'))
if not api_key:
raise ValueError("未配置API密钥请设置环境变量 DASHSCOPE_API_KEY 或在 config.yaml 中配置 llm.api_key")
try:
from src.llm_interface import QwenLLM
model = llm_config.get('model', 'qwen-plus')
temperature = llm_config.get('temperature', 0.3)
max_tokens = llm_config.get('max_tokens', 1024)
llm = QwenLLM(
api_key=api_key,
model=model,
temperature=temperature,
max_tokens=max_tokens
)
logger.info(f"成功创建LLM实例: {provider} ({model})")
return llm
except ImportError as e:
raise RuntimeError(f"无法导入LLM模块: {e}。请安装依赖pip install dashscope") from e
except Exception as e:
raise RuntimeError(f"创建LLM实例失败: {e}") from e
def parse_chapter_selector(selector: str) -> list:
"""解析章节筛选参数。"""
if not selector:
return []
chapters = [x.strip() for x in selector.split(',') if x.strip()]
valid = []
for chapter in chapters:
if not chapter or not all(p.isdigit() for p in chapter.split('.')):
raise ValueError(f"无效章节编号: {chapter},仅支持如 3 或 3.1 的格式")
valid.append(chapter)
return valid
def _clone_section_with_children(section: Section) -> Section:
copied = Section(
level=section.level,
title=section.title,
number=section.number,
content=section.content,
uid=section.uid,
)
copied.tables = list(section.tables)
copied.blocks = list(section.blocks)
for child in section.children:
copied.add_child(_clone_section_with_children(child))
return copied
def filter_sections_by_chapters(sections: list, chapters: list) -> list:
"""按章节前缀过滤章节树如3匹配3及3.x"""
if not chapters:
return sections
def matched(number: str) -> bool:
number = (number or "").strip()
if not number:
return False
for chapter in chapters:
if number == chapter or number.startswith(f"{chapter}."):
return True
return False
def recurse(section: Section) -> Section:
if matched(section.number):
return _clone_section_with_children(section)
copied = Section(
level=section.level,
title=section.title,
number=section.number,
content=section.content,
uid=section.uid,
)
copied.tables = list(section.tables)
copied.blocks = list(section.blocks)
for child in section.children:
filtered_child = recurse(child)
if filtered_child:
copied.add_child(filtered_child)
return copied if copied.children else None
filtered = []
for s in sections:
fs = recurse(s)
if fs:
filtered.append(fs)
return filtered
def main():
"""主程序入口"""
# 解析命令行参数
parser = argparse.ArgumentParser(
description='SRS需求文档解析工具',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例用法:
python main.py --input sample.pdf --output output.json
python main.py -i requirements.docx -o output.json --verbose
python main.py -i DC-SRS.pdf -o output.json
"""
)
parser.add_argument(
'--input', '-i',
type=str,
required=True,
help='输入的SRS文档路径支持.docx和.pdf'
)
parser.add_argument(
'--output', '-o',
type=str,
default='output.json',
help='输出JSON文件路径默认output.json'
)
parser.add_argument(
'--config', '-c',
type=str,
default=None,
help='配置文件路径(默认:./config.yaml'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='输出详细日志'
)
parser.add_argument(
'--chapters',
type=str,
default=None,
help='按章节提取(如: 3 或 3,4.1输入3表示提取第3章及其子章节'
)
# 解析命令行参数
args = parser.parse_args()
# 加载配置
config = load_config(args.config)
# 设置日志
if args.verbose:
config.setdefault('logging', {})['level'] = 'DEBUG'
setup_logging(config)
logger.info("=" * 60)
logger.info("SRS需求文档解析工具启动LLM增强版")
logger.info("=" * 60)
try:
# 验证输入文件
if not validate_file_path(args.input, ['.pdf', '.docx']):
logger.error(f"输入文件验证失败: {args.input}")
return False
logger.info(f"输入文件: {args.input}")
# 创建输出目录
output_dir = os.path.dirname(args.output) or '.'
if output_dir != '.' and not ensure_directory_exists(output_dir):
logger.error(f"无法创建输出目录: {output_dir}")
return False
logger.info(f"输出文件: {args.output}")
# 创建LLM实例必需
llm = create_llm(config)
logger.info("LLM增强模式已启用")
# 步骤1解析文档
logger.info("\n" + "=" * 60)
logger.info("步骤1解析文档")
logger.info("=" * 60)
doc_parser = create_parser(args.input)
if llm:
doc_parser.set_llm(llm)
sections = doc_parser.parse()
document_title = doc_parser.get_document_title()
selected_chapters = parse_chapter_selector(args.chapters) if args.chapters else []
if selected_chapters:
sections = filter_sections_by_chapters(sections, selected_chapters)
if not sections:
raise ValueError(f"未匹配到指定章节: {', '.join(selected_chapters)}")
logger.info(f"章节筛选已启用: {', '.join(selected_chapters)}")
logger.info(f"成功解析文档,提取{len(sections)}个顶级章节")
# 打印章节结构
def print_sections(sections, indent=0):
for section in sections:
logger.info(" " * indent + f"- {section.number} {section.title}")
if section.children:
print_sections(section.children, indent + 1)
if args.verbose:
logger.info("章节结构:")
print_sections(sections)
# 步骤2提取需求
logger.info("\n" + "=" * 60)
logger.info("步骤2提取需求LLM增强模式")
logger.info("=" * 60)
extractor = RequirementExtractor(config, llm=llm)
requirements = extractor.extract_from_sections(sections)
# 统计需求信息
stats = extractor.get_statistics()
logger.info(f"\n需求统计:")
for req_type, count in stats['by_type'].items():
logger.info(f" {req_type}: {count}")
logger.info(f" 总计: {stats['total']}")
# 步骤3生成JSON
logger.info("\n" + "=" * 60)
logger.info("步骤3生成JSON")
logger.info("=" * 60)
generator = JSONGenerator(config)
json_output = generator.generate(
sections,
requirements,
document_title
)
logger.info(f"JSON结构生成完成")
# 步骤4保存文件
logger.info("\n" + "=" * 60)
logger.info("步骤4保存结果")
logger.info("=" * 60)
generator.save_to_file(json_output, args.output)
logger.info(f"成功保存JSON文件到: {args.output}")
# 打印输出文件大小
if os.path.exists(args.output):
file_size = os.path.getsize(args.output)
logger.info(f"文件大小: {file_size} 字节")
logger.info("\n" + "=" * 60)
logger.info("SRS需求文档解析完成")
logger.info("=" * 60)
return True
except Exception as e:
logger.error(f"处理过程中出现错误: {e}", exc_info=True)
return False
if __name__ == '__main__':
success = main()
sys.exit(0 if success else 1)