完善skills;测试用例生成页面功能初步实现

This commit is contained in:
2026-05-05 19:45:33 +08:00
parent 0c2ed67e2a
commit 69b49d28b2
35 changed files with 4396 additions and 658 deletions

View File

@@ -0,0 +1,34 @@
"""add interface fields to srs requirements
Revision ID: b7217f0c3d92
Revises: a4f9c89b7d11
Create Date: 2026-04-18 19:10:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "b7217f0c3d92"
down_revision: Union[str, None] = "a4f9c89b7d11"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column("srs_requirements", sa.Column("section_uid", sa.String(length=64), nullable=True))
op.add_column("srs_requirements", sa.Column("interface_name", sa.String(length=255), nullable=True))
op.add_column("srs_requirements", sa.Column("interface_type", sa.String(length=128), nullable=True))
op.add_column("srs_requirements", sa.Column("data_source", sa.String(length=255), nullable=True))
op.add_column("srs_requirements", sa.Column("data_destination", sa.String(length=255), nullable=True))
def downgrade() -> None:
op.drop_column("srs_requirements", "data_destination")
op.drop_column("srs_requirements", "data_source")
op.drop_column("srs_requirements", "interface_type")
op.drop_column("srs_requirements", "interface_name")
op.drop_column("srs_requirements", "section_uid")

View File

@@ -0,0 +1,59 @@
"""add testing generation history table
Revision ID: c9f6e7a1bd34
Revises: b7217f0c3d92
Create Date: 2026-04-26 20:30:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "c9f6e7a1bd34"
down_revision: Union[str, None] = "b7217f0c3d92"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"testing_generations",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("job_id", sa.Integer(), nullable=False),
sa.Column("source_job_id", sa.Integer(), nullable=True),
sa.Column("source_document_name", sa.String(length=255), nullable=False),
sa.Column("generated_at", sa.DateTime(), nullable=False),
sa.Column("total_requirements", sa.Integer(), nullable=False, server_default="0"),
sa.Column("knowledge_base_id", sa.Integer(), nullable=True),
sa.Column("generated_file", sa.JSON(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(["job_id"], ["tool_jobs.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["knowledge_base_id"], ["knowledge_bases.id"]),
sa.ForeignKeyConstraint(["source_job_id"], ["tool_jobs.id"]),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("job_id"),
)
op.create_index(op.f("ix_testing_generations_id"), "testing_generations", ["id"], unique=False)
op.create_index(
op.f("ix_testing_generations_source_job_id"),
"testing_generations",
["source_job_id"],
unique=False,
)
op.create_index(
op.f("ix_testing_generations_knowledge_base_id"),
"testing_generations",
["knowledge_base_id"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_testing_generations_knowledge_base_id"), table_name="testing_generations")
op.drop_index(op.f("ix_testing_generations_source_job_id"), table_name="testing_generations")
op.drop_index(op.f("ix_testing_generations_id"), table_name="testing_generations")
op.drop_table("testing_generations")

View File

@@ -1,6 +1,8 @@
import logging
import asyncio
from typing import Any, Dict, List
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.core.config import settings
@@ -15,6 +17,8 @@ from app.services.testing_pipeline import run_testing_pipeline
from app.services.vector_store import VectorStoreFactory
router = APIRouter()
logger = logging.getLogger(__name__)
MODEL_PIPELINE_TIMEOUT_SECONDS = 300
async def _build_kb_vector_stores(db: Session, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
@@ -47,38 +51,75 @@ async def generate_testing_content(
knowledge_context = (payload.knowledge_context or "").strip()
if payload.knowledge_base_ids:
knowledge_bases = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id.in_(payload.knowledge_base_ids),
KnowledgeBase.user_id == current_user.id,
try:
knowledge_bases = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id.in_(payload.knowledge_base_ids),
KnowledgeBase.user_id == current_user.id,
)
.all()
)
.all()
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
if kb_vector_stores:
retriever = MultiKBRetriever(
reranker_weight=settings.RERANKER_WEIGHT,
)
retrieval_rows = await retriever.retrieve(
query=payload.requirement_text,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=max(12, payload.retrieval_top_k * 2),
top_k=payload.retrieval_top_k,
)
if retrieval_rows:
knowledge_context = format_retrieval_context(retrieval_rows)
except Exception as exc:
logger.exception(
"Testing generation retrieval fallback triggered for user=%s knowledge_base_ids=%s: %s",
current_user.id,
payload.knowledge_base_ids,
exc,
)
pipeline_kwargs = {
"user_requirement_text": payload.requirement_text,
"requirement_type_input": payload.requirement_type,
"debug": payload.debug,
"knowledge_context": knowledge_context,
"use_model_generation": payload.use_model_generation,
"max_items_per_group": payload.max_items_per_group,
"cases_per_item": payload.cases_per_item,
"max_focus_points": payload.max_focus_points,
"max_llm_calls": payload.max_llm_calls,
}
try:
result = await asyncio.wait_for(
asyncio.to_thread(run_testing_pipeline, **pipeline_kwargs),
timeout=MODEL_PIPELINE_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError as exc:
logger.exception(
"Testing pipeline timed out for user=%s use_model_generation=%s after %s seconds",
current_user.id,
payload.use_model_generation,
MODEL_PIPELINE_TIMEOUT_SECONDS,
)
raise HTTPException(
status_code=504,
detail=f"LLM generation timed out after {MODEL_PIPELINE_TIMEOUT_SECONDS} seconds",
) from exc
except Exception as exc:
logger.exception(
"Testing pipeline failed for user=%s use_model_generation=%s: %s",
current_user.id,
payload.use_model_generation,
exc,
)
raise HTTPException(
status_code=500,
detail=f"LLM generation failed: {exc}",
) from exc
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
if kb_vector_stores:
retriever = MultiKBRetriever(
reranker_weight=settings.RERANKER_WEIGHT,
)
retrieval_rows = await retriever.retrieve(
query=payload.requirement_text,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=max(12, payload.retrieval_top_k * 2),
top_k=payload.retrieval_top_k,
)
if retrieval_rows:
knowledge_context = format_retrieval_context(retrieval_rows)
result = run_testing_pipeline(
user_requirement_text=payload.requirement_text,
requirement_type_input=payload.requirement_type,
debug=payload.debug,
knowledge_context=knowledge_context,
use_model_generation=payload.use_model_generation,
max_items_per_group=payload.max_items_per_group,
cases_per_item=payload.cases_per_item,
max_focus_points=payload.max_focus_points,
max_llm_calls=payload.max_llm_calls,
)
return result

View File

@@ -1,26 +1,46 @@
from pathlib import Path
from typing import Any, List
import shutil
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
from sqlalchemy.orm import Session
from app.core.security import get_current_user
from app.db.session import get_db
from app.models.tooling import SRSExtraction, ToolJob
from app.models.knowledge import KnowledgeBase
from app.models.tooling import SRSExtraction, TestingGeneration, ToolJob
from app.models.user import User
from app.schemas.tooling import (
SRSToolCreateJobResponse,
SRSToolHistoryItem,
SRSToolJobStatusResponse,
SRSToolRequirementsSaveRequest,
SRSToolResultResponse,
TestingGenerationCreateRequest,
TestingGenerationCreateResponse,
TestingGenerationHistoryItem,
TestingGenerationJobStatusResponse,
TestingGenerationResultResponse,
TestingGenerationSaveRequest,
ToolDefinitionResponse,
)
from app.services.srs_job_service import (
build_srs_upload_path,
build_result_response,
delete_srs_job,
ensure_upload_path,
list_srs_history,
replace_requirements,
run_srs_job,
)
from app.services.testing_generation_service import (
build_testing_generation_response,
create_testing_generation,
delete_testing_generation,
list_testing_history,
)
from app.services.testing_generation_job_service import run_testing_generation_job
from app.tools.registry import ToolRegistry
from app.tools.srs_reqs_qwen import get_srs_tool
@@ -173,3 +193,223 @@ async def save_srs_requirements(
db.refresh(extraction)
return build_result_response(job, extraction)
@router.get("/srs/history", response_model=List[SRSToolHistoryItem])
async def get_srs_history(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
return list_srs_history(db, current_user.id)
@router.delete("/srs/jobs/{job_id}")
async def delete_srs_job_api(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
upload_path = build_srs_upload_path(job_id)
delete_srs_job(db, job)
if upload_path.exists():
shutil.rmtree(upload_path, ignore_errors=True)
return {"message": "删除成功"}
@router.post("/testing/generations", response_model=TestingGenerationResultResponse)
async def save_testing_generation(
payload: TestingGenerationSaveRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
if payload.source_job_id is not None:
source_job = (
db.query(ToolJob)
.filter(
ToolJob.id == payload.source_job_id,
ToolJob.user_id == current_user.id,
)
.first()
)
if not source_job:
raise HTTPException(status_code=404, detail="来源文件不存在")
if payload.knowledge_base_id is not None:
knowledge_base = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id == payload.knowledge_base_id,
KnowledgeBase.user_id == current_user.id,
)
.first()
)
if not knowledge_base:
raise HTTPException(status_code=404, detail="知识库不存在")
return create_testing_generation(db, current_user.id, payload)
@router.post("/testing/jobs", response_model=TestingGenerationCreateResponse)
async def create_testing_generation_job(
background_tasks: BackgroundTasks,
payload: TestingGenerationCreateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
if payload.source_job_id is not None:
source_job = (
db.query(ToolJob)
.filter(
ToolJob.id == payload.source_job_id,
ToolJob.user_id == current_user.id,
)
.first()
)
if not source_job:
raise HTTPException(status_code=404, detail="来源文件不存在")
if payload.knowledge_base_id is not None:
knowledge_base = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id == payload.knowledge_base_id,
KnowledgeBase.user_id == current_user.id,
)
.first()
)
if not knowledge_base:
raise HTTPException(status_code=404, detail="知识库不存在")
job = ToolJob(
user_id=current_user.id,
tool_name="testing.case_generator",
status="pending",
input_file_name=payload.source_document_name,
input_file_path="",
output_summary={
"source_document_name": payload.source_document_name,
"current_step": 0,
"total_steps": len(payload.requirements),
},
)
db.add(job)
db.commit()
db.refresh(job)
background_tasks.add_task(
run_testing_generation_job,
job.id,
payload.dict(),
)
return {
"job_id": job.id,
"status": job.status,
}
@router.get("/testing/jobs/{job_id}", response_model=TestingGenerationJobStatusResponse)
async def get_testing_generation_job_status(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(
ToolJob.id == job_id,
ToolJob.user_id == current_user.id,
)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
summary = job.output_summary or {}
return {
"job_id": job.id,
"tool_name": job.tool_name,
"status": job.status,
"error_message": job.error_message,
"started_at": job.started_at,
"completed_at": job.completed_at,
"source_document_name": summary.get("source_document_name"),
"current_step": summary.get("current_step"),
"total_steps": summary.get("total_steps"),
"current_requirement_id": summary.get("current_requirement_id"),
}
@router.get("/testing/history", response_model=List[TestingGenerationHistoryItem])
async def get_testing_history(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
return list_testing_history(db, current_user.id)
@router.get("/testing/jobs/{job_id}/result", response_model=TestingGenerationResultResponse)
async def get_testing_generation_result(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(
ToolJob.id == job_id,
ToolJob.user_id == current_user.id,
)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
generation = (
db.query(TestingGeneration)
.filter(TestingGeneration.job_id == job.id)
.first()
)
if not generation:
raise HTTPException(status_code=404, detail="任务结果不存在")
return build_testing_generation_response(job, generation)
@router.delete("/testing/jobs/{job_id}")
async def delete_testing_generation_api(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(
ToolJob.id == job_id,
ToolJob.user_id == current_user.id,
)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
generation = (
db.query(TestingGeneration)
.filter(TestingGeneration.job_id == job.id)
.first()
)
if not generation:
raise HTTPException(status_code=404, detail="任务结果不存在")
delete_testing_generation(db, job)
return {"message": "删除成功"}

View File

@@ -2,7 +2,7 @@ from .user import User
from .knowledge import KnowledgeBase, Document, DocumentChunk
from .chat import Chat, Message
from .api_key import APIKey
from .tooling import ToolJob, SRSExtraction, SRSRequirement
from .tooling import ToolJob, SRSExtraction, SRSRequirement, TestingGeneration
__all__ = [
"User",
@@ -15,4 +15,5 @@ __all__ = [
"ToolJob",
"SRSExtraction",
"SRSRequirement",
"TestingGeneration",
]

View File

@@ -29,6 +29,13 @@ class ToolJob(Base, TimestampMixin):
uselist=False,
cascade="all, delete-orphan",
)
testing_generation = relationship(
"TestingGeneration",
back_populates="job",
uselist=False,
cascade="all, delete-orphan",
foreign_keys="TestingGeneration.job_id",
)
class SRSExtraction(Base, TimestampMixin):
@@ -63,9 +70,14 @@ class SRSRequirement(Base, TimestampMixin):
priority = Column(String(16), nullable=False, default="")
acceptance_criteria = Column(JSON, nullable=False)
source_field = Column(String(255), nullable=False)
section_uid = Column(String(64), nullable=True)
section_number = Column(String(64), nullable=True)
section_title = Column(String(255), nullable=True)
requirement_type = Column(String(64), nullable=True)
interface_name = Column(String(255), nullable=True)
interface_type = Column(String(128), nullable=True)
data_source = Column(String(255), nullable=True)
data_destination = Column(String(255), nullable=True)
sort_order = Column(Integer, nullable=False, default=0)
extraction = relationship("SRSExtraction", back_populates="requirements")
@@ -74,3 +86,19 @@ class SRSRequirement(Base, TimestampMixin):
sa.UniqueConstraint("extraction_id", "requirement_uid", name="uq_srs_extraction_requirement_uid"),
sa.Index("idx_srs_requirements_extraction_sort", "extraction_id", "sort_order"),
)
class TestingGeneration(Base, TimestampMixin):
__tablename__ = "testing_generations"
id = Column(Integer, primary_key=True, index=True)
job_id = Column(Integer, ForeignKey("tool_jobs.id", ondelete="CASCADE"), nullable=False, unique=True)
source_job_id = Column(Integer, ForeignKey("tool_jobs.id"), nullable=True, index=True)
source_document_name = Column(String(255), nullable=False)
generated_at = Column(DateTime, default=datetime.utcnow, nullable=False)
total_requirements = Column(Integer, nullable=False, default=0)
knowledge_base_id = Column(Integer, ForeignKey("knowledge_bases.id"), nullable=True, index=True)
generated_file = Column(JSON, nullable=False)
job = relationship("ToolJob", back_populates="testing_generation", foreign_keys=[job_id])

View File

@@ -29,14 +29,18 @@ class SRSToolJobStatusResponse(BaseModel):
class SRSToolRequirementItem(BaseModel):
id: str
title: str
description: str
priority: str
acceptanceCriteria: List[str]
sourceField: str
sectionUid: Optional[str] = None
sectionNumber: Optional[str] = None
sectionTitle: Optional[str] = None
requirementType: Optional[str] = None
interfaceName: Optional[str] = None
interfaceType: Optional[str] = None
dataSource: Optional[str] = None
dataDestination: Optional[str] = None
sortOrder: int
@@ -46,7 +50,72 @@ class SRSToolResultResponse(BaseModel):
generatedAt: str
statistics: Dict[str, Any]
requirements: List[SRSToolRequirementItem]
rawOutput: Dict[str, Any]
class SRSToolHistoryItem(BaseModel):
jobId: int
documentName: str
generatedAt: str
totalRequirements: int
status: str
createdAt: str
updatedAt: str
class SRSToolRequirementsSaveRequest(BaseModel):
requirements: List[SRSToolRequirementItem]
class TestingGenerationSaveRequest(BaseModel):
source_job_id: Optional[int] = None
source_document_name: str
knowledge_base_id: Optional[int] = None
generated_file: Dict[str, Any]
class TestingGenerationCreateRequest(BaseModel):
source_job_id: Optional[int] = None
source_document_name: str
knowledge_base_id: Optional[int] = None
requirements: List[SRSToolRequirementItem]
class TestingGenerationCreateResponse(BaseModel):
job_id: int
status: str
class TestingGenerationJobStatusResponse(BaseModel):
job_id: int
tool_name: str
status: str
error_message: Optional[str] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
source_document_name: Optional[str] = None
current_step: Optional[int] = None
total_steps: Optional[int] = None
current_requirement_id: Optional[str] = None
class TestingGenerationResultResponse(BaseModel):
jobId: int
sourceJobId: Optional[int] = None
sourceDocumentName: str
generatedAt: str
totalRequirements: int
knowledgeBaseId: Optional[int] = None
generatedFile: Dict[str, Any]
class TestingGenerationHistoryItem(BaseModel):
jobId: int
sourceJobId: Optional[int] = None
sourceDocumentName: str
generatedAt: str
totalRequirements: int
knowledgeBaseId: Optional[int] = None
status: str
createdAt: str
updatedAt: str

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Tuple
from sqlalchemy.orm import Session
@@ -10,6 +11,45 @@ from app.db.session import SessionLocal
from app.models.tooling import SRSExtraction, SRSRequirement, ToolJob
from app.tools.srs_reqs_qwen import get_srs_tool
TYPE_TO_CHINESE = {
"functional": "功能需求",
"interface": "接口需求",
"performance": "性能需求",
"security": "安全需求",
"reliability": "可靠性需求",
"other": "其他需求",
}
def _build_internal_title(description: Any, fallback: str, index: int = 0) -> str:
text = str(description or "").strip()
if not text:
return fallback or f"需求项 {index + 1}"
for separator in ("", "", "\n", ";", "."):
if separator in text:
text = text.split(separator, 1)[0].strip()
break
text = text[:20].strip()
return text or fallback or f"需求项 {index + 1}"
def _normalize_requirement_type(value: Any) -> str:
text = str(value or "").strip()
if text in {"functional", "interface", "performance", "security", "reliability", "other"}:
return text
chinese_map = {
"接口需求": "interface",
"性能需求": "performance",
"安全需求": "security",
"可靠性需求": "reliability",
"其他需求": "other",
"功能需求": "functional",
}
return chinese_map.get(text, "functional")
def run_srs_job(job_id: int) -> None:
db = SessionLocal()
@@ -41,14 +81,19 @@ def run_srs_job(job_id: int) -> None:
requirement = SRSRequirement(
extraction_id=extraction.id,
requirement_uid=item["id"],
title=item.get("title") or item["id"],
title=_build_internal_title(item.get("description"), item["id"]),
description=item.get("description") or "",
priority=item.get("priority") or "",
acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"],
source_field=item.get("source_field") or "文档解析",
section_uid=item.get("section_uid"),
section_number=item.get("section_number"),
section_title=item.get("section_title"),
requirement_type=item.get("requirement_type"),
interface_name=item.get("interface_name"),
interface_type=item.get("interface_type"),
data_source=item.get("data_source"),
data_destination=item.get("data_destination"),
sort_order=int(item.get("sort_order") or 0),
)
db.add(requirement)
@@ -97,22 +142,8 @@ def ensure_upload_path(job_id: int, file_name: str) -> Path:
def build_result_response(job: ToolJob, extraction: SRSExtraction) -> Dict[str, Any]:
requirements: List[Dict[str, Any]] = []
for item in extraction.requirements:
requirements.append(
{
"id": item.requirement_uid,
"title": item.title,
"description": item.description,
"priority": item.priority,
"acceptanceCriteria": item.acceptance_criteria or [],
"sourceField": item.source_field,
"sectionNumber": item.section_number,
"sectionTitle": item.section_title,
"requirementType": item.requirement_type,
"sortOrder": item.sort_order,
}
)
requirements = [_requirement_model_to_payload(item) for item in extraction.requirements]
raw_output = _merge_updates_into_raw_output(extraction.raw_output, requirements, extraction.document_name)
return {
"jobId": job.id,
@@ -120,10 +151,12 @@ def build_result_response(job: ToolJob, extraction: SRSExtraction) -> Dict[str,
"generatedAt": extraction.generated_at.isoformat(),
"statistics": extraction.statistics or {},
"requirements": requirements,
"rawOutput": raw_output,
}
def replace_requirements(db: Session, extraction: SRSExtraction, updates: List[Dict[str, Any]]) -> None:
normalized_updates = [_normalize_update_payload(item, index) for index, item in enumerate(updates)]
existing = {
req.requirement_uid: req
for req in db.query(SRSRequirement)
@@ -132,51 +165,95 @@ def replace_requirements(db: Session, extraction: SRSExtraction, updates: List[D
}
seen_ids = set()
for index, item in enumerate(updates):
for index, item in enumerate(normalized_updates):
uid = item["id"]
seen_ids.add(uid)
req = existing.get(uid)
if req is None:
description = item.get("description") if item.get("description") is not None else ""
req = SRSRequirement(
extraction_id=extraction.id,
requirement_uid=uid,
title=item.get("title") or uid,
description=item.get("description") if item.get("description") is not None else "",
title=_build_internal_title(description, uid, int(item.get("sortOrder") or index)),
description=description,
priority=item.get("priority") or "",
acceptance_criteria=item.get("acceptanceCriteria") or ["待补充验收标准"],
source_field=item.get("sourceField") or "文档解析",
section_uid=item.get("sectionUid"),
section_number=item.get("sectionNumber"),
section_title=item.get("sectionTitle"),
requirement_type=item.get("requirementType"),
sort_order=int(item.get("sortOrder") or index),
interface_name=item.get("interfaceName"),
interface_type=item.get("interfaceType"),
data_source=item.get("dataSource"),
data_destination=item.get("dataDestination"),
sort_order=int(item.get("sortOrder") or 0),
)
db.add(req)
continue
req.title = item.get("title", req.title)
req.description = item.get("description", req.description)
req.title = _build_internal_title(req.description, req.requirement_uid, int(item.get("sortOrder") or index))
req.priority = item.get("priority", req.priority)
req.acceptance_criteria = item.get("acceptanceCriteria", req.acceptance_criteria)
req.source_field = item.get("sourceField", req.source_field)
req.section_uid = item.get("sectionUid", req.section_uid)
req.section_number = item.get("sectionNumber", req.section_number)
req.section_title = item.get("sectionTitle", req.section_title)
req.requirement_type = item.get("requirementType", req.requirement_type)
req.sort_order = int(item.get("sortOrder", index))
req.interface_name = item.get("interfaceName", req.interface_name)
req.interface_type = item.get("interfaceType", req.interface_type)
req.data_source = item.get("dataSource", req.data_source)
req.data_destination = item.get("dataDestination", req.data_destination)
req.sort_order = int(item.get("sortOrder", req.sort_order))
for uid, req in existing.items():
if uid not in seen_ids:
db.delete(req)
extraction.total_requirements = len(updates)
extraction.total_requirements = len(normalized_updates)
extraction.statistics = {
"total": len(updates),
"by_type": _count_requirement_types(updates),
}
extraction.raw_output = {
"document_name": extraction.document_name,
"generated_at": extraction.generated_at.isoformat(),
"requirements": updates,
"total": len(normalized_updates),
"by_type": _count_requirement_types(normalized_updates),
}
extraction.raw_output = _merge_updates_into_raw_output(
extraction.raw_output,
normalized_updates,
extraction.document_name,
)
def list_srs_history(db: Session, user_id: int) -> List[Dict[str, Any]]:
records: List[Tuple[ToolJob, SRSExtraction]] = (
db.query(ToolJob, SRSExtraction)
.join(SRSExtraction, SRSExtraction.job_id == ToolJob.id)
.filter(ToolJob.user_id == user_id)
.order_by(ToolJob.created_at.desc())
.all()
)
items: List[Dict[str, Any]] = []
for job, extraction in records:
items.append(
{
"jobId": job.id,
"documentName": extraction.document_name,
"generatedAt": extraction.generated_at.isoformat(),
"totalRequirements": extraction.total_requirements,
"status": job.status,
"createdAt": job.created_at.isoformat(),
"updatedAt": job.updated_at.isoformat(),
}
)
return items
def delete_srs_job(db: Session, job: ToolJob) -> None:
db.delete(job)
db.commit()
def build_srs_upload_path(job_id: int) -> Path:
return Path("uploads") / "srs_jobs" / str(job_id)
def _count_requirement_types(items: List[Dict[str, Any]]) -> Dict[str, int]:
@@ -185,3 +262,210 @@ def _count_requirement_types(items: List[Dict[str, Any]]) -> Dict[str, int]:
req_type = item.get("requirementType") or "functional"
stats[req_type] = stats.get(req_type, 0) + 1
return stats
def _requirement_model_to_payload(item: SRSRequirement) -> Dict[str, Any]:
return {
"id": item.requirement_uid,
"description": item.description,
"priority": item.priority,
"acceptanceCriteria": item.acceptance_criteria or [],
"sourceField": item.source_field,
"sectionUid": item.section_uid,
"sectionNumber": item.section_number,
"sectionTitle": item.section_title,
"requirementType": item.requirement_type,
"interfaceName": item.interface_name,
"interfaceType": item.interface_type,
"dataSource": item.data_source,
"dataDestination": item.data_destination,
"sortOrder": item.sort_order,
}
def _normalize_update_payload(item: Dict[str, Any], index: int) -> Dict[str, Any]:
requirement_type = _normalize_requirement_type(item.get("requirementType"))
normalized = {
"id": str(item.get("id") or f"REQ-{index + 1:03d}"),
"description": str(item.get("description") or "").strip(),
"priority": item.get("priority") or "",
"acceptanceCriteria": item.get("acceptanceCriteria") or ["待补充验收标准"],
"sourceField": item.get("sourceField") or "文档解析",
"sectionUid": item.get("sectionUid"),
"sectionNumber": item.get("sectionNumber"),
"sectionTitle": item.get("sectionTitle"),
"requirementType": requirement_type,
"interfaceName": item.get("interfaceName") or "",
"interfaceType": item.get("interfaceType") or "",
"dataSource": item.get("dataSource") or "",
"dataDestination": item.get("dataDestination") or "",
"sortOrder": int(item.get("sortOrder") or index),
}
if requirement_type != "interface":
normalized["interfaceName"] = ""
normalized["interfaceType"] = ""
normalized["dataSource"] = ""
normalized["dataDestination"] = ""
return normalized
def _merge_updates_into_raw_output(
raw_output: Dict[str, Any] | None,
updates: List[Dict[str, Any]],
document_name: str,
) -> Dict[str, Any]:
if not isinstance(raw_output, dict) or "需求内容" not in raw_output:
return _build_raw_output_from_flat(updates, document_name)
result = deepcopy(raw_output)
content = result.get("需求内容")
if not isinstance(content, dict):
return _build_raw_output_from_flat(updates, document_name)
updates_by_section = _group_updates_by_section(updates)
_rewrite_content_requirements(content, updates_by_section)
_append_unmatched_sections(content, updates_by_section)
_refresh_metadata(result, updates)
return result
def _group_updates_by_section(updates: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
grouped: Dict[str, List[Dict[str, Any]]] = {}
for item in updates:
key = _section_key(item.get("sectionUid"), item.get("sectionNumber"), item.get("sectionTitle"))
grouped.setdefault(key, []).append(item)
for values in grouped.values():
values.sort(key=lambda value: int(value.get("sortOrder") or 0))
return grouped
def _rewrite_content_requirements(content: Dict[str, Any], updates_by_section: Dict[str, List[Dict[str, Any]]]) -> None:
for section in content.values():
if not isinstance(section, dict):
continue
section_info = section.get("章节信息") or {}
section_key = _section_key(
section_info.get("章节UID"),
section_info.get("章节编号"),
section_info.get("章节标题"),
)
if section_key in updates_by_section:
section["需求列表"] = [_to_raw_requirement_item(item) for item in updates_by_section.pop(section_key)]
elif "需求列表" in section:
section["需求列表"] = []
children = section.get("子章节")
if isinstance(children, dict):
_rewrite_content_requirements(children, updates_by_section)
def _append_unmatched_sections(content: Dict[str, Any], updates_by_section: Dict[str, List[Dict[str, Any]]]) -> None:
if not updates_by_section:
return
orphan_key = "未归类章节"
orphan_section = content.get(orphan_key)
if not isinstance(orphan_section, dict):
orphan_section = {
"章节信息": {
"章节编号": "",
"章节标题": orphan_key,
"章节级别": 1,
},
"需求列表": [],
}
content[orphan_key] = orphan_section
all_reqs: List[Dict[str, Any]] = orphan_section.get("需求列表") or []
for values in updates_by_section.values():
for item in values:
all_reqs.append(_to_raw_requirement_item(item))
orphan_section["需求列表"] = all_reqs
updates_by_section.clear()
def _refresh_metadata(raw_output: Dict[str, Any], updates: List[Dict[str, Any]]) -> None:
metadata = raw_output.get("文档元数据")
if not isinstance(metadata, dict):
metadata = {}
raw_output["文档元数据"] = metadata
metadata["总需求数"] = len(updates)
metadata["生成时间"] = datetime.now().isoformat()
type_stats: Dict[str, int] = {}
for item in updates:
req_type = item.get("requirementType") or "functional"
cn_type = TYPE_TO_CHINESE.get(req_type, "其他需求")
type_stats[cn_type] = type_stats.get(cn_type, 0) + 1
metadata["需求类型统计"] = type_stats
def _to_raw_requirement_item(item: Dict[str, Any]) -> Dict[str, Any]:
req_type = item.get("requirementType") or "functional"
raw_item = {
"需求类型": TYPE_TO_CHINESE.get(req_type, "其他需求"),
"需求编号": item.get("id") or "",
"需求描述": item.get("description") or "",
"优先级": item.get("priority") or "",
}
if req_type == "interface":
raw_item["接口名称"] = item.get("interfaceName") or ""
raw_item["接口类型"] = item.get("interfaceType") or ""
raw_item["数据来源"] = item.get("dataSource") or ""
raw_item["数据目的地"] = item.get("dataDestination") or ""
return raw_item
def _build_raw_output_from_flat(updates: List[Dict[str, Any]], document_name: str) -> Dict[str, Any]:
grouped = _group_updates_by_section(updates)
content: Dict[str, Any] = {}
for key, values in grouped.items():
number, title = _parse_section_key(key)
section_title = title or "未归类章节"
display_key = f"{number} {section_title}".strip()
content[display_key] = {
"章节信息": {
"章节编号": number,
"章节标题": section_title,
"章节级别": max(len(number.split(".")), 1) if number else 1,
},
"需求列表": [_to_raw_requirement_item(item) for item in values],
}
raw_output = {
"文档元数据": {
"标题": document_name,
"生成时间": datetime.now().isoformat(),
"总需求数": len(updates),
"需求类型统计": {},
},
"需求内容": content,
}
_refresh_metadata(raw_output, updates)
return raw_output
def _section_key(section_uid: Any, section_number: Any, section_title: Any) -> str:
uid = str(section_uid or "").strip()
if uid:
return f"uid::{uid}"
number = str(section_number or "").strip()
title = str(section_title or "").strip()
return f"number::{number}::title::{title}"
def _parse_section_key(value: str) -> Tuple[str, str]:
if value.startswith("uid::"):
return "", "未归类章节"
number = ""
title = ""
parts = value.split("::")
if len(parts) >= 4:
number = parts[1]
title = parts[3]
return number, title

View File

@@ -0,0 +1,236 @@
from __future__ import annotations
import asyncio
from datetime import datetime
from typing import Any, Dict, List
from sqlalchemy.orm import Session
from app.core.config import settings
from app.db.session import SessionLocal
from app.models.knowledge import Document, KnowledgeBase
from app.models.tooling import TestingGeneration, ToolJob
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
from app.services.testing_pipeline import run_testing_pipeline
from app.services.vector_store import VectorStoreFactory
def _flatten_record(value: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
items: List[Dict[str, Any]] = []
for current in value.values():
items.extend(current)
return items
def _build_kb_vector_stores(db: Session, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create()
kb_vector_stores: List[Dict[str, Any]] = []
for kb in knowledge_bases:
documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all()
if not documents:
continue
store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb.id}",
embedding_function=embeddings,
)
kb_vector_stores.append({"kb_id": kb.id, "store": store})
return kb_vector_stores
def _resolve_knowledge_context(
db: Session,
*,
user_id: int,
requirement_text: str,
knowledge_base_id: int | None,
) -> str:
if knowledge_base_id is None:
return ""
try:
knowledge_bases = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id == knowledge_base_id,
KnowledgeBase.user_id == user_id,
)
.all()
)
kb_vector_stores = _build_kb_vector_stores(db, knowledge_bases)
if not kb_vector_stores:
return ""
retriever = MultiKBRetriever(
reranker_weight=settings.RERANKER_WEIGHT,
)
rows = asyncio.run(
retriever.retrieve(
query=requirement_text,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=16,
top_k=8,
)
)
if rows:
return format_retrieval_context(rows)
except Exception:
return ""
return ""
def _build_generated_requirement(req: Dict[str, Any], pipeline_result: Dict[str, Any]) -> Dict[str, Any]:
test_items = [
{
"id": item.get("id"),
"content": item.get("content"),
}
for item in _flatten_record(pipeline_result.get("test_items", {}))
]
test_cases = [
{
"id": item.get("id"),
"itemId": item.get("item_id"),
"testContent": item.get("test_content"),
"operationSteps": item.get("operation_steps", []),
"expectedResultPlaceholder": item.get("expected_result_placeholder"),
}
for item in _flatten_record(pipeline_result.get("test_cases", {}))
]
expected_results = [
{
"id": item.get("id"),
"caseId": item.get("case_id"),
"result": item.get("result"),
}
for item in _flatten_record(pipeline_result.get("expected_results", {}))
]
return {
**req,
"测试项": test_items,
"测试用例": test_cases,
"预期结果": expected_results,
}
def _mark_job_failed(job_id: int, error_message: str) -> None:
db = SessionLocal()
try:
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
if not job:
return
job.status = "failed"
job.completed_at = datetime.utcnow()
job.error_message = error_message[:2000]
db.commit()
finally:
db.close()
def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None:
db = SessionLocal()
try:
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
if not job:
return
requirements = payload.get("requirements") or []
source_document_name = str(payload.get("source_document_name") or job.input_file_name or "")
source_job_id = payload.get("source_job_id")
knowledge_base_id = payload.get("knowledge_base_id")
job.status = "processing"
job.started_at = datetime.utcnow()
job.error_message = None
job.output_summary = {
"source_document_name": source_document_name,
"current_step": 0,
"total_steps": len(requirements),
}
db.commit()
generated_requirements: List[Dict[str, Any]] = []
for index, req in enumerate(requirements):
req_id = str(req.get("id") or f"REQ-{index + 1:03d}")
job.output_summary = {
"source_document_name": source_document_name,
"current_step": index + 1,
"total_steps": len(requirements),
"current_requirement_id": req_id,
}
db.commit()
description = str(req.get("description") or "").strip()
if not description:
generated_requirements.append(
{
**req,
"测试项": [],
"测试用例": [],
"预期结果": [],
}
)
continue
knowledge_context = _resolve_knowledge_context(
db,
user_id=job.user_id,
requirement_text=description,
knowledge_base_id=knowledge_base_id,
)
pipeline_result = run_testing_pipeline(
user_requirement_text=description,
requirement_type_input=req.get("requirementType"),
debug=False,
knowledge_context=knowledge_context,
use_model_generation=True,
max_items_per_group=12,
cases_per_item=2,
max_focus_points=6,
max_llm_calls=10,
)
generated_requirements.append(_build_generated_requirement(req, pipeline_result))
generated_at = datetime.utcnow()
generated_file = {
"sourceDocument": source_document_name,
"sourceJobId": source_job_id,
"generatedAt": generated_at.isoformat(),
"totalRequirements": len(generated_requirements),
"knowledgeBaseId": knowledge_base_id,
"requirements": generated_requirements,
}
generation = TestingGeneration(
job_id=job.id,
source_job_id=source_job_id,
source_document_name=source_document_name,
generated_at=generated_at,
total_requirements=len(generated_requirements),
knowledge_base_id=knowledge_base_id,
generated_file=generated_file,
)
db.add(generation)
job.status = "completed"
job.completed_at = datetime.utcnow()
job.output_summary = {
"source_document_name": source_document_name,
"current_step": len(generated_requirements),
"total_steps": len(generated_requirements),
"total_requirements": len(generated_requirements),
"knowledge_base_id": knowledge_base_id,
}
db.commit()
except Exception as exc:
db.rollback()
_mark_job_failed(job_id, str(exc))
finally:
db.close()

View File

@@ -0,0 +1,111 @@
from datetime import datetime
from typing import Any, Dict, List, Tuple
from sqlalchemy.orm import Session
from app.models.tooling import TestingGeneration, ToolJob
from app.schemas.tooling import TestingGenerationSaveRequest
TESTING_TOOL_NAME = "testing.case_generator"
def _resolve_total_requirements(generated_file: Dict[str, Any]) -> int:
requirements = generated_file.get("requirements")
if isinstance(requirements, list):
return len(requirements)
total = generated_file.get("totalRequirements")
if isinstance(total, int) and total >= 0:
return total
return 0
def build_testing_generation_response(job: ToolJob, generation: TestingGeneration) -> Dict[str, Any]:
return {
"jobId": job.id,
"sourceJobId": generation.source_job_id,
"sourceDocumentName": generation.source_document_name,
"generatedAt": generation.generated_at.isoformat(),
"totalRequirements": generation.total_requirements,
"knowledgeBaseId": generation.knowledge_base_id,
"generatedFile": generation.generated_file or {},
}
def create_testing_generation(
db: Session,
user_id: int,
payload: TestingGenerationSaveRequest,
) -> Dict[str, Any]:
now = datetime.utcnow()
total_requirements = _resolve_total_requirements(payload.generated_file)
job = ToolJob(
user_id=user_id,
tool_name=TESTING_TOOL_NAME,
status="completed",
input_file_name=payload.source_document_name,
input_file_path="",
started_at=now,
completed_at=now,
output_summary={
"source_document_name": payload.source_document_name,
"total_requirements": total_requirements,
"knowledge_base_id": payload.knowledge_base_id,
},
)
db.add(job)
db.flush()
generation = TestingGeneration(
job_id=job.id,
source_job_id=payload.source_job_id,
source_document_name=payload.source_document_name,
generated_at=now,
total_requirements=total_requirements,
knowledge_base_id=payload.knowledge_base_id,
generated_file=payload.generated_file,
)
db.add(generation)
db.commit()
db.refresh(job)
db.refresh(generation)
return build_testing_generation_response(job, generation)
def list_testing_history(db: Session, user_id: int) -> List[Dict[str, Any]]:
rows: List[Tuple[ToolJob, TestingGeneration]] = (
db.query(ToolJob, TestingGeneration)
.join(TestingGeneration, TestingGeneration.job_id == ToolJob.id)
.filter(
ToolJob.user_id == user_id,
ToolJob.tool_name == TESTING_TOOL_NAME,
)
.order_by(ToolJob.created_at.desc())
.all()
)
items: List[Dict[str, Any]] = []
for job, generation in rows:
items.append(
{
"jobId": job.id,
"sourceJobId": generation.source_job_id,
"sourceDocumentName": generation.source_document_name,
"generatedAt": generation.generated_at.isoformat(),
"totalRequirements": generation.total_requirements,
"knowledgeBaseId": generation.knowledge_base_id,
"status": job.status,
"createdAt": job.created_at.isoformat(),
"updatedAt": job.updated_at.isoformat(),
}
)
return items
def delete_testing_generation(db: Session, job: ToolJob) -> None:
db.delete(job)
db.commit()

View File

@@ -4,7 +4,6 @@ from time import perf_counter
from typing import Any, Dict, List, Optional
from uuid import uuid4
from app.services.llm.llm_factory import LLMFactory
from app.services.testing_pipeline.tools import build_default_tool_chain
@@ -42,6 +41,8 @@ def run_testing_pipeline(
llm_model = None
if use_model_generation:
try:
from app.services.llm.llm_factory import LLMFactory
llm_model = LLMFactory.create(streaming=False)
except Exception:
llm_model = None

View File

@@ -137,6 +137,19 @@ class DocumentParser(ABC):
chinese_numbers = '一二三四五六七八九十百千万'
return text and all(c in chinese_numbers for c in text)
def _section_sort_key(self, section: 'Section') -> Tuple[int, List[int], str]:
number = (section.number or "").strip()
if number and re.match(r'^\d+(?:\.\d+)*$', number):
return (0, [int(part) for part in number.split('.')], section.title or "")
return (1, [section.level], section.title or "")
def _sort_sections_by_number(self, sections: List['Section']) -> List['Section']:
ordered = sorted(sections, key=self._section_sort_key)
for section in ordered:
if section.children:
section.children = self._sort_sections_by_number(section.children)
return ordered
class DocxParser(DocumentParser):
"""DOCX格式文档解析器"""
@@ -210,6 +223,7 @@ class DocxParser(DocumentParser):
# 为没有编号的章节自动生成编号
self._auto_number_sections(self.sections)
self.sections = self._sort_sections_by_number(self.sections)
logger.info(f"完成Docx解析提取{len(self.sections)}个顶级章节")
return self.sections
@@ -236,12 +250,17 @@ class DocxParser(DocumentParser):
"""解析标题,返回(编号, 标题, 级别)"""
style_name = paragraph.style.name if paragraph.style else ""
is_heading_style = style_name.lower().startswith('heading') if style_name else False
if self._is_calendar_line(text):
return None
# 数字编号标题
match = re.match(r'^(\d+(?:\.\d+)*)\s*[\.、]?\s*(.+)$', text)
match = re.match(r'^(\d+(?:\.\d+)*)\s*[\.、):\-_/]?\s*(.+)$', text)
if match and self._is_valid_heading(match.group(2)):
number = match.group(1)
title = match.group(2).strip()
if not self._is_valid_numbered_heading(number, title):
return None
level = len(number.split('.'))
return number, title, level
@@ -263,6 +282,31 @@ class DocxParser(DocumentParser):
return None
def _is_calendar_line(self, text: str) -> bool:
value = (text or "").strip().replace(" ", "")
return bool(re.match(r'^\d{4}\d{1,2}月(?:\d{1,2}日)?$', value))
def _is_valid_numbered_heading(self, number: str, title: str) -> bool:
parts = number.split('.')
if len(parts) > 6:
return False
first = int(parts[0])
if first < 1 or first > 30:
return False
for part in parts[1:]:
if int(part) > 30:
return False
if len(parts) == 1 and re.match(r'^年\d{1,2}月', title):
return False
if title and title[0].isdigit():
return False
return True
def _iter_block_items(self, parent):
"""按文档顺序迭代段落和表格"""
from docx.text.paragraph import Paragraph
@@ -356,6 +400,7 @@ class PDFParser(DocumentParser):
# 6. 为没有编号的章节自动生成编号
self._auto_number_sections(self.sections)
self.sections = self._sort_sections_by_number(self.sections)
logger.info(f"完成PDF解析提取{len(self.sections)}个顶级章节")
return self.sections
@@ -599,18 +644,6 @@ class PDFParser(DocumentParser):
level = len(number.split('.'))
top_level_number = int(number.split('.')[0])
# 顶级章节序号大幅跳跃通常是误识别如正文中的“8 表...”)。
if level == 1 and last_top_level_number and top_level_number > last_top_level_number + 1:
if line and not self._is_noise(line):
content_buffer.append(line)
continue
# 顶级章节编号倒退通常是正文枚举项被误识别如“1 综合监控...”)。
if level == 1 and last_top_level_number and top_level_number < last_top_level_number:
if line and not self._is_noise(line):
content_buffer.append(line)
continue
if level > 6:
continue
@@ -645,10 +678,6 @@ class PDFParser(DocumentParser):
for l in list(section_stack.keys()):
if l > level:
del section_stack[l]
# 若出现层级跳跃如1->3自动回退到父级+1。
if level > 1 and (level - 1) not in section_stack:
section.level = max(section_stack.keys()) if section_stack else 1
current_section = section
else:
@@ -670,7 +699,10 @@ class PDFParser(DocumentParser):
(章节编号, 章节标题) 或 None
"""
# 模式: "3.1 功能需求" / "3.1.2 电场..."
match = re.match(r'^(\d+(?:\.\d+)*)[\s、.)]*(.+)$', line)
if self._is_calendar_line(line):
return None
match = re.match(r'^(\d+(?:\.\d+)*)[\s、.):\-_/]*(.+)$', line)
if not match:
return None
@@ -692,7 +724,7 @@ class PDFParser(DocumentParser):
# 检查子部分是否合理
for part in parts[1:]:
if int(part) > 20:
if int(part) > 30:
return None
# 避免重复
@@ -747,6 +779,10 @@ class PDFParser(DocumentParser):
return (number, title)
def _is_calendar_line(self, text: str) -> bool:
value = (text or "").strip().replace(" ", "")
return bool(re.match(r'^\d{4}\d{1,2}月(?:\d{1,2}日)?$', value))
def _looks_like_statement(self, title: str) -> bool:
"""判断标题是否更像正文语句而非章节名。"""
if not title:

View File

@@ -51,6 +51,8 @@ class SRSTool:
"other": "",
}
UNKNOWN_INTERFACE_VALUES = {"", "未知", "unknown", "n/a", "-", "--", "", "none", "null"}
def __init__(self) -> None:
ToolRegistry.register(self.DEFINITION)
@@ -90,24 +92,78 @@ class SRSTool:
normalized: List[Dict[str, Any]] = []
for index, req in enumerate(extracted, start=1):
description = (req.description or "").strip()
title = description[:40] if description else f"需求项 {index}"
title = self._build_short_title(description, index)
requirement_type = self._normalize_requirement_type(
req_type=getattr(req, "type", "functional"),
interface_name=getattr(req, "interface_name", ""),
interface_type=getattr(req, "interface_type", ""),
data_source=getattr(req, "source", ""),
data_destination=getattr(req, "destination", ""),
)
source_field = f"{req.section_number} {req.section_title}".strip() or "文档解析"
normalized.append(
{
"id": req.id,
"title": title,
"description": description,
"priority": self.PRIORITY_BY_TYPE.get(req.type, ""),
"priority": "",
"acceptance_criteria": [description] if description else ["待补充验收标准"],
"source_field": source_field,
"section_uid": req.section_uid,
"section_number": req.section_number,
"section_title": req.section_title,
"requirement_type": req.type,
"requirement_type": requirement_type,
"interface_name": req.interface_name if requirement_type == "interface" else "",
"interface_type": req.interface_type if requirement_type == "interface" else "",
"data_source": req.source if requirement_type == "interface" else "",
"data_destination": req.destination if requirement_type == "interface" else "",
"sort_order": index,
}
)
return normalized
def _normalize_requirement_type(
self,
req_type: Any,
interface_name: Any,
interface_type: Any,
data_source: Any,
data_destination: Any,
) -> str:
raw_type = str(req_type or "").strip()
mapping = {
"功能需求": "functional",
"接口需求": "interface",
"性能需求": "performance",
"安全需求": "security",
"可靠性需求": "reliability",
"其他需求": "other",
}
normalized_type = mapping.get(raw_type, raw_type)
if normalized_type not in self.PRIORITY_BY_TYPE:
normalized_type = "functional"
fields = [interface_name, interface_type, data_source, data_destination]
has_interface_fields = any(
str(value or "").strip().lower() not in self.UNKNOWN_INTERFACE_VALUES for value in fields
)
if normalized_type == "interface" or has_interface_fields:
return "interface"
return normalized_type
def _build_short_title(self, description: str, index: int) -> str:
text = (description or "").strip()
if not text:
return f"需求项 {index}"
for separator in ("", "", "\n", ";", "."):
if separator in text:
text = text.split(separator, 1)[0].strip()
break
if len(text) <= 20:
return text
return f"{text[:20].rstrip()}"
def _load_config(self) -> Dict[str, Any]:
config_path = Path(__file__).with_name("default_config.yaml")
if config_path.exists():