完善skills;测试用例生成页面功能初步实现
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
@@ -15,6 +17,8 @@ from app.services.testing_pipeline import run_testing_pipeline
|
||||
from app.services.vector_store import VectorStoreFactory
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
MODEL_PIPELINE_TIMEOUT_SECONDS = 300
|
||||
|
||||
|
||||
async def _build_kb_vector_stores(db: Session, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
|
||||
@@ -47,38 +51,75 @@ async def generate_testing_content(
|
||||
|
||||
knowledge_context = (payload.knowledge_context or "").strip()
|
||||
if payload.knowledge_base_ids:
|
||||
knowledge_bases = (
|
||||
db.query(KnowledgeBase)
|
||||
.filter(
|
||||
KnowledgeBase.id.in_(payload.knowledge_base_ids),
|
||||
KnowledgeBase.user_id == current_user.id,
|
||||
try:
|
||||
knowledge_bases = (
|
||||
db.query(KnowledgeBase)
|
||||
.filter(
|
||||
KnowledgeBase.id.in_(payload.knowledge_base_ids),
|
||||
KnowledgeBase.user_id == current_user.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
.all()
|
||||
|
||||
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
|
||||
if kb_vector_stores:
|
||||
retriever = MultiKBRetriever(
|
||||
reranker_weight=settings.RERANKER_WEIGHT,
|
||||
)
|
||||
retrieval_rows = await retriever.retrieve(
|
||||
query=payload.requirement_text,
|
||||
kb_vector_stores=kb_vector_stores,
|
||||
fetch_k_per_kb=max(12, payload.retrieval_top_k * 2),
|
||||
top_k=payload.retrieval_top_k,
|
||||
)
|
||||
if retrieval_rows:
|
||||
knowledge_context = format_retrieval_context(retrieval_rows)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Testing generation retrieval fallback triggered for user=%s knowledge_base_ids=%s: %s",
|
||||
current_user.id,
|
||||
payload.knowledge_base_ids,
|
||||
exc,
|
||||
)
|
||||
|
||||
pipeline_kwargs = {
|
||||
"user_requirement_text": payload.requirement_text,
|
||||
"requirement_type_input": payload.requirement_type,
|
||||
"debug": payload.debug,
|
||||
"knowledge_context": knowledge_context,
|
||||
"use_model_generation": payload.use_model_generation,
|
||||
"max_items_per_group": payload.max_items_per_group,
|
||||
"cases_per_item": payload.cases_per_item,
|
||||
"max_focus_points": payload.max_focus_points,
|
||||
"max_llm_calls": payload.max_llm_calls,
|
||||
}
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
asyncio.to_thread(run_testing_pipeline, **pipeline_kwargs),
|
||||
timeout=MODEL_PIPELINE_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError as exc:
|
||||
logger.exception(
|
||||
"Testing pipeline timed out for user=%s use_model_generation=%s after %s seconds",
|
||||
current_user.id,
|
||||
payload.use_model_generation,
|
||||
MODEL_PIPELINE_TIMEOUT_SECONDS,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail=f"LLM generation timed out after {MODEL_PIPELINE_TIMEOUT_SECONDS} seconds",
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Testing pipeline failed for user=%s use_model_generation=%s: %s",
|
||||
current_user.id,
|
||||
payload.use_model_generation,
|
||||
exc,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"LLM generation failed: {exc}",
|
||||
) from exc
|
||||
|
||||
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
|
||||
if kb_vector_stores:
|
||||
retriever = MultiKBRetriever(
|
||||
reranker_weight=settings.RERANKER_WEIGHT,
|
||||
)
|
||||
retrieval_rows = await retriever.retrieve(
|
||||
query=payload.requirement_text,
|
||||
kb_vector_stores=kb_vector_stores,
|
||||
fetch_k_per_kb=max(12, payload.retrieval_top_k * 2),
|
||||
top_k=payload.retrieval_top_k,
|
||||
)
|
||||
if retrieval_rows:
|
||||
knowledge_context = format_retrieval_context(retrieval_rows)
|
||||
|
||||
result = run_testing_pipeline(
|
||||
user_requirement_text=payload.requirement_text,
|
||||
requirement_type_input=payload.requirement_type,
|
||||
debug=payload.debug,
|
||||
knowledge_context=knowledge_context,
|
||||
use_model_generation=payload.use_model_generation,
|
||||
max_items_per_group=payload.max_items_per_group,
|
||||
cases_per_item=payload.cases_per_item,
|
||||
max_focus_points=payload.max_focus_points,
|
||||
max_llm_calls=payload.max_llm_calls,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -1,26 +1,46 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
|
||||
import shutil
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.db.session import get_db
|
||||
from app.models.tooling import SRSExtraction, ToolJob
|
||||
from app.models.knowledge import KnowledgeBase
|
||||
from app.models.tooling import SRSExtraction, TestingGeneration, ToolJob
|
||||
from app.models.user import User
|
||||
from app.schemas.tooling import (
|
||||
SRSToolCreateJobResponse,
|
||||
SRSToolHistoryItem,
|
||||
SRSToolJobStatusResponse,
|
||||
SRSToolRequirementsSaveRequest,
|
||||
SRSToolResultResponse,
|
||||
TestingGenerationCreateRequest,
|
||||
TestingGenerationCreateResponse,
|
||||
TestingGenerationHistoryItem,
|
||||
TestingGenerationJobStatusResponse,
|
||||
TestingGenerationResultResponse,
|
||||
TestingGenerationSaveRequest,
|
||||
ToolDefinitionResponse,
|
||||
)
|
||||
from app.services.srs_job_service import (
|
||||
build_srs_upload_path,
|
||||
build_result_response,
|
||||
delete_srs_job,
|
||||
ensure_upload_path,
|
||||
list_srs_history,
|
||||
replace_requirements,
|
||||
run_srs_job,
|
||||
)
|
||||
from app.services.testing_generation_service import (
|
||||
build_testing_generation_response,
|
||||
create_testing_generation,
|
||||
delete_testing_generation,
|
||||
list_testing_history,
|
||||
)
|
||||
from app.services.testing_generation_job_service import run_testing_generation_job
|
||||
from app.tools.registry import ToolRegistry
|
||||
from app.tools.srs_reqs_qwen import get_srs_tool
|
||||
|
||||
@@ -173,3 +193,223 @@ async def save_srs_requirements(
|
||||
db.refresh(extraction)
|
||||
|
||||
return build_result_response(job, extraction)
|
||||
|
||||
|
||||
@router.get("/srs/history", response_model=List[SRSToolHistoryItem])
|
||||
async def get_srs_history(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
return list_srs_history(db, current_user.id)
|
||||
|
||||
|
||||
@router.delete("/srs/jobs/{job_id}")
|
||||
async def delete_srs_job_api(
|
||||
job_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
job = (
|
||||
db.query(ToolJob)
|
||||
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
upload_path = build_srs_upload_path(job_id)
|
||||
delete_srs_job(db, job)
|
||||
|
||||
if upload_path.exists():
|
||||
shutil.rmtree(upload_path, ignore_errors=True)
|
||||
|
||||
return {"message": "删除成功"}
|
||||
|
||||
|
||||
@router.post("/testing/generations", response_model=TestingGenerationResultResponse)
|
||||
async def save_testing_generation(
|
||||
payload: TestingGenerationSaveRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
if payload.source_job_id is not None:
|
||||
source_job = (
|
||||
db.query(ToolJob)
|
||||
.filter(
|
||||
ToolJob.id == payload.source_job_id,
|
||||
ToolJob.user_id == current_user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not source_job:
|
||||
raise HTTPException(status_code=404, detail="来源文件不存在")
|
||||
|
||||
if payload.knowledge_base_id is not None:
|
||||
knowledge_base = (
|
||||
db.query(KnowledgeBase)
|
||||
.filter(
|
||||
KnowledgeBase.id == payload.knowledge_base_id,
|
||||
KnowledgeBase.user_id == current_user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not knowledge_base:
|
||||
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||
|
||||
return create_testing_generation(db, current_user.id, payload)
|
||||
|
||||
|
||||
@router.post("/testing/jobs", response_model=TestingGenerationCreateResponse)
|
||||
async def create_testing_generation_job(
|
||||
background_tasks: BackgroundTasks,
|
||||
payload: TestingGenerationCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
if payload.source_job_id is not None:
|
||||
source_job = (
|
||||
db.query(ToolJob)
|
||||
.filter(
|
||||
ToolJob.id == payload.source_job_id,
|
||||
ToolJob.user_id == current_user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not source_job:
|
||||
raise HTTPException(status_code=404, detail="来源文件不存在")
|
||||
|
||||
if payload.knowledge_base_id is not None:
|
||||
knowledge_base = (
|
||||
db.query(KnowledgeBase)
|
||||
.filter(
|
||||
KnowledgeBase.id == payload.knowledge_base_id,
|
||||
KnowledgeBase.user_id == current_user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not knowledge_base:
|
||||
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||
|
||||
job = ToolJob(
|
||||
user_id=current_user.id,
|
||||
tool_name="testing.case_generator",
|
||||
status="pending",
|
||||
input_file_name=payload.source_document_name,
|
||||
input_file_path="",
|
||||
output_summary={
|
||||
"source_document_name": payload.source_document_name,
|
||||
"current_step": 0,
|
||||
"total_steps": len(payload.requirements),
|
||||
},
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
background_tasks.add_task(
|
||||
run_testing_generation_job,
|
||||
job.id,
|
||||
payload.dict(),
|
||||
)
|
||||
|
||||
return {
|
||||
"job_id": job.id,
|
||||
"status": job.status,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/testing/jobs/{job_id}", response_model=TestingGenerationJobStatusResponse)
|
||||
async def get_testing_generation_job_status(
|
||||
job_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
job = (
|
||||
db.query(ToolJob)
|
||||
.filter(
|
||||
ToolJob.id == job_id,
|
||||
ToolJob.user_id == current_user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
summary = job.output_summary or {}
|
||||
return {
|
||||
"job_id": job.id,
|
||||
"tool_name": job.tool_name,
|
||||
"status": job.status,
|
||||
"error_message": job.error_message,
|
||||
"started_at": job.started_at,
|
||||
"completed_at": job.completed_at,
|
||||
"source_document_name": summary.get("source_document_name"),
|
||||
"current_step": summary.get("current_step"),
|
||||
"total_steps": summary.get("total_steps"),
|
||||
"current_requirement_id": summary.get("current_requirement_id"),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/testing/history", response_model=List[TestingGenerationHistoryItem])
|
||||
async def get_testing_history(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
return list_testing_history(db, current_user.id)
|
||||
|
||||
|
||||
@router.get("/testing/jobs/{job_id}/result", response_model=TestingGenerationResultResponse)
|
||||
async def get_testing_generation_result(
|
||||
job_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
job = (
|
||||
db.query(ToolJob)
|
||||
.filter(
|
||||
ToolJob.id == job_id,
|
||||
ToolJob.user_id == current_user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
generation = (
|
||||
db.query(TestingGeneration)
|
||||
.filter(TestingGeneration.job_id == job.id)
|
||||
.first()
|
||||
)
|
||||
if not generation:
|
||||
raise HTTPException(status_code=404, detail="任务结果不存在")
|
||||
|
||||
return build_testing_generation_response(job, generation)
|
||||
|
||||
|
||||
@router.delete("/testing/jobs/{job_id}")
|
||||
async def delete_testing_generation_api(
|
||||
job_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
job = (
|
||||
db.query(ToolJob)
|
||||
.filter(
|
||||
ToolJob.id == job_id,
|
||||
ToolJob.user_id == current_user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
generation = (
|
||||
db.query(TestingGeneration)
|
||||
.filter(TestingGeneration.job_id == job.id)
|
||||
.first()
|
||||
)
|
||||
if not generation:
|
||||
raise HTTPException(status_code=404, detail="任务结果不存在")
|
||||
|
||||
delete_testing_generation(db, job)
|
||||
return {"message": "删除成功"}
|
||||
|
||||
@@ -2,7 +2,7 @@ from .user import User
|
||||
from .knowledge import KnowledgeBase, Document, DocumentChunk
|
||||
from .chat import Chat, Message
|
||||
from .api_key import APIKey
|
||||
from .tooling import ToolJob, SRSExtraction, SRSRequirement
|
||||
from .tooling import ToolJob, SRSExtraction, SRSRequirement, TestingGeneration
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
@@ -15,4 +15,5 @@ __all__ = [
|
||||
"ToolJob",
|
||||
"SRSExtraction",
|
||||
"SRSRequirement",
|
||||
"TestingGeneration",
|
||||
]
|
||||
|
||||
@@ -29,6 +29,13 @@ class ToolJob(Base, TimestampMixin):
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
testing_generation = relationship(
|
||||
"TestingGeneration",
|
||||
back_populates="job",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="TestingGeneration.job_id",
|
||||
)
|
||||
|
||||
|
||||
class SRSExtraction(Base, TimestampMixin):
|
||||
@@ -63,9 +70,14 @@ class SRSRequirement(Base, TimestampMixin):
|
||||
priority = Column(String(16), nullable=False, default="中")
|
||||
acceptance_criteria = Column(JSON, nullable=False)
|
||||
source_field = Column(String(255), nullable=False)
|
||||
section_uid = Column(String(64), nullable=True)
|
||||
section_number = Column(String(64), nullable=True)
|
||||
section_title = Column(String(255), nullable=True)
|
||||
requirement_type = Column(String(64), nullable=True)
|
||||
interface_name = Column(String(255), nullable=True)
|
||||
interface_type = Column(String(128), nullable=True)
|
||||
data_source = Column(String(255), nullable=True)
|
||||
data_destination = Column(String(255), nullable=True)
|
||||
sort_order = Column(Integer, nullable=False, default=0)
|
||||
|
||||
extraction = relationship("SRSExtraction", back_populates="requirements")
|
||||
@@ -74,3 +86,19 @@ class SRSRequirement(Base, TimestampMixin):
|
||||
sa.UniqueConstraint("extraction_id", "requirement_uid", name="uq_srs_extraction_requirement_uid"),
|
||||
sa.Index("idx_srs_requirements_extraction_sort", "extraction_id", "sort_order"),
|
||||
)
|
||||
|
||||
|
||||
class TestingGeneration(Base, TimestampMixin):
|
||||
__tablename__ = "testing_generations"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
job_id = Column(Integer, ForeignKey("tool_jobs.id", ondelete="CASCADE"), nullable=False, unique=True)
|
||||
source_job_id = Column(Integer, ForeignKey("tool_jobs.id"), nullable=True, index=True)
|
||||
source_document_name = Column(String(255), nullable=False)
|
||||
generated_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
total_requirements = Column(Integer, nullable=False, default=0)
|
||||
knowledge_base_id = Column(Integer, ForeignKey("knowledge_bases.id"), nullable=True, index=True)
|
||||
generated_file = Column(JSON, nullable=False)
|
||||
|
||||
job = relationship("ToolJob", back_populates="testing_generation", foreign_keys=[job_id])
|
||||
|
||||
|
||||
@@ -29,14 +29,18 @@ class SRSToolJobStatusResponse(BaseModel):
|
||||
|
||||
class SRSToolRequirementItem(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
description: str
|
||||
priority: str
|
||||
acceptanceCriteria: List[str]
|
||||
sourceField: str
|
||||
sectionUid: Optional[str] = None
|
||||
sectionNumber: Optional[str] = None
|
||||
sectionTitle: Optional[str] = None
|
||||
requirementType: Optional[str] = None
|
||||
interfaceName: Optional[str] = None
|
||||
interfaceType: Optional[str] = None
|
||||
dataSource: Optional[str] = None
|
||||
dataDestination: Optional[str] = None
|
||||
sortOrder: int
|
||||
|
||||
|
||||
@@ -46,7 +50,72 @@ class SRSToolResultResponse(BaseModel):
|
||||
generatedAt: str
|
||||
statistics: Dict[str, Any]
|
||||
requirements: List[SRSToolRequirementItem]
|
||||
rawOutput: Dict[str, Any]
|
||||
|
||||
|
||||
class SRSToolHistoryItem(BaseModel):
|
||||
jobId: int
|
||||
documentName: str
|
||||
generatedAt: str
|
||||
totalRequirements: int
|
||||
status: str
|
||||
createdAt: str
|
||||
updatedAt: str
|
||||
|
||||
|
||||
class SRSToolRequirementsSaveRequest(BaseModel):
|
||||
requirements: List[SRSToolRequirementItem]
|
||||
|
||||
|
||||
class TestingGenerationSaveRequest(BaseModel):
|
||||
source_job_id: Optional[int] = None
|
||||
source_document_name: str
|
||||
knowledge_base_id: Optional[int] = None
|
||||
generated_file: Dict[str, Any]
|
||||
|
||||
|
||||
class TestingGenerationCreateRequest(BaseModel):
|
||||
source_job_id: Optional[int] = None
|
||||
source_document_name: str
|
||||
knowledge_base_id: Optional[int] = None
|
||||
requirements: List[SRSToolRequirementItem]
|
||||
|
||||
|
||||
class TestingGenerationCreateResponse(BaseModel):
|
||||
job_id: int
|
||||
status: str
|
||||
|
||||
|
||||
class TestingGenerationJobStatusResponse(BaseModel):
|
||||
job_id: int
|
||||
tool_name: str
|
||||
status: str
|
||||
error_message: Optional[str] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
source_document_name: Optional[str] = None
|
||||
current_step: Optional[int] = None
|
||||
total_steps: Optional[int] = None
|
||||
current_requirement_id: Optional[str] = None
|
||||
|
||||
|
||||
class TestingGenerationResultResponse(BaseModel):
|
||||
jobId: int
|
||||
sourceJobId: Optional[int] = None
|
||||
sourceDocumentName: str
|
||||
generatedAt: str
|
||||
totalRequirements: int
|
||||
knowledgeBaseId: Optional[int] = None
|
||||
generatedFile: Dict[str, Any]
|
||||
|
||||
|
||||
class TestingGenerationHistoryItem(BaseModel):
|
||||
jobId: int
|
||||
sourceJobId: Optional[int] = None
|
||||
sourceDocumentName: str
|
||||
generatedAt: str
|
||||
totalRequirements: int
|
||||
knowledgeBaseId: Optional[int] = None
|
||||
status: str
|
||||
createdAt: str
|
||||
updatedAt: str
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -10,6 +11,45 @@ from app.db.session import SessionLocal
|
||||
from app.models.tooling import SRSExtraction, SRSRequirement, ToolJob
|
||||
from app.tools.srs_reqs_qwen import get_srs_tool
|
||||
|
||||
TYPE_TO_CHINESE = {
|
||||
"functional": "功能需求",
|
||||
"interface": "接口需求",
|
||||
"performance": "性能需求",
|
||||
"security": "安全需求",
|
||||
"reliability": "可靠性需求",
|
||||
"other": "其他需求",
|
||||
}
|
||||
|
||||
|
||||
def _build_internal_title(description: Any, fallback: str, index: int = 0) -> str:
|
||||
text = str(description or "").strip()
|
||||
if not text:
|
||||
return fallback or f"需求项 {index + 1}"
|
||||
|
||||
for separator in ("。", ";", "\n", ";", "."):
|
||||
if separator in text:
|
||||
text = text.split(separator, 1)[0].strip()
|
||||
break
|
||||
|
||||
text = text[:20].strip()
|
||||
return text or fallback or f"需求项 {index + 1}"
|
||||
|
||||
|
||||
def _normalize_requirement_type(value: Any) -> str:
|
||||
text = str(value or "").strip()
|
||||
if text in {"functional", "interface", "performance", "security", "reliability", "other"}:
|
||||
return text
|
||||
|
||||
chinese_map = {
|
||||
"接口需求": "interface",
|
||||
"性能需求": "performance",
|
||||
"安全需求": "security",
|
||||
"可靠性需求": "reliability",
|
||||
"其他需求": "other",
|
||||
"功能需求": "functional",
|
||||
}
|
||||
return chinese_map.get(text, "functional")
|
||||
|
||||
|
||||
def run_srs_job(job_id: int) -> None:
|
||||
db = SessionLocal()
|
||||
@@ -41,14 +81,19 @@ def run_srs_job(job_id: int) -> None:
|
||||
requirement = SRSRequirement(
|
||||
extraction_id=extraction.id,
|
||||
requirement_uid=item["id"],
|
||||
title=item.get("title") or item["id"],
|
||||
title=_build_internal_title(item.get("description"), item["id"]),
|
||||
description=item.get("description") or "",
|
||||
priority=item.get("priority") or "中",
|
||||
acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"],
|
||||
source_field=item.get("source_field") or "文档解析",
|
||||
section_uid=item.get("section_uid"),
|
||||
section_number=item.get("section_number"),
|
||||
section_title=item.get("section_title"),
|
||||
requirement_type=item.get("requirement_type"),
|
||||
interface_name=item.get("interface_name"),
|
||||
interface_type=item.get("interface_type"),
|
||||
data_source=item.get("data_source"),
|
||||
data_destination=item.get("data_destination"),
|
||||
sort_order=int(item.get("sort_order") or 0),
|
||||
)
|
||||
db.add(requirement)
|
||||
@@ -97,22 +142,8 @@ def ensure_upload_path(job_id: int, file_name: str) -> Path:
|
||||
|
||||
|
||||
def build_result_response(job: ToolJob, extraction: SRSExtraction) -> Dict[str, Any]:
|
||||
requirements: List[Dict[str, Any]] = []
|
||||
for item in extraction.requirements:
|
||||
requirements.append(
|
||||
{
|
||||
"id": item.requirement_uid,
|
||||
"title": item.title,
|
||||
"description": item.description,
|
||||
"priority": item.priority,
|
||||
"acceptanceCriteria": item.acceptance_criteria or [],
|
||||
"sourceField": item.source_field,
|
||||
"sectionNumber": item.section_number,
|
||||
"sectionTitle": item.section_title,
|
||||
"requirementType": item.requirement_type,
|
||||
"sortOrder": item.sort_order,
|
||||
}
|
||||
)
|
||||
requirements = [_requirement_model_to_payload(item) for item in extraction.requirements]
|
||||
raw_output = _merge_updates_into_raw_output(extraction.raw_output, requirements, extraction.document_name)
|
||||
|
||||
return {
|
||||
"jobId": job.id,
|
||||
@@ -120,10 +151,12 @@ def build_result_response(job: ToolJob, extraction: SRSExtraction) -> Dict[str,
|
||||
"generatedAt": extraction.generated_at.isoformat(),
|
||||
"statistics": extraction.statistics or {},
|
||||
"requirements": requirements,
|
||||
"rawOutput": raw_output,
|
||||
}
|
||||
|
||||
|
||||
def replace_requirements(db: Session, extraction: SRSExtraction, updates: List[Dict[str, Any]]) -> None:
|
||||
normalized_updates = [_normalize_update_payload(item, index) for index, item in enumerate(updates)]
|
||||
existing = {
|
||||
req.requirement_uid: req
|
||||
for req in db.query(SRSRequirement)
|
||||
@@ -132,51 +165,95 @@ def replace_requirements(db: Session, extraction: SRSExtraction, updates: List[D
|
||||
}
|
||||
seen_ids = set()
|
||||
|
||||
for index, item in enumerate(updates):
|
||||
for index, item in enumerate(normalized_updates):
|
||||
uid = item["id"]
|
||||
seen_ids.add(uid)
|
||||
req = existing.get(uid)
|
||||
if req is None:
|
||||
description = item.get("description") if item.get("description") is not None else ""
|
||||
req = SRSRequirement(
|
||||
extraction_id=extraction.id,
|
||||
requirement_uid=uid,
|
||||
title=item.get("title") or uid,
|
||||
description=item.get("description") if item.get("description") is not None else "",
|
||||
title=_build_internal_title(description, uid, int(item.get("sortOrder") or index)),
|
||||
description=description,
|
||||
priority=item.get("priority") or "中",
|
||||
acceptance_criteria=item.get("acceptanceCriteria") or ["待补充验收标准"],
|
||||
source_field=item.get("sourceField") or "文档解析",
|
||||
section_uid=item.get("sectionUid"),
|
||||
section_number=item.get("sectionNumber"),
|
||||
section_title=item.get("sectionTitle"),
|
||||
requirement_type=item.get("requirementType"),
|
||||
sort_order=int(item.get("sortOrder") or index),
|
||||
interface_name=item.get("interfaceName"),
|
||||
interface_type=item.get("interfaceType"),
|
||||
data_source=item.get("dataSource"),
|
||||
data_destination=item.get("dataDestination"),
|
||||
sort_order=int(item.get("sortOrder") or 0),
|
||||
)
|
||||
db.add(req)
|
||||
continue
|
||||
|
||||
req.title = item.get("title", req.title)
|
||||
req.description = item.get("description", req.description)
|
||||
req.title = _build_internal_title(req.description, req.requirement_uid, int(item.get("sortOrder") or index))
|
||||
req.priority = item.get("priority", req.priority)
|
||||
req.acceptance_criteria = item.get("acceptanceCriteria", req.acceptance_criteria)
|
||||
req.source_field = item.get("sourceField", req.source_field)
|
||||
req.section_uid = item.get("sectionUid", req.section_uid)
|
||||
req.section_number = item.get("sectionNumber", req.section_number)
|
||||
req.section_title = item.get("sectionTitle", req.section_title)
|
||||
req.requirement_type = item.get("requirementType", req.requirement_type)
|
||||
req.sort_order = int(item.get("sortOrder", index))
|
||||
req.interface_name = item.get("interfaceName", req.interface_name)
|
||||
req.interface_type = item.get("interfaceType", req.interface_type)
|
||||
req.data_source = item.get("dataSource", req.data_source)
|
||||
req.data_destination = item.get("dataDestination", req.data_destination)
|
||||
req.sort_order = int(item.get("sortOrder", req.sort_order))
|
||||
|
||||
for uid, req in existing.items():
|
||||
if uid not in seen_ids:
|
||||
db.delete(req)
|
||||
|
||||
extraction.total_requirements = len(updates)
|
||||
extraction.total_requirements = len(normalized_updates)
|
||||
extraction.statistics = {
|
||||
"total": len(updates),
|
||||
"by_type": _count_requirement_types(updates),
|
||||
}
|
||||
extraction.raw_output = {
|
||||
"document_name": extraction.document_name,
|
||||
"generated_at": extraction.generated_at.isoformat(),
|
||||
"requirements": updates,
|
||||
"total": len(normalized_updates),
|
||||
"by_type": _count_requirement_types(normalized_updates),
|
||||
}
|
||||
extraction.raw_output = _merge_updates_into_raw_output(
|
||||
extraction.raw_output,
|
||||
normalized_updates,
|
||||
extraction.document_name,
|
||||
)
|
||||
|
||||
|
||||
def list_srs_history(db: Session, user_id: int) -> List[Dict[str, Any]]:
|
||||
records: List[Tuple[ToolJob, SRSExtraction]] = (
|
||||
db.query(ToolJob, SRSExtraction)
|
||||
.join(SRSExtraction, SRSExtraction.job_id == ToolJob.id)
|
||||
.filter(ToolJob.user_id == user_id)
|
||||
.order_by(ToolJob.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
items: List[Dict[str, Any]] = []
|
||||
for job, extraction in records:
|
||||
items.append(
|
||||
{
|
||||
"jobId": job.id,
|
||||
"documentName": extraction.document_name,
|
||||
"generatedAt": extraction.generated_at.isoformat(),
|
||||
"totalRequirements": extraction.total_requirements,
|
||||
"status": job.status,
|
||||
"createdAt": job.created_at.isoformat(),
|
||||
"updatedAt": job.updated_at.isoformat(),
|
||||
}
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
def delete_srs_job(db: Session, job: ToolJob) -> None:
|
||||
db.delete(job)
|
||||
db.commit()
|
||||
|
||||
|
||||
def build_srs_upload_path(job_id: int) -> Path:
|
||||
return Path("uploads") / "srs_jobs" / str(job_id)
|
||||
|
||||
|
||||
def _count_requirement_types(items: List[Dict[str, Any]]) -> Dict[str, int]:
|
||||
@@ -185,3 +262,210 @@ def _count_requirement_types(items: List[Dict[str, Any]]) -> Dict[str, int]:
|
||||
req_type = item.get("requirementType") or "functional"
|
||||
stats[req_type] = stats.get(req_type, 0) + 1
|
||||
return stats
|
||||
|
||||
|
||||
def _requirement_model_to_payload(item: SRSRequirement) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": item.requirement_uid,
|
||||
"description": item.description,
|
||||
"priority": item.priority,
|
||||
"acceptanceCriteria": item.acceptance_criteria or [],
|
||||
"sourceField": item.source_field,
|
||||
"sectionUid": item.section_uid,
|
||||
"sectionNumber": item.section_number,
|
||||
"sectionTitle": item.section_title,
|
||||
"requirementType": item.requirement_type,
|
||||
"interfaceName": item.interface_name,
|
||||
"interfaceType": item.interface_type,
|
||||
"dataSource": item.data_source,
|
||||
"dataDestination": item.data_destination,
|
||||
"sortOrder": item.sort_order,
|
||||
}
|
||||
|
||||
|
||||
def _normalize_update_payload(item: Dict[str, Any], index: int) -> Dict[str, Any]:
|
||||
requirement_type = _normalize_requirement_type(item.get("requirementType"))
|
||||
normalized = {
|
||||
"id": str(item.get("id") or f"REQ-{index + 1:03d}"),
|
||||
"description": str(item.get("description") or "").strip(),
|
||||
"priority": item.get("priority") or "中",
|
||||
"acceptanceCriteria": item.get("acceptanceCriteria") or ["待补充验收标准"],
|
||||
"sourceField": item.get("sourceField") or "文档解析",
|
||||
"sectionUid": item.get("sectionUid"),
|
||||
"sectionNumber": item.get("sectionNumber"),
|
||||
"sectionTitle": item.get("sectionTitle"),
|
||||
"requirementType": requirement_type,
|
||||
"interfaceName": item.get("interfaceName") or "",
|
||||
"interfaceType": item.get("interfaceType") or "",
|
||||
"dataSource": item.get("dataSource") or "",
|
||||
"dataDestination": item.get("dataDestination") or "",
|
||||
"sortOrder": int(item.get("sortOrder") or index),
|
||||
}
|
||||
|
||||
if requirement_type != "interface":
|
||||
normalized["interfaceName"] = ""
|
||||
normalized["interfaceType"] = ""
|
||||
normalized["dataSource"] = ""
|
||||
normalized["dataDestination"] = ""
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _merge_updates_into_raw_output(
|
||||
raw_output: Dict[str, Any] | None,
|
||||
updates: List[Dict[str, Any]],
|
||||
document_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
if not isinstance(raw_output, dict) or "需求内容" not in raw_output:
|
||||
return _build_raw_output_from_flat(updates, document_name)
|
||||
|
||||
result = deepcopy(raw_output)
|
||||
content = result.get("需求内容")
|
||||
if not isinstance(content, dict):
|
||||
return _build_raw_output_from_flat(updates, document_name)
|
||||
|
||||
updates_by_section = _group_updates_by_section(updates)
|
||||
_rewrite_content_requirements(content, updates_by_section)
|
||||
_append_unmatched_sections(content, updates_by_section)
|
||||
_refresh_metadata(result, updates)
|
||||
return result
|
||||
|
||||
|
||||
def _group_updates_by_section(updates: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
|
||||
grouped: Dict[str, List[Dict[str, Any]]] = {}
|
||||
for item in updates:
|
||||
key = _section_key(item.get("sectionUid"), item.get("sectionNumber"), item.get("sectionTitle"))
|
||||
grouped.setdefault(key, []).append(item)
|
||||
|
||||
for values in grouped.values():
|
||||
values.sort(key=lambda value: int(value.get("sortOrder") or 0))
|
||||
return grouped
|
||||
|
||||
|
||||
def _rewrite_content_requirements(content: Dict[str, Any], updates_by_section: Dict[str, List[Dict[str, Any]]]) -> None:
|
||||
for section in content.values():
|
||||
if not isinstance(section, dict):
|
||||
continue
|
||||
|
||||
section_info = section.get("章节信息") or {}
|
||||
section_key = _section_key(
|
||||
section_info.get("章节UID"),
|
||||
section_info.get("章节编号"),
|
||||
section_info.get("章节标题"),
|
||||
)
|
||||
if section_key in updates_by_section:
|
||||
section["需求列表"] = [_to_raw_requirement_item(item) for item in updates_by_section.pop(section_key)]
|
||||
elif "需求列表" in section:
|
||||
section["需求列表"] = []
|
||||
|
||||
children = section.get("子章节")
|
||||
if isinstance(children, dict):
|
||||
_rewrite_content_requirements(children, updates_by_section)
|
||||
|
||||
|
||||
def _append_unmatched_sections(content: Dict[str, Any], updates_by_section: Dict[str, List[Dict[str, Any]]]) -> None:
|
||||
if not updates_by_section:
|
||||
return
|
||||
|
||||
orphan_key = "未归类章节"
|
||||
orphan_section = content.get(orphan_key)
|
||||
if not isinstance(orphan_section, dict):
|
||||
orphan_section = {
|
||||
"章节信息": {
|
||||
"章节编号": "",
|
||||
"章节标题": orphan_key,
|
||||
"章节级别": 1,
|
||||
},
|
||||
"需求列表": [],
|
||||
}
|
||||
content[orphan_key] = orphan_section
|
||||
|
||||
all_reqs: List[Dict[str, Any]] = orphan_section.get("需求列表") or []
|
||||
for values in updates_by_section.values():
|
||||
for item in values:
|
||||
all_reqs.append(_to_raw_requirement_item(item))
|
||||
orphan_section["需求列表"] = all_reqs
|
||||
updates_by_section.clear()
|
||||
|
||||
|
||||
def _refresh_metadata(raw_output: Dict[str, Any], updates: List[Dict[str, Any]]) -> None:
|
||||
metadata = raw_output.get("文档元数据")
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
raw_output["文档元数据"] = metadata
|
||||
|
||||
metadata["总需求数"] = len(updates)
|
||||
metadata["生成时间"] = datetime.now().isoformat()
|
||||
|
||||
type_stats: Dict[str, int] = {}
|
||||
for item in updates:
|
||||
req_type = item.get("requirementType") or "functional"
|
||||
cn_type = TYPE_TO_CHINESE.get(req_type, "其他需求")
|
||||
type_stats[cn_type] = type_stats.get(cn_type, 0) + 1
|
||||
metadata["需求类型统计"] = type_stats
|
||||
|
||||
|
||||
def _to_raw_requirement_item(item: Dict[str, Any]) -> Dict[str, Any]:
|
||||
req_type = item.get("requirementType") or "functional"
|
||||
raw_item = {
|
||||
"需求类型": TYPE_TO_CHINESE.get(req_type, "其他需求"),
|
||||
"需求编号": item.get("id") or "",
|
||||
"需求描述": item.get("description") or "",
|
||||
"优先级": item.get("priority") or "中",
|
||||
}
|
||||
if req_type == "interface":
|
||||
raw_item["接口名称"] = item.get("interfaceName") or ""
|
||||
raw_item["接口类型"] = item.get("interfaceType") or ""
|
||||
raw_item["数据来源"] = item.get("dataSource") or ""
|
||||
raw_item["数据目的地"] = item.get("dataDestination") or ""
|
||||
return raw_item
|
||||
|
||||
|
||||
def _build_raw_output_from_flat(updates: List[Dict[str, Any]], document_name: str) -> Dict[str, Any]:
|
||||
grouped = _group_updates_by_section(updates)
|
||||
content: Dict[str, Any] = {}
|
||||
for key, values in grouped.items():
|
||||
number, title = _parse_section_key(key)
|
||||
section_title = title or "未归类章节"
|
||||
display_key = f"{number} {section_title}".strip()
|
||||
content[display_key] = {
|
||||
"章节信息": {
|
||||
"章节编号": number,
|
||||
"章节标题": section_title,
|
||||
"章节级别": max(len(number.split(".")), 1) if number else 1,
|
||||
},
|
||||
"需求列表": [_to_raw_requirement_item(item) for item in values],
|
||||
}
|
||||
|
||||
raw_output = {
|
||||
"文档元数据": {
|
||||
"标题": document_name,
|
||||
"生成时间": datetime.now().isoformat(),
|
||||
"总需求数": len(updates),
|
||||
"需求类型统计": {},
|
||||
},
|
||||
"需求内容": content,
|
||||
}
|
||||
_refresh_metadata(raw_output, updates)
|
||||
return raw_output
|
||||
|
||||
|
||||
def _section_key(section_uid: Any, section_number: Any, section_title: Any) -> str:
|
||||
uid = str(section_uid or "").strip()
|
||||
if uid:
|
||||
return f"uid::{uid}"
|
||||
number = str(section_number or "").strip()
|
||||
title = str(section_title or "").strip()
|
||||
return f"number::{number}::title::{title}"
|
||||
|
||||
|
||||
def _parse_section_key(value: str) -> Tuple[str, str]:
|
||||
if value.startswith("uid::"):
|
||||
return "", "未归类章节"
|
||||
number = ""
|
||||
title = ""
|
||||
parts = value.split("::")
|
||||
if len(parts) >= 4:
|
||||
number = parts[1]
|
||||
title = parts[3]
|
||||
return number, title
|
||||
|
||||
@@ -0,0 +1,236 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db.session import SessionLocal
|
||||
from app.models.knowledge import Document, KnowledgeBase
|
||||
from app.models.tooling import TestingGeneration, ToolJob
|
||||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||||
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
|
||||
from app.services.testing_pipeline import run_testing_pipeline
|
||||
from app.services.vector_store import VectorStoreFactory
|
||||
|
||||
def _flatten_record(value: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
||||
items: List[Dict[str, Any]] = []
|
||||
for current in value.values():
|
||||
items.extend(current)
|
||||
return items
|
||||
|
||||
|
||||
def _build_kb_vector_stores(db: Session, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
|
||||
embeddings = EmbeddingsFactory.create()
|
||||
kb_vector_stores: List[Dict[str, Any]] = []
|
||||
|
||||
for kb in knowledge_bases:
|
||||
documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all()
|
||||
if not documents:
|
||||
continue
|
||||
|
||||
store = VectorStoreFactory.create(
|
||||
store_type=settings.VECTOR_STORE_TYPE,
|
||||
collection_name=f"kb_{kb.id}",
|
||||
embedding_function=embeddings,
|
||||
)
|
||||
kb_vector_stores.append({"kb_id": kb.id, "store": store})
|
||||
|
||||
return kb_vector_stores
|
||||
|
||||
|
||||
def _resolve_knowledge_context(
|
||||
db: Session,
|
||||
*,
|
||||
user_id: int,
|
||||
requirement_text: str,
|
||||
knowledge_base_id: int | None,
|
||||
) -> str:
|
||||
if knowledge_base_id is None:
|
||||
return ""
|
||||
|
||||
try:
|
||||
knowledge_bases = (
|
||||
db.query(KnowledgeBase)
|
||||
.filter(
|
||||
KnowledgeBase.id == knowledge_base_id,
|
||||
KnowledgeBase.user_id == user_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
kb_vector_stores = _build_kb_vector_stores(db, knowledge_bases)
|
||||
if not kb_vector_stores:
|
||||
return ""
|
||||
|
||||
retriever = MultiKBRetriever(
|
||||
reranker_weight=settings.RERANKER_WEIGHT,
|
||||
)
|
||||
rows = asyncio.run(
|
||||
retriever.retrieve(
|
||||
query=requirement_text,
|
||||
kb_vector_stores=kb_vector_stores,
|
||||
fetch_k_per_kb=16,
|
||||
top_k=8,
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
return format_retrieval_context(rows)
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _build_generated_requirement(req: Dict[str, Any], pipeline_result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
test_items = [
|
||||
{
|
||||
"id": item.get("id"),
|
||||
"content": item.get("content"),
|
||||
}
|
||||
for item in _flatten_record(pipeline_result.get("test_items", {}))
|
||||
]
|
||||
test_cases = [
|
||||
{
|
||||
"id": item.get("id"),
|
||||
"itemId": item.get("item_id"),
|
||||
"testContent": item.get("test_content"),
|
||||
"operationSteps": item.get("operation_steps", []),
|
||||
"expectedResultPlaceholder": item.get("expected_result_placeholder"),
|
||||
}
|
||||
for item in _flatten_record(pipeline_result.get("test_cases", {}))
|
||||
]
|
||||
expected_results = [
|
||||
{
|
||||
"id": item.get("id"),
|
||||
"caseId": item.get("case_id"),
|
||||
"result": item.get("result"),
|
||||
}
|
||||
for item in _flatten_record(pipeline_result.get("expected_results", {}))
|
||||
]
|
||||
|
||||
return {
|
||||
**req,
|
||||
"测试项": test_items,
|
||||
"测试用例": test_cases,
|
||||
"预期结果": expected_results,
|
||||
}
|
||||
|
||||
|
||||
def _mark_job_failed(job_id: int, error_message: str) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
|
||||
if not job:
|
||||
return
|
||||
job.status = "failed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
job.error_message = error_message[:2000]
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
|
||||
if not job:
|
||||
return
|
||||
|
||||
requirements = payload.get("requirements") or []
|
||||
source_document_name = str(payload.get("source_document_name") or job.input_file_name or "")
|
||||
source_job_id = payload.get("source_job_id")
|
||||
knowledge_base_id = payload.get("knowledge_base_id")
|
||||
|
||||
job.status = "processing"
|
||||
job.started_at = datetime.utcnow()
|
||||
job.error_message = None
|
||||
job.output_summary = {
|
||||
"source_document_name": source_document_name,
|
||||
"current_step": 0,
|
||||
"total_steps": len(requirements),
|
||||
}
|
||||
db.commit()
|
||||
|
||||
generated_requirements: List[Dict[str, Any]] = []
|
||||
|
||||
for index, req in enumerate(requirements):
|
||||
req_id = str(req.get("id") or f"REQ-{index + 1:03d}")
|
||||
job.output_summary = {
|
||||
"source_document_name": source_document_name,
|
||||
"current_step": index + 1,
|
||||
"total_steps": len(requirements),
|
||||
"current_requirement_id": req_id,
|
||||
}
|
||||
db.commit()
|
||||
|
||||
description = str(req.get("description") or "").strip()
|
||||
if not description:
|
||||
generated_requirements.append(
|
||||
{
|
||||
**req,
|
||||
"测试项": [],
|
||||
"测试用例": [],
|
||||
"预期结果": [],
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
knowledge_context = _resolve_knowledge_context(
|
||||
db,
|
||||
user_id=job.user_id,
|
||||
requirement_text=description,
|
||||
knowledge_base_id=knowledge_base_id,
|
||||
)
|
||||
|
||||
pipeline_result = run_testing_pipeline(
|
||||
user_requirement_text=description,
|
||||
requirement_type_input=req.get("requirementType"),
|
||||
debug=False,
|
||||
knowledge_context=knowledge_context,
|
||||
use_model_generation=True,
|
||||
max_items_per_group=12,
|
||||
cases_per_item=2,
|
||||
max_focus_points=6,
|
||||
max_llm_calls=10,
|
||||
)
|
||||
generated_requirements.append(_build_generated_requirement(req, pipeline_result))
|
||||
|
||||
generated_at = datetime.utcnow()
|
||||
generated_file = {
|
||||
"sourceDocument": source_document_name,
|
||||
"sourceJobId": source_job_id,
|
||||
"generatedAt": generated_at.isoformat(),
|
||||
"totalRequirements": len(generated_requirements),
|
||||
"knowledgeBaseId": knowledge_base_id,
|
||||
"requirements": generated_requirements,
|
||||
}
|
||||
|
||||
generation = TestingGeneration(
|
||||
job_id=job.id,
|
||||
source_job_id=source_job_id,
|
||||
source_document_name=source_document_name,
|
||||
generated_at=generated_at,
|
||||
total_requirements=len(generated_requirements),
|
||||
knowledge_base_id=knowledge_base_id,
|
||||
generated_file=generated_file,
|
||||
)
|
||||
db.add(generation)
|
||||
|
||||
job.status = "completed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
job.output_summary = {
|
||||
"source_document_name": source_document_name,
|
||||
"current_step": len(generated_requirements),
|
||||
"total_steps": len(generated_requirements),
|
||||
"total_requirements": len(generated_requirements),
|
||||
"knowledge_base_id": knowledge_base_id,
|
||||
}
|
||||
db.commit()
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
_mark_job_failed(job_id, str(exc))
|
||||
finally:
|
||||
db.close()
|
||||
111
rag-web-ui/backend/app/services/testing_generation_service.py
Normal file
111
rag-web-ui/backend/app/services/testing_generation_service.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.tooling import TestingGeneration, ToolJob
|
||||
from app.schemas.tooling import TestingGenerationSaveRequest
|
||||
|
||||
TESTING_TOOL_NAME = "testing.case_generator"
|
||||
|
||||
|
||||
def _resolve_total_requirements(generated_file: Dict[str, Any]) -> int:
|
||||
requirements = generated_file.get("requirements")
|
||||
if isinstance(requirements, list):
|
||||
return len(requirements)
|
||||
|
||||
total = generated_file.get("totalRequirements")
|
||||
if isinstance(total, int) and total >= 0:
|
||||
return total
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def build_testing_generation_response(job: ToolJob, generation: TestingGeneration) -> Dict[str, Any]:
|
||||
return {
|
||||
"jobId": job.id,
|
||||
"sourceJobId": generation.source_job_id,
|
||||
"sourceDocumentName": generation.source_document_name,
|
||||
"generatedAt": generation.generated_at.isoformat(),
|
||||
"totalRequirements": generation.total_requirements,
|
||||
"knowledgeBaseId": generation.knowledge_base_id,
|
||||
"generatedFile": generation.generated_file or {},
|
||||
}
|
||||
|
||||
|
||||
def create_testing_generation(
|
||||
db: Session,
|
||||
user_id: int,
|
||||
payload: TestingGenerationSaveRequest,
|
||||
) -> Dict[str, Any]:
|
||||
now = datetime.utcnow()
|
||||
total_requirements = _resolve_total_requirements(payload.generated_file)
|
||||
|
||||
job = ToolJob(
|
||||
user_id=user_id,
|
||||
tool_name=TESTING_TOOL_NAME,
|
||||
status="completed",
|
||||
input_file_name=payload.source_document_name,
|
||||
input_file_path="",
|
||||
started_at=now,
|
||||
completed_at=now,
|
||||
output_summary={
|
||||
"source_document_name": payload.source_document_name,
|
||||
"total_requirements": total_requirements,
|
||||
"knowledge_base_id": payload.knowledge_base_id,
|
||||
},
|
||||
)
|
||||
db.add(job)
|
||||
db.flush()
|
||||
|
||||
generation = TestingGeneration(
|
||||
job_id=job.id,
|
||||
source_job_id=payload.source_job_id,
|
||||
source_document_name=payload.source_document_name,
|
||||
generated_at=now,
|
||||
total_requirements=total_requirements,
|
||||
knowledge_base_id=payload.knowledge_base_id,
|
||||
generated_file=payload.generated_file,
|
||||
)
|
||||
db.add(generation)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
db.refresh(generation)
|
||||
|
||||
return build_testing_generation_response(job, generation)
|
||||
|
||||
|
||||
def list_testing_history(db: Session, user_id: int) -> List[Dict[str, Any]]:
|
||||
rows: List[Tuple[ToolJob, TestingGeneration]] = (
|
||||
db.query(ToolJob, TestingGeneration)
|
||||
.join(TestingGeneration, TestingGeneration.job_id == ToolJob.id)
|
||||
.filter(
|
||||
ToolJob.user_id == user_id,
|
||||
ToolJob.tool_name == TESTING_TOOL_NAME,
|
||||
)
|
||||
.order_by(ToolJob.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
items: List[Dict[str, Any]] = []
|
||||
for job, generation in rows:
|
||||
items.append(
|
||||
{
|
||||
"jobId": job.id,
|
||||
"sourceJobId": generation.source_job_id,
|
||||
"sourceDocumentName": generation.source_document_name,
|
||||
"generatedAt": generation.generated_at.isoformat(),
|
||||
"totalRequirements": generation.total_requirements,
|
||||
"knowledgeBaseId": generation.knowledge_base_id,
|
||||
"status": job.status,
|
||||
"createdAt": job.created_at.isoformat(),
|
||||
"updatedAt": job.updated_at.isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def delete_testing_generation(db: Session, job: ToolJob) -> None:
|
||||
db.delete(job)
|
||||
db.commit()
|
||||
@@ -4,7 +4,6 @@ from time import perf_counter
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
from app.services.testing_pipeline.tools import build_default_tool_chain
|
||||
|
||||
|
||||
@@ -42,6 +41,8 @@ def run_testing_pipeline(
|
||||
llm_model = None
|
||||
if use_model_generation:
|
||||
try:
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
|
||||
llm_model = LLMFactory.create(streaming=False)
|
||||
except Exception:
|
||||
llm_model = None
|
||||
|
||||
@@ -137,6 +137,19 @@ class DocumentParser(ABC):
|
||||
chinese_numbers = '一二三四五六七八九十百千万'
|
||||
return text and all(c in chinese_numbers for c in text)
|
||||
|
||||
def _section_sort_key(self, section: 'Section') -> Tuple[int, List[int], str]:
|
||||
number = (section.number or "").strip()
|
||||
if number and re.match(r'^\d+(?:\.\d+)*$', number):
|
||||
return (0, [int(part) for part in number.split('.')], section.title or "")
|
||||
return (1, [section.level], section.title or "")
|
||||
|
||||
def _sort_sections_by_number(self, sections: List['Section']) -> List['Section']:
|
||||
ordered = sorted(sections, key=self._section_sort_key)
|
||||
for section in ordered:
|
||||
if section.children:
|
||||
section.children = self._sort_sections_by_number(section.children)
|
||||
return ordered
|
||||
|
||||
|
||||
class DocxParser(DocumentParser):
|
||||
"""DOCX格式文档解析器"""
|
||||
@@ -210,6 +223,7 @@ class DocxParser(DocumentParser):
|
||||
|
||||
# 为没有编号的章节自动生成编号
|
||||
self._auto_number_sections(self.sections)
|
||||
self.sections = self._sort_sections_by_number(self.sections)
|
||||
|
||||
logger.info(f"完成Docx解析,提取{len(self.sections)}个顶级章节")
|
||||
return self.sections
|
||||
@@ -236,12 +250,17 @@ class DocxParser(DocumentParser):
|
||||
"""解析标题,返回(编号, 标题, 级别)"""
|
||||
style_name = paragraph.style.name if paragraph.style else ""
|
||||
is_heading_style = style_name.lower().startswith('heading') if style_name else False
|
||||
|
||||
if self._is_calendar_line(text):
|
||||
return None
|
||||
|
||||
# 数字编号标题
|
||||
match = re.match(r'^(\d+(?:\.\d+)*)\s*[\.、]?\s*(.+)$', text)
|
||||
match = re.match(r'^(\d+(?:\.\d+)*)\s*[\.、.))::\-_/]?\s*(.+)$', text)
|
||||
if match and self._is_valid_heading(match.group(2)):
|
||||
number = match.group(1)
|
||||
title = match.group(2).strip()
|
||||
if not self._is_valid_numbered_heading(number, title):
|
||||
return None
|
||||
level = len(number.split('.'))
|
||||
return number, title, level
|
||||
|
||||
@@ -263,6 +282,31 @@ class DocxParser(DocumentParser):
|
||||
|
||||
return None
|
||||
|
||||
def _is_calendar_line(self, text: str) -> bool:
|
||||
value = (text or "").strip().replace(" ", "")
|
||||
return bool(re.match(r'^\d{4}年\d{1,2}月(?:\d{1,2}日)?$', value))
|
||||
|
||||
def _is_valid_numbered_heading(self, number: str, title: str) -> bool:
|
||||
parts = number.split('.')
|
||||
if len(parts) > 6:
|
||||
return False
|
||||
|
||||
first = int(parts[0])
|
||||
if first < 1 or first > 30:
|
||||
return False
|
||||
|
||||
for part in parts[1:]:
|
||||
if int(part) > 30:
|
||||
return False
|
||||
|
||||
if len(parts) == 1 and re.match(r'^年\d{1,2}月', title):
|
||||
return False
|
||||
|
||||
if title and title[0].isdigit():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _iter_block_items(self, parent):
|
||||
"""按文档顺序迭代段落和表格"""
|
||||
from docx.text.paragraph import Paragraph
|
||||
@@ -356,6 +400,7 @@ class PDFParser(DocumentParser):
|
||||
|
||||
# 6. 为没有编号的章节自动生成编号
|
||||
self._auto_number_sections(self.sections)
|
||||
self.sections = self._sort_sections_by_number(self.sections)
|
||||
|
||||
logger.info(f"完成PDF解析,提取{len(self.sections)}个顶级章节")
|
||||
return self.sections
|
||||
@@ -599,18 +644,6 @@ class PDFParser(DocumentParser):
|
||||
level = len(number.split('.'))
|
||||
top_level_number = int(number.split('.')[0])
|
||||
|
||||
# 顶级章节序号大幅跳跃通常是误识别(如正文中的“8 表...”)。
|
||||
if level == 1 and last_top_level_number and top_level_number > last_top_level_number + 1:
|
||||
if line and not self._is_noise(line):
|
||||
content_buffer.append(line)
|
||||
continue
|
||||
|
||||
# 顶级章节编号倒退通常是正文枚举项被误识别(如“1 综合监控...”)。
|
||||
if level == 1 and last_top_level_number and top_level_number < last_top_level_number:
|
||||
if line and not self._is_noise(line):
|
||||
content_buffer.append(line)
|
||||
continue
|
||||
|
||||
if level > 6:
|
||||
continue
|
||||
|
||||
@@ -645,10 +678,6 @@ class PDFParser(DocumentParser):
|
||||
for l in list(section_stack.keys()):
|
||||
if l > level:
|
||||
del section_stack[l]
|
||||
|
||||
# 若出现层级跳跃(如1->3),自动回退到父级+1。
|
||||
if level > 1 and (level - 1) not in section_stack:
|
||||
section.level = max(section_stack.keys()) if section_stack else 1
|
||||
|
||||
current_section = section
|
||||
else:
|
||||
@@ -670,7 +699,10 @@ class PDFParser(DocumentParser):
|
||||
(章节编号, 章节标题) 或 None
|
||||
"""
|
||||
# 模式: "3.1 功能需求" / "3.1.2 电场..."
|
||||
match = re.match(r'^(\d+(?:\.\d+)*)[\s、.))]*(.+)$', line)
|
||||
if self._is_calendar_line(line):
|
||||
return None
|
||||
|
||||
match = re.match(r'^(\d+(?:\.\d+)*)[\s、..))::\-_/]*(.+)$', line)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
@@ -692,7 +724,7 @@ class PDFParser(DocumentParser):
|
||||
|
||||
# 检查子部分是否合理
|
||||
for part in parts[1:]:
|
||||
if int(part) > 20:
|
||||
if int(part) > 30:
|
||||
return None
|
||||
|
||||
# 避免重复
|
||||
@@ -747,6 +779,10 @@ class PDFParser(DocumentParser):
|
||||
|
||||
return (number, title)
|
||||
|
||||
def _is_calendar_line(self, text: str) -> bool:
|
||||
value = (text or "").strip().replace(" ", "")
|
||||
return bool(re.match(r'^\d{4}年\d{1,2}月(?:\d{1,2}日)?$', value))
|
||||
|
||||
def _looks_like_statement(self, title: str) -> bool:
|
||||
"""判断标题是否更像正文语句而非章节名。"""
|
||||
if not title:
|
||||
|
||||
@@ -51,6 +51,8 @@ class SRSTool:
|
||||
"other": "低",
|
||||
}
|
||||
|
||||
UNKNOWN_INTERFACE_VALUES = {"", "未知", "unknown", "n/a", "-", "--", "无", "none", "null"}
|
||||
|
||||
def __init__(self) -> None:
|
||||
ToolRegistry.register(self.DEFINITION)
|
||||
|
||||
@@ -90,24 +92,78 @@ class SRSTool:
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
for index, req in enumerate(extracted, start=1):
|
||||
description = (req.description or "").strip()
|
||||
title = description[:40] if description else f"需求项 {index}"
|
||||
title = self._build_short_title(description, index)
|
||||
requirement_type = self._normalize_requirement_type(
|
||||
req_type=getattr(req, "type", "functional"),
|
||||
interface_name=getattr(req, "interface_name", ""),
|
||||
interface_type=getattr(req, "interface_type", ""),
|
||||
data_source=getattr(req, "source", ""),
|
||||
data_destination=getattr(req, "destination", ""),
|
||||
)
|
||||
source_field = f"{req.section_number} {req.section_title}".strip() or "文档解析"
|
||||
normalized.append(
|
||||
{
|
||||
"id": req.id,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"priority": self.PRIORITY_BY_TYPE.get(req.type, "中"),
|
||||
"priority": "中",
|
||||
"acceptance_criteria": [description] if description else ["待补充验收标准"],
|
||||
"source_field": source_field,
|
||||
"section_uid": req.section_uid,
|
||||
"section_number": req.section_number,
|
||||
"section_title": req.section_title,
|
||||
"requirement_type": req.type,
|
||||
"requirement_type": requirement_type,
|
||||
"interface_name": req.interface_name if requirement_type == "interface" else "",
|
||||
"interface_type": req.interface_type if requirement_type == "interface" else "",
|
||||
"data_source": req.source if requirement_type == "interface" else "",
|
||||
"data_destination": req.destination if requirement_type == "interface" else "",
|
||||
"sort_order": index,
|
||||
}
|
||||
)
|
||||
return normalized
|
||||
|
||||
def _normalize_requirement_type(
|
||||
self,
|
||||
req_type: Any,
|
||||
interface_name: Any,
|
||||
interface_type: Any,
|
||||
data_source: Any,
|
||||
data_destination: Any,
|
||||
) -> str:
|
||||
raw_type = str(req_type or "").strip()
|
||||
mapping = {
|
||||
"功能需求": "functional",
|
||||
"接口需求": "interface",
|
||||
"性能需求": "performance",
|
||||
"安全需求": "security",
|
||||
"可靠性需求": "reliability",
|
||||
"其他需求": "other",
|
||||
}
|
||||
normalized_type = mapping.get(raw_type, raw_type)
|
||||
if normalized_type not in self.PRIORITY_BY_TYPE:
|
||||
normalized_type = "functional"
|
||||
|
||||
fields = [interface_name, interface_type, data_source, data_destination]
|
||||
has_interface_fields = any(
|
||||
str(value or "").strip().lower() not in self.UNKNOWN_INTERFACE_VALUES for value in fields
|
||||
)
|
||||
|
||||
if normalized_type == "interface" or has_interface_fields:
|
||||
return "interface"
|
||||
return normalized_type
|
||||
|
||||
def _build_short_title(self, description: str, index: int) -> str:
|
||||
text = (description or "").strip()
|
||||
if not text:
|
||||
return f"需求项 {index}"
|
||||
for separator in ("。", ";", "\n", ";", "."):
|
||||
if separator in text:
|
||||
text = text.split(separator, 1)[0].strip()
|
||||
break
|
||||
if len(text) <= 20:
|
||||
return text
|
||||
return f"{text[:20].rstrip()}"
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
config_path = Path(__file__).with_name("default_config.yaml")
|
||||
if config_path.exists():
|
||||
|
||||
Reference in New Issue
Block a user