Files

252 lines
7.3 KiB
Python
Raw Permalink Normal View History

2026-02-03 22:48:22 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
SRS 解析工具 - 主程序入口
LLM 增强版 - 默认阿里云千问大模型
"""
import argparse
import os
import sys
import logging
from pathlib import Path
# 添加当前目录到Python路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.utils import load_config, setup_logging, validate_file_path, ensure_directory_exists, get_env_or_config
from src.document_parser import create_parser
from src.requirement_extractor import RequirementExtractor
from src.json_generator import JSONGenerator
logger = logging.getLogger(__name__)
def create_llm(config: dict):
"""
创建LLM实例
Args:
config: 配置字典
Returns:
LLM实例或None
"""
llm_config = config.get('llm', {})
# 检查是否启用LLM
if not llm_config.get('enabled', True):
logger.info("LLM已禁用使用纯规则提取模式")
return None
provider = llm_config.get('provider', 'qwen')
# 获取API密钥优先使用环境变量
api_key = get_env_or_config('DASHSCOPE_API_KEY', llm_config.get('api_key'))
if not api_key:
logger.warning("未配置API密钥请使用纯规则提取模式")
logger.warning("请设置环境变量 DASHSCOPE_API_KEY 或在 config.yaml 中配置 llm.api_key")
return None
try:
from src.llm_interface import QwenLLM
model = llm_config.get('model', 'qwen-plus')
temperature = llm_config.get('temperature', 0.3)
max_tokens = llm_config.get('max_tokens', 1024)
llm = QwenLLM(
api_key=api_key,
model=model,
temperature=temperature,
max_tokens=max_tokens
)
logger.info(f"成功创建LLM实例: {provider} ({model})")
return llm
except ImportError as e:
logger.warning(f"无法导入LLM模块: {e}")
logger.warning("请运行: pip install dashscope")
return None
except Exception as e:
logger.warning(f"创建LLM实例失败: {e}")
return None
def main():
"""主程序入口"""
# 解析命令行参数
parser = argparse.ArgumentParser(
description='SRS需求文档解析工具',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例用法
python main.py --input sample.pdf --output output.json
python main.py -i requirements.docx -o output.json --verbose
python main.py -i DC-SRS.pdf -o output.json --no-llm # 禁用LLM
"""
)
parser.add_argument(
'--input', '-i',
type=str,
required=True,
help='输入的SRS文档路径支持.docx和.pdf'
)
parser.add_argument(
'--output', '-o',
type=str,
default='output.json',
help='输出JSON文件路径默认output.json'
)
parser.add_argument(
'--config', '-c',
type=str,
default=None,
help='配置文件路径(默认:./config.yaml'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='输出详细日志'
)
parser.add_argument(
'--no-llm',
action='store_true',
help='禁用LLM使用纯规则提取'
)
# 解析命令行参数
args = parser.parse_args()
# 加载配置
config = load_config(args.config)
# 命令行参数覆盖配置
if args.no_llm:
config.setdefault('llm', {})['enabled'] = False
# 设置日志
if args.verbose:
config.setdefault('logging', {})['level'] = 'DEBUG'
setup_logging(config)
logger.info("=" * 60)
logger.info("SRS需求文档解析工具启动LLM增强版")
logger.info("=" * 60)
try:
# 验证输入文件
if not validate_file_path(args.input, ['.pdf', '.docx']):
logger.error(f"输入文件验证失败: {args.input}")
return False
logger.info(f"输入文件: {args.input}")
# 创建输出目录
output_dir = os.path.dirname(args.output) or '.'
if output_dir != '.' and not ensure_directory_exists(output_dir):
logger.error(f"无法创建输出目录: {output_dir}")
return False
logger.info(f"输出文件: {args.output}")
# 创建LLM实例
llm = create_llm(config)
if llm:
logger.info("LLM增强模式已启用")
else:
logger.info("使用纯规则提取模式")
# 步骤1解析文档
logger.info("\n" + "=" * 60)
logger.info("步骤1解析文档")
logger.info("=" * 60)
doc_parser = create_parser(args.input)
if llm:
doc_parser.set_llm(llm)
sections = doc_parser.parse()
document_title = doc_parser.get_document_title()
logger.info(f"成功解析文档,提取{len(sections)}个顶级章节")
# 打印章节结构
def print_sections(sections, indent=0):
for section in sections:
logger.info(" " * indent + f"- {section.number} {section.title}")
if section.children:
print_sections(section.children, indent + 1)
if args.verbose:
logger.info("章节结构:")
print_sections(sections)
# 步骤2提取需求
logger.info("\n" + "=" * 60)
if llm:
logger.info("步骤2提取需求LLM增强模式")
else:
logger.info("步骤2提取需求规则匹配模式")
logger.info("=" * 60)
extractor = RequirementExtractor(config, llm=llm)
requirements = extractor.extract_from_sections(sections)
# 统计需求信息
stats = extractor.get_statistics()
logger.info(f"\n需求统计:")
for req_type, count in stats['by_type'].items():
logger.info(f" {req_type}: {count}")
logger.info(f" 总计: {stats['total']}")
# 步骤3生成JSON
logger.info("\n" + "=" * 60)
logger.info("步骤3生成JSON")
logger.info("=" * 60)
generator = JSONGenerator(config)
json_output = generator.generate(
sections,
requirements,
document_title
)
logger.info(f"JSON结构生成完成")
# 步骤4保存文件
logger.info("\n" + "=" * 60)
logger.info("步骤4保存结果")
logger.info("=" * 60)
generator.save_to_file(json_output, args.output)
logger.info(f"成功保存JSON文件到: {args.output}")
# 打印输出文件大小
if os.path.exists(args.output):
file_size = os.path.getsize(args.output)
logger.info(f"文件大小: {file_size} 字节")
logger.info("\n" + "=" * 60)
logger.info("SRS需求文档解析完成")
logger.info("=" * 60)
return True
except Exception as e:
logger.error(f"处理过程中出现错误: {e}", exc_info=True)
return False
if __name__ == '__main__':
success = main()
sys.exit(0 if success else 1)