Files
Extract_reqs/main.py
2026-02-03 22:48:22 +08:00

252 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
SRS 解析工具 - 主程序入口
LLM 增强版 - 默认阿里云千问大模型
"""
import argparse
import os
import sys
import logging
from pathlib import Path
# 添加当前目录到Python路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.utils import load_config, setup_logging, validate_file_path, ensure_directory_exists, get_env_or_config
from src.document_parser import create_parser
from src.requirement_extractor import RequirementExtractor
from src.json_generator import JSONGenerator
logger = logging.getLogger(__name__)
def create_llm(config: dict):
"""
创建LLM实例
Args:
config: 配置字典
Returns:
LLM实例或None
"""
llm_config = config.get('llm', {})
# 检查是否启用LLM
if not llm_config.get('enabled', True):
logger.info("LLM已禁用使用纯规则提取模式")
return None
provider = llm_config.get('provider', 'qwen')
# 获取API密钥优先使用环境变量
api_key = get_env_or_config('DASHSCOPE_API_KEY', llm_config.get('api_key'))
if not api_key:
logger.warning("未配置API密钥请使用纯规则提取模式")
logger.warning("请设置环境变量 DASHSCOPE_API_KEY 或在 config.yaml 中配置 llm.api_key")
return None
try:
from src.llm_interface import QwenLLM
model = llm_config.get('model', 'qwen-plus')
temperature = llm_config.get('temperature', 0.3)
max_tokens = llm_config.get('max_tokens', 1024)
llm = QwenLLM(
api_key=api_key,
model=model,
temperature=temperature,
max_tokens=max_tokens
)
logger.info(f"成功创建LLM实例: {provider} ({model})")
return llm
except ImportError as e:
logger.warning(f"无法导入LLM模块: {e}")
logger.warning("请运行: pip install dashscope")
return None
except Exception as e:
logger.warning(f"创建LLM实例失败: {e}")
return None
def main():
"""主程序入口"""
# 解析命令行参数
parser = argparse.ArgumentParser(
description='SRS需求文档解析工具',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例用法:
python main.py --input sample.pdf --output output.json
python main.py -i requirements.docx -o output.json --verbose
python main.py -i DC-SRS.pdf -o output.json --no-llm # 禁用LLM
"""
)
parser.add_argument(
'--input', '-i',
type=str,
required=True,
help='输入的SRS文档路径支持.docx和.pdf'
)
parser.add_argument(
'--output', '-o',
type=str,
default='output.json',
help='输出JSON文件路径默认output.json'
)
parser.add_argument(
'--config', '-c',
type=str,
default=None,
help='配置文件路径(默认:./config.yaml'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='输出详细日志'
)
parser.add_argument(
'--no-llm',
action='store_true',
help='禁用LLM使用纯规则提取'
)
# 解析命令行参数
args = parser.parse_args()
# 加载配置
config = load_config(args.config)
# 命令行参数覆盖配置
if args.no_llm:
config.setdefault('llm', {})['enabled'] = False
# 设置日志
if args.verbose:
config.setdefault('logging', {})['level'] = 'DEBUG'
setup_logging(config)
logger.info("=" * 60)
logger.info("SRS需求文档解析工具启动LLM增强版")
logger.info("=" * 60)
try:
# 验证输入文件
if not validate_file_path(args.input, ['.pdf', '.docx']):
logger.error(f"输入文件验证失败: {args.input}")
return False
logger.info(f"输入文件: {args.input}")
# 创建输出目录
output_dir = os.path.dirname(args.output) or '.'
if output_dir != '.' and not ensure_directory_exists(output_dir):
logger.error(f"无法创建输出目录: {output_dir}")
return False
logger.info(f"输出文件: {args.output}")
# 创建LLM实例
llm = create_llm(config)
if llm:
logger.info("LLM增强模式已启用")
else:
logger.info("使用纯规则提取模式")
# 步骤1解析文档
logger.info("\n" + "=" * 60)
logger.info("步骤1解析文档")
logger.info("=" * 60)
doc_parser = create_parser(args.input)
if llm:
doc_parser.set_llm(llm)
sections = doc_parser.parse()
document_title = doc_parser.get_document_title()
logger.info(f"成功解析文档,提取{len(sections)}个顶级章节")
# 打印章节结构
def print_sections(sections, indent=0):
for section in sections:
logger.info(" " * indent + f"- {section.number} {section.title}")
if section.children:
print_sections(section.children, indent + 1)
if args.verbose:
logger.info("章节结构:")
print_sections(sections)
# 步骤2提取需求
logger.info("\n" + "=" * 60)
if llm:
logger.info("步骤2提取需求LLM增强模式")
else:
logger.info("步骤2提取需求规则匹配模式")
logger.info("=" * 60)
extractor = RequirementExtractor(config, llm=llm)
requirements = extractor.extract_from_sections(sections)
# 统计需求信息
stats = extractor.get_statistics()
logger.info(f"\n需求统计:")
for req_type, count in stats['by_type'].items():
logger.info(f" {req_type}: {count}")
logger.info(f" 总计: {stats['total']}")
# 步骤3生成JSON
logger.info("\n" + "=" * 60)
logger.info("步骤3生成JSON")
logger.info("=" * 60)
generator = JSONGenerator(config)
json_output = generator.generate(
sections,
requirements,
document_title
)
logger.info(f"JSON结构生成完成")
# 步骤4保存文件
logger.info("\n" + "=" * 60)
logger.info("步骤4保存结果")
logger.info("=" * 60)
generator.save_to_file(json_output, args.output)
logger.info(f"成功保存JSON文件到: {args.output}")
# 打印输出文件大小
if os.path.exists(args.output):
file_size = os.path.getsize(args.output)
logger.info(f"文件大小: {file_size} 字节")
logger.info("\n" + "=" * 60)
logger.info("SRS需求文档解析完成")
logger.info("=" * 60)
return True
except Exception as e:
logger.error(f"处理过程中出现错误: {e}", exc_info=True)
return False
if __name__ == '__main__':
success = main()
sys.exit(0 if success else 1)