只保留LLM提取模式,修改提取逻辑
This commit is contained in:
123
main.py
123
main.py
@@ -2,7 +2,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
SRS 解析工具 - 主程序入口
|
||||
LLM 增强版 - 默认阿里云千问大模型
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -16,6 +15,7 @@ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from src.utils import load_config, setup_logging, validate_file_path, ensure_directory_exists, get_env_or_config
|
||||
from src.document_parser import create_parser
|
||||
from src.document_parser import Section
|
||||
from src.requirement_extractor import RequirementExtractor
|
||||
from src.json_generator import JSONGenerator
|
||||
|
||||
@@ -34,10 +34,9 @@ def create_llm(config: dict):
|
||||
"""
|
||||
llm_config = config.get('llm', {})
|
||||
|
||||
# 检查是否启用LLM
|
||||
# 当前版本仅支持LLM模式
|
||||
if not llm_config.get('enabled', True):
|
||||
logger.info("LLM已禁用,使用纯规则提取模式")
|
||||
return None
|
||||
raise ValueError("当前版本仅支持LLM模式,请将配置 llm.enabled 设为 true")
|
||||
|
||||
provider = llm_config.get('provider', 'qwen')
|
||||
|
||||
@@ -45,9 +44,7 @@ def create_llm(config: dict):
|
||||
api_key = get_env_or_config('DASHSCOPE_API_KEY', llm_config.get('api_key'))
|
||||
|
||||
if not api_key:
|
||||
logger.warning("未配置API密钥,请使用纯规则提取模式")
|
||||
logger.warning("请设置环境变量 DASHSCOPE_API_KEY 或在 config.yaml 中配置 llm.api_key")
|
||||
return None
|
||||
raise ValueError("未配置API密钥:请设置环境变量 DASHSCOPE_API_KEY 或在 config.yaml 中配置 llm.api_key")
|
||||
|
||||
try:
|
||||
from src.llm_interface import QwenLLM
|
||||
@@ -67,12 +64,80 @@ def create_llm(config: dict):
|
||||
return llm
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"无法导入LLM模块: {e}")
|
||||
logger.warning("请运行: pip install dashscope")
|
||||
return None
|
||||
raise RuntimeError(f"无法导入LLM模块: {e}。请安装依赖:pip install dashscope") from e
|
||||
except Exception as e:
|
||||
logger.warning(f"创建LLM实例失败: {e}")
|
||||
return None
|
||||
raise RuntimeError(f"创建LLM实例失败: {e}") from e
|
||||
|
||||
|
||||
def parse_chapter_selector(selector: str) -> list:
|
||||
"""解析章节筛选参数。"""
|
||||
if not selector:
|
||||
return []
|
||||
chapters = [x.strip() for x in selector.split(',') if x.strip()]
|
||||
valid = []
|
||||
for chapter in chapters:
|
||||
if not chapter or not all(p.isdigit() for p in chapter.split('.')):
|
||||
raise ValueError(f"无效章节编号: {chapter},仅支持如 3 或 3.1 的格式")
|
||||
valid.append(chapter)
|
||||
return valid
|
||||
|
||||
|
||||
def _clone_section_with_children(section: Section) -> Section:
|
||||
copied = Section(
|
||||
level=section.level,
|
||||
title=section.title,
|
||||
number=section.number,
|
||||
content=section.content,
|
||||
uid=section.uid,
|
||||
)
|
||||
copied.tables = list(section.tables)
|
||||
copied.blocks = list(section.blocks)
|
||||
for child in section.children:
|
||||
copied.add_child(_clone_section_with_children(child))
|
||||
return copied
|
||||
|
||||
|
||||
def filter_sections_by_chapters(sections: list, chapters: list) -> list:
|
||||
"""按章节前缀过滤章节树(如3匹配3及3.x)。"""
|
||||
if not chapters:
|
||||
return sections
|
||||
|
||||
def matched(number: str) -> bool:
|
||||
number = (number or "").strip()
|
||||
if not number:
|
||||
return False
|
||||
for chapter in chapters:
|
||||
if number == chapter or number.startswith(f"{chapter}."):
|
||||
return True
|
||||
return False
|
||||
|
||||
def recurse(section: Section) -> Section:
|
||||
if matched(section.number):
|
||||
return _clone_section_with_children(section)
|
||||
|
||||
copied = Section(
|
||||
level=section.level,
|
||||
title=section.title,
|
||||
number=section.number,
|
||||
content=section.content,
|
||||
uid=section.uid,
|
||||
)
|
||||
copied.tables = list(section.tables)
|
||||
copied.blocks = list(section.blocks)
|
||||
|
||||
for child in section.children:
|
||||
filtered_child = recurse(child)
|
||||
if filtered_child:
|
||||
copied.add_child(filtered_child)
|
||||
|
||||
return copied if copied.children else None
|
||||
|
||||
filtered = []
|
||||
for s in sections:
|
||||
fs = recurse(s)
|
||||
if fs:
|
||||
filtered.append(fs)
|
||||
return filtered
|
||||
|
||||
|
||||
def main():
|
||||
@@ -86,7 +151,7 @@ def main():
|
||||
示例用法:
|
||||
python main.py --input sample.pdf --output output.json
|
||||
python main.py -i requirements.docx -o output.json --verbose
|
||||
python main.py -i DC-SRS.pdf -o output.json --no-llm # 禁用LLM
|
||||
python main.py -i DC-SRS.pdf -o output.json
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -116,11 +181,12 @@ def main():
|
||||
action='store_true',
|
||||
help='输出详细日志'
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
'--no-llm',
|
||||
action='store_true',
|
||||
help='禁用LLM,使用纯规则提取'
|
||||
'--chapters',
|
||||
type=str,
|
||||
default=None,
|
||||
help='按章节提取(如: 3 或 3,4.1);输入3表示提取第3章及其子章节'
|
||||
)
|
||||
|
||||
# 解析命令行参数
|
||||
@@ -129,10 +195,6 @@ def main():
|
||||
# 加载配置
|
||||
config = load_config(args.config)
|
||||
|
||||
# 命令行参数覆盖配置
|
||||
if args.no_llm:
|
||||
config.setdefault('llm', {})['enabled'] = False
|
||||
|
||||
# 设置日志
|
||||
if args.verbose:
|
||||
config.setdefault('logging', {})['level'] = 'DEBUG'
|
||||
@@ -158,12 +220,9 @@ def main():
|
||||
|
||||
logger.info(f"输出文件: {args.output}")
|
||||
|
||||
# 创建LLM实例
|
||||
# 创建LLM实例(必需)
|
||||
llm = create_llm(config)
|
||||
if llm:
|
||||
logger.info("LLM增强模式已启用")
|
||||
else:
|
||||
logger.info("使用纯规则提取模式")
|
||||
logger.info("LLM增强模式已启用")
|
||||
|
||||
# 步骤1:解析文档
|
||||
logger.info("\n" + "=" * 60)
|
||||
@@ -176,6 +235,13 @@ def main():
|
||||
|
||||
sections = doc_parser.parse()
|
||||
document_title = doc_parser.get_document_title()
|
||||
|
||||
selected_chapters = parse_chapter_selector(args.chapters) if args.chapters else []
|
||||
if selected_chapters:
|
||||
sections = filter_sections_by_chapters(sections, selected_chapters)
|
||||
if not sections:
|
||||
raise ValueError(f"未匹配到指定章节: {', '.join(selected_chapters)}")
|
||||
logger.info(f"章节筛选已启用: {', '.join(selected_chapters)}")
|
||||
|
||||
logger.info(f"成功解析文档,提取{len(sections)}个顶级章节")
|
||||
|
||||
@@ -192,10 +258,7 @@ def main():
|
||||
|
||||
# 步骤2:提取需求
|
||||
logger.info("\n" + "=" * 60)
|
||||
if llm:
|
||||
logger.info("步骤2:提取需求(LLM增强模式)")
|
||||
else:
|
||||
logger.info("步骤2:提取需求(规则匹配模式)")
|
||||
logger.info("步骤2:提取需求(LLM增强模式)")
|
||||
logger.info("=" * 60)
|
||||
|
||||
extractor = RequirementExtractor(config, llm=llm)
|
||||
|
||||
Reference in New Issue
Block a user