完善skills;测试用例生成页面功能初步实现

This commit is contained in:
2026-05-05 19:45:33 +08:00
parent 0c2ed67e2a
commit 69b49d28b2
35 changed files with 4396 additions and 658 deletions

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Tuple
from sqlalchemy.orm import Session
@@ -10,6 +11,45 @@ from app.db.session import SessionLocal
from app.models.tooling import SRSExtraction, SRSRequirement, ToolJob
from app.tools.srs_reqs_qwen import get_srs_tool
TYPE_TO_CHINESE = {
"functional": "功能需求",
"interface": "接口需求",
"performance": "性能需求",
"security": "安全需求",
"reliability": "可靠性需求",
"other": "其他需求",
}
def _build_internal_title(description: Any, fallback: str, index: int = 0) -> str:
text = str(description or "").strip()
if not text:
return fallback or f"需求项 {index + 1}"
for separator in ("", "", "\n", ";", "."):
if separator in text:
text = text.split(separator, 1)[0].strip()
break
text = text[:20].strip()
return text or fallback or f"需求项 {index + 1}"
def _normalize_requirement_type(value: Any) -> str:
text = str(value or "").strip()
if text in {"functional", "interface", "performance", "security", "reliability", "other"}:
return text
chinese_map = {
"接口需求": "interface",
"性能需求": "performance",
"安全需求": "security",
"可靠性需求": "reliability",
"其他需求": "other",
"功能需求": "functional",
}
return chinese_map.get(text, "functional")
def run_srs_job(job_id: int) -> None:
db = SessionLocal()
@@ -41,14 +81,19 @@ def run_srs_job(job_id: int) -> None:
requirement = SRSRequirement(
extraction_id=extraction.id,
requirement_uid=item["id"],
title=item.get("title") or item["id"],
title=_build_internal_title(item.get("description"), item["id"]),
description=item.get("description") or "",
priority=item.get("priority") or "",
acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"],
source_field=item.get("source_field") or "文档解析",
section_uid=item.get("section_uid"),
section_number=item.get("section_number"),
section_title=item.get("section_title"),
requirement_type=item.get("requirement_type"),
interface_name=item.get("interface_name"),
interface_type=item.get("interface_type"),
data_source=item.get("data_source"),
data_destination=item.get("data_destination"),
sort_order=int(item.get("sort_order") or 0),
)
db.add(requirement)
@@ -97,22 +142,8 @@ def ensure_upload_path(job_id: int, file_name: str) -> Path:
def build_result_response(job: ToolJob, extraction: SRSExtraction) -> Dict[str, Any]:
requirements: List[Dict[str, Any]] = []
for item in extraction.requirements:
requirements.append(
{
"id": item.requirement_uid,
"title": item.title,
"description": item.description,
"priority": item.priority,
"acceptanceCriteria": item.acceptance_criteria or [],
"sourceField": item.source_field,
"sectionNumber": item.section_number,
"sectionTitle": item.section_title,
"requirementType": item.requirement_type,
"sortOrder": item.sort_order,
}
)
requirements = [_requirement_model_to_payload(item) for item in extraction.requirements]
raw_output = _merge_updates_into_raw_output(extraction.raw_output, requirements, extraction.document_name)
return {
"jobId": job.id,
@@ -120,10 +151,12 @@ def build_result_response(job: ToolJob, extraction: SRSExtraction) -> Dict[str,
"generatedAt": extraction.generated_at.isoformat(),
"statistics": extraction.statistics or {},
"requirements": requirements,
"rawOutput": raw_output,
}
def replace_requirements(db: Session, extraction: SRSExtraction, updates: List[Dict[str, Any]]) -> None:
normalized_updates = [_normalize_update_payload(item, index) for index, item in enumerate(updates)]
existing = {
req.requirement_uid: req
for req in db.query(SRSRequirement)
@@ -132,51 +165,95 @@ def replace_requirements(db: Session, extraction: SRSExtraction, updates: List[D
}
seen_ids = set()
for index, item in enumerate(updates):
for index, item in enumerate(normalized_updates):
uid = item["id"]
seen_ids.add(uid)
req = existing.get(uid)
if req is None:
description = item.get("description") if item.get("description") is not None else ""
req = SRSRequirement(
extraction_id=extraction.id,
requirement_uid=uid,
title=item.get("title") or uid,
description=item.get("description") if item.get("description") is not None else "",
title=_build_internal_title(description, uid, int(item.get("sortOrder") or index)),
description=description,
priority=item.get("priority") or "",
acceptance_criteria=item.get("acceptanceCriteria") or ["待补充验收标准"],
source_field=item.get("sourceField") or "文档解析",
section_uid=item.get("sectionUid"),
section_number=item.get("sectionNumber"),
section_title=item.get("sectionTitle"),
requirement_type=item.get("requirementType"),
sort_order=int(item.get("sortOrder") or index),
interface_name=item.get("interfaceName"),
interface_type=item.get("interfaceType"),
data_source=item.get("dataSource"),
data_destination=item.get("dataDestination"),
sort_order=int(item.get("sortOrder") or 0),
)
db.add(req)
continue
req.title = item.get("title", req.title)
req.description = item.get("description", req.description)
req.title = _build_internal_title(req.description, req.requirement_uid, int(item.get("sortOrder") or index))
req.priority = item.get("priority", req.priority)
req.acceptance_criteria = item.get("acceptanceCriteria", req.acceptance_criteria)
req.source_field = item.get("sourceField", req.source_field)
req.section_uid = item.get("sectionUid", req.section_uid)
req.section_number = item.get("sectionNumber", req.section_number)
req.section_title = item.get("sectionTitle", req.section_title)
req.requirement_type = item.get("requirementType", req.requirement_type)
req.sort_order = int(item.get("sortOrder", index))
req.interface_name = item.get("interfaceName", req.interface_name)
req.interface_type = item.get("interfaceType", req.interface_type)
req.data_source = item.get("dataSource", req.data_source)
req.data_destination = item.get("dataDestination", req.data_destination)
req.sort_order = int(item.get("sortOrder", req.sort_order))
for uid, req in existing.items():
if uid not in seen_ids:
db.delete(req)
extraction.total_requirements = len(updates)
extraction.total_requirements = len(normalized_updates)
extraction.statistics = {
"total": len(updates),
"by_type": _count_requirement_types(updates),
}
extraction.raw_output = {
"document_name": extraction.document_name,
"generated_at": extraction.generated_at.isoformat(),
"requirements": updates,
"total": len(normalized_updates),
"by_type": _count_requirement_types(normalized_updates),
}
extraction.raw_output = _merge_updates_into_raw_output(
extraction.raw_output,
normalized_updates,
extraction.document_name,
)
def list_srs_history(db: Session, user_id: int) -> List[Dict[str, Any]]:
records: List[Tuple[ToolJob, SRSExtraction]] = (
db.query(ToolJob, SRSExtraction)
.join(SRSExtraction, SRSExtraction.job_id == ToolJob.id)
.filter(ToolJob.user_id == user_id)
.order_by(ToolJob.created_at.desc())
.all()
)
items: List[Dict[str, Any]] = []
for job, extraction in records:
items.append(
{
"jobId": job.id,
"documentName": extraction.document_name,
"generatedAt": extraction.generated_at.isoformat(),
"totalRequirements": extraction.total_requirements,
"status": job.status,
"createdAt": job.created_at.isoformat(),
"updatedAt": job.updated_at.isoformat(),
}
)
return items
def delete_srs_job(db: Session, job: ToolJob) -> None:
db.delete(job)
db.commit()
def build_srs_upload_path(job_id: int) -> Path:
return Path("uploads") / "srs_jobs" / str(job_id)
def _count_requirement_types(items: List[Dict[str, Any]]) -> Dict[str, int]:
@@ -185,3 +262,210 @@ def _count_requirement_types(items: List[Dict[str, Any]]) -> Dict[str, int]:
req_type = item.get("requirementType") or "functional"
stats[req_type] = stats.get(req_type, 0) + 1
return stats
def _requirement_model_to_payload(item: SRSRequirement) -> Dict[str, Any]:
return {
"id": item.requirement_uid,
"description": item.description,
"priority": item.priority,
"acceptanceCriteria": item.acceptance_criteria or [],
"sourceField": item.source_field,
"sectionUid": item.section_uid,
"sectionNumber": item.section_number,
"sectionTitle": item.section_title,
"requirementType": item.requirement_type,
"interfaceName": item.interface_name,
"interfaceType": item.interface_type,
"dataSource": item.data_source,
"dataDestination": item.data_destination,
"sortOrder": item.sort_order,
}
def _normalize_update_payload(item: Dict[str, Any], index: int) -> Dict[str, Any]:
requirement_type = _normalize_requirement_type(item.get("requirementType"))
normalized = {
"id": str(item.get("id") or f"REQ-{index + 1:03d}"),
"description": str(item.get("description") or "").strip(),
"priority": item.get("priority") or "",
"acceptanceCriteria": item.get("acceptanceCriteria") or ["待补充验收标准"],
"sourceField": item.get("sourceField") or "文档解析",
"sectionUid": item.get("sectionUid"),
"sectionNumber": item.get("sectionNumber"),
"sectionTitle": item.get("sectionTitle"),
"requirementType": requirement_type,
"interfaceName": item.get("interfaceName") or "",
"interfaceType": item.get("interfaceType") or "",
"dataSource": item.get("dataSource") or "",
"dataDestination": item.get("dataDestination") or "",
"sortOrder": int(item.get("sortOrder") or index),
}
if requirement_type != "interface":
normalized["interfaceName"] = ""
normalized["interfaceType"] = ""
normalized["dataSource"] = ""
normalized["dataDestination"] = ""
return normalized
def _merge_updates_into_raw_output(
raw_output: Dict[str, Any] | None,
updates: List[Dict[str, Any]],
document_name: str,
) -> Dict[str, Any]:
if not isinstance(raw_output, dict) or "需求内容" not in raw_output:
return _build_raw_output_from_flat(updates, document_name)
result = deepcopy(raw_output)
content = result.get("需求内容")
if not isinstance(content, dict):
return _build_raw_output_from_flat(updates, document_name)
updates_by_section = _group_updates_by_section(updates)
_rewrite_content_requirements(content, updates_by_section)
_append_unmatched_sections(content, updates_by_section)
_refresh_metadata(result, updates)
return result
def _group_updates_by_section(updates: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
grouped: Dict[str, List[Dict[str, Any]]] = {}
for item in updates:
key = _section_key(item.get("sectionUid"), item.get("sectionNumber"), item.get("sectionTitle"))
grouped.setdefault(key, []).append(item)
for values in grouped.values():
values.sort(key=lambda value: int(value.get("sortOrder") or 0))
return grouped
def _rewrite_content_requirements(content: Dict[str, Any], updates_by_section: Dict[str, List[Dict[str, Any]]]) -> None:
for section in content.values():
if not isinstance(section, dict):
continue
section_info = section.get("章节信息") or {}
section_key = _section_key(
section_info.get("章节UID"),
section_info.get("章节编号"),
section_info.get("章节标题"),
)
if section_key in updates_by_section:
section["需求列表"] = [_to_raw_requirement_item(item) for item in updates_by_section.pop(section_key)]
elif "需求列表" in section:
section["需求列表"] = []
children = section.get("子章节")
if isinstance(children, dict):
_rewrite_content_requirements(children, updates_by_section)
def _append_unmatched_sections(content: Dict[str, Any], updates_by_section: Dict[str, List[Dict[str, Any]]]) -> None:
if not updates_by_section:
return
orphan_key = "未归类章节"
orphan_section = content.get(orphan_key)
if not isinstance(orphan_section, dict):
orphan_section = {
"章节信息": {
"章节编号": "",
"章节标题": orphan_key,
"章节级别": 1,
},
"需求列表": [],
}
content[orphan_key] = orphan_section
all_reqs: List[Dict[str, Any]] = orphan_section.get("需求列表") or []
for values in updates_by_section.values():
for item in values:
all_reqs.append(_to_raw_requirement_item(item))
orphan_section["需求列表"] = all_reqs
updates_by_section.clear()
def _refresh_metadata(raw_output: Dict[str, Any], updates: List[Dict[str, Any]]) -> None:
metadata = raw_output.get("文档元数据")
if not isinstance(metadata, dict):
metadata = {}
raw_output["文档元数据"] = metadata
metadata["总需求数"] = len(updates)
metadata["生成时间"] = datetime.now().isoformat()
type_stats: Dict[str, int] = {}
for item in updates:
req_type = item.get("requirementType") or "functional"
cn_type = TYPE_TO_CHINESE.get(req_type, "其他需求")
type_stats[cn_type] = type_stats.get(cn_type, 0) + 1
metadata["需求类型统计"] = type_stats
def _to_raw_requirement_item(item: Dict[str, Any]) -> Dict[str, Any]:
req_type = item.get("requirementType") or "functional"
raw_item = {
"需求类型": TYPE_TO_CHINESE.get(req_type, "其他需求"),
"需求编号": item.get("id") or "",
"需求描述": item.get("description") or "",
"优先级": item.get("priority") or "",
}
if req_type == "interface":
raw_item["接口名称"] = item.get("interfaceName") or ""
raw_item["接口类型"] = item.get("interfaceType") or ""
raw_item["数据来源"] = item.get("dataSource") or ""
raw_item["数据目的地"] = item.get("dataDestination") or ""
return raw_item
def _build_raw_output_from_flat(updates: List[Dict[str, Any]], document_name: str) -> Dict[str, Any]:
grouped = _group_updates_by_section(updates)
content: Dict[str, Any] = {}
for key, values in grouped.items():
number, title = _parse_section_key(key)
section_title = title or "未归类章节"
display_key = f"{number} {section_title}".strip()
content[display_key] = {
"章节信息": {
"章节编号": number,
"章节标题": section_title,
"章节级别": max(len(number.split(".")), 1) if number else 1,
},
"需求列表": [_to_raw_requirement_item(item) for item in values],
}
raw_output = {
"文档元数据": {
"标题": document_name,
"生成时间": datetime.now().isoformat(),
"总需求数": len(updates),
"需求类型统计": {},
},
"需求内容": content,
}
_refresh_metadata(raw_output, updates)
return raw_output
def _section_key(section_uid: Any, section_number: Any, section_title: Any) -> str:
uid = str(section_uid or "").strip()
if uid:
return f"uid::{uid}"
number = str(section_number or "").strip()
title = str(section_title or "").strip()
return f"number::{number}::title::{title}"
def _parse_section_key(value: str) -> Tuple[str, str]:
if value.startswith("uid::"):
return "", "未归类章节"
number = ""
title = ""
parts = value.split("::")
if len(parts) >= 4:
number = parts[1]
title = parts[3]
return number, title

View File

@@ -0,0 +1,236 @@
from __future__ import annotations
import asyncio
from datetime import datetime
from typing import Any, Dict, List
from sqlalchemy.orm import Session
from app.core.config import settings
from app.db.session import SessionLocal
from app.models.knowledge import Document, KnowledgeBase
from app.models.tooling import TestingGeneration, ToolJob
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
from app.services.testing_pipeline import run_testing_pipeline
from app.services.vector_store import VectorStoreFactory
def _flatten_record(value: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
items: List[Dict[str, Any]] = []
for current in value.values():
items.extend(current)
return items
def _build_kb_vector_stores(db: Session, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create()
kb_vector_stores: List[Dict[str, Any]] = []
for kb in knowledge_bases:
documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all()
if not documents:
continue
store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb.id}",
embedding_function=embeddings,
)
kb_vector_stores.append({"kb_id": kb.id, "store": store})
return kb_vector_stores
def _resolve_knowledge_context(
db: Session,
*,
user_id: int,
requirement_text: str,
knowledge_base_id: int | None,
) -> str:
if knowledge_base_id is None:
return ""
try:
knowledge_bases = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id == knowledge_base_id,
KnowledgeBase.user_id == user_id,
)
.all()
)
kb_vector_stores = _build_kb_vector_stores(db, knowledge_bases)
if not kb_vector_stores:
return ""
retriever = MultiKBRetriever(
reranker_weight=settings.RERANKER_WEIGHT,
)
rows = asyncio.run(
retriever.retrieve(
query=requirement_text,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=16,
top_k=8,
)
)
if rows:
return format_retrieval_context(rows)
except Exception:
return ""
return ""
def _build_generated_requirement(req: Dict[str, Any], pipeline_result: Dict[str, Any]) -> Dict[str, Any]:
test_items = [
{
"id": item.get("id"),
"content": item.get("content"),
}
for item in _flatten_record(pipeline_result.get("test_items", {}))
]
test_cases = [
{
"id": item.get("id"),
"itemId": item.get("item_id"),
"testContent": item.get("test_content"),
"operationSteps": item.get("operation_steps", []),
"expectedResultPlaceholder": item.get("expected_result_placeholder"),
}
for item in _flatten_record(pipeline_result.get("test_cases", {}))
]
expected_results = [
{
"id": item.get("id"),
"caseId": item.get("case_id"),
"result": item.get("result"),
}
for item in _flatten_record(pipeline_result.get("expected_results", {}))
]
return {
**req,
"测试项": test_items,
"测试用例": test_cases,
"预期结果": expected_results,
}
def _mark_job_failed(job_id: int, error_message: str) -> None:
db = SessionLocal()
try:
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
if not job:
return
job.status = "failed"
job.completed_at = datetime.utcnow()
job.error_message = error_message[:2000]
db.commit()
finally:
db.close()
def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None:
db = SessionLocal()
try:
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
if not job:
return
requirements = payload.get("requirements") or []
source_document_name = str(payload.get("source_document_name") or job.input_file_name or "")
source_job_id = payload.get("source_job_id")
knowledge_base_id = payload.get("knowledge_base_id")
job.status = "processing"
job.started_at = datetime.utcnow()
job.error_message = None
job.output_summary = {
"source_document_name": source_document_name,
"current_step": 0,
"total_steps": len(requirements),
}
db.commit()
generated_requirements: List[Dict[str, Any]] = []
for index, req in enumerate(requirements):
req_id = str(req.get("id") or f"REQ-{index + 1:03d}")
job.output_summary = {
"source_document_name": source_document_name,
"current_step": index + 1,
"total_steps": len(requirements),
"current_requirement_id": req_id,
}
db.commit()
description = str(req.get("description") or "").strip()
if not description:
generated_requirements.append(
{
**req,
"测试项": [],
"测试用例": [],
"预期结果": [],
}
)
continue
knowledge_context = _resolve_knowledge_context(
db,
user_id=job.user_id,
requirement_text=description,
knowledge_base_id=knowledge_base_id,
)
pipeline_result = run_testing_pipeline(
user_requirement_text=description,
requirement_type_input=req.get("requirementType"),
debug=False,
knowledge_context=knowledge_context,
use_model_generation=True,
max_items_per_group=12,
cases_per_item=2,
max_focus_points=6,
max_llm_calls=10,
)
generated_requirements.append(_build_generated_requirement(req, pipeline_result))
generated_at = datetime.utcnow()
generated_file = {
"sourceDocument": source_document_name,
"sourceJobId": source_job_id,
"generatedAt": generated_at.isoformat(),
"totalRequirements": len(generated_requirements),
"knowledgeBaseId": knowledge_base_id,
"requirements": generated_requirements,
}
generation = TestingGeneration(
job_id=job.id,
source_job_id=source_job_id,
source_document_name=source_document_name,
generated_at=generated_at,
total_requirements=len(generated_requirements),
knowledge_base_id=knowledge_base_id,
generated_file=generated_file,
)
db.add(generation)
job.status = "completed"
job.completed_at = datetime.utcnow()
job.output_summary = {
"source_document_name": source_document_name,
"current_step": len(generated_requirements),
"total_steps": len(generated_requirements),
"total_requirements": len(generated_requirements),
"knowledge_base_id": knowledge_base_id,
}
db.commit()
except Exception as exc:
db.rollback()
_mark_job_failed(job_id, str(exc))
finally:
db.close()

View File

@@ -0,0 +1,111 @@
from datetime import datetime
from typing import Any, Dict, List, Tuple
from sqlalchemy.orm import Session
from app.models.tooling import TestingGeneration, ToolJob
from app.schemas.tooling import TestingGenerationSaveRequest
TESTING_TOOL_NAME = "testing.case_generator"
def _resolve_total_requirements(generated_file: Dict[str, Any]) -> int:
requirements = generated_file.get("requirements")
if isinstance(requirements, list):
return len(requirements)
total = generated_file.get("totalRequirements")
if isinstance(total, int) and total >= 0:
return total
return 0
def build_testing_generation_response(job: ToolJob, generation: TestingGeneration) -> Dict[str, Any]:
return {
"jobId": job.id,
"sourceJobId": generation.source_job_id,
"sourceDocumentName": generation.source_document_name,
"generatedAt": generation.generated_at.isoformat(),
"totalRequirements": generation.total_requirements,
"knowledgeBaseId": generation.knowledge_base_id,
"generatedFile": generation.generated_file or {},
}
def create_testing_generation(
db: Session,
user_id: int,
payload: TestingGenerationSaveRequest,
) -> Dict[str, Any]:
now = datetime.utcnow()
total_requirements = _resolve_total_requirements(payload.generated_file)
job = ToolJob(
user_id=user_id,
tool_name=TESTING_TOOL_NAME,
status="completed",
input_file_name=payload.source_document_name,
input_file_path="",
started_at=now,
completed_at=now,
output_summary={
"source_document_name": payload.source_document_name,
"total_requirements": total_requirements,
"knowledge_base_id": payload.knowledge_base_id,
},
)
db.add(job)
db.flush()
generation = TestingGeneration(
job_id=job.id,
source_job_id=payload.source_job_id,
source_document_name=payload.source_document_name,
generated_at=now,
total_requirements=total_requirements,
knowledge_base_id=payload.knowledge_base_id,
generated_file=payload.generated_file,
)
db.add(generation)
db.commit()
db.refresh(job)
db.refresh(generation)
return build_testing_generation_response(job, generation)
def list_testing_history(db: Session, user_id: int) -> List[Dict[str, Any]]:
rows: List[Tuple[ToolJob, TestingGeneration]] = (
db.query(ToolJob, TestingGeneration)
.join(TestingGeneration, TestingGeneration.job_id == ToolJob.id)
.filter(
ToolJob.user_id == user_id,
ToolJob.tool_name == TESTING_TOOL_NAME,
)
.order_by(ToolJob.created_at.desc())
.all()
)
items: List[Dict[str, Any]] = []
for job, generation in rows:
items.append(
{
"jobId": job.id,
"sourceJobId": generation.source_job_id,
"sourceDocumentName": generation.source_document_name,
"generatedAt": generation.generated_at.isoformat(),
"totalRequirements": generation.total_requirements,
"knowledgeBaseId": generation.knowledge_base_id,
"status": job.status,
"createdAt": job.created_at.isoformat(),
"updatedAt": job.updated_at.isoformat(),
}
)
return items
def delete_testing_generation(db: Session, job: ToolJob) -> None:
db.delete(job)
db.commit()

View File

@@ -4,7 +4,6 @@ from time import perf_counter
from typing import Any, Dict, List, Optional
from uuid import uuid4
from app.services.llm.llm_factory import LLMFactory
from app.services.testing_pipeline.tools import build_default_tool_chain
@@ -42,6 +41,8 @@ def run_testing_pipeline(
llm_model = None
if use_model_generation:
try:
from app.services.llm.llm_factory import LLMFactory
llm_model = LLMFactory.create(streaming=False)
except Exception:
llm_model = None