237 lines
7.6 KiB
Python
237 lines
7.6 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.core.config import settings
|
|
from app.db.session import SessionLocal
|
|
from app.models.knowledge import Document, KnowledgeBase
|
|
from app.models.tooling import TestingGeneration, ToolJob
|
|
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
|
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
|
|
from app.services.testing_pipeline import run_testing_pipeline
|
|
from app.services.vector_store import VectorStoreFactory
|
|
|
|
def _flatten_record(value: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
|
items: List[Dict[str, Any]] = []
|
|
for current in value.values():
|
|
items.extend(current)
|
|
return items
|
|
|
|
|
|
def _build_kb_vector_stores(db: Session, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
|
|
embeddings = EmbeddingsFactory.create()
|
|
kb_vector_stores: List[Dict[str, Any]] = []
|
|
|
|
for kb in knowledge_bases:
|
|
documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all()
|
|
if not documents:
|
|
continue
|
|
|
|
store = VectorStoreFactory.create(
|
|
store_type=settings.VECTOR_STORE_TYPE,
|
|
collection_name=f"kb_{kb.id}",
|
|
embedding_function=embeddings,
|
|
)
|
|
kb_vector_stores.append({"kb_id": kb.id, "store": store})
|
|
|
|
return kb_vector_stores
|
|
|
|
|
|
def _resolve_knowledge_context(
|
|
db: Session,
|
|
*,
|
|
user_id: int,
|
|
requirement_text: str,
|
|
knowledge_base_id: int | None,
|
|
) -> str:
|
|
if knowledge_base_id is None:
|
|
return ""
|
|
|
|
try:
|
|
knowledge_bases = (
|
|
db.query(KnowledgeBase)
|
|
.filter(
|
|
KnowledgeBase.id == knowledge_base_id,
|
|
KnowledgeBase.user_id == user_id,
|
|
)
|
|
.all()
|
|
)
|
|
kb_vector_stores = _build_kb_vector_stores(db, knowledge_bases)
|
|
if not kb_vector_stores:
|
|
return ""
|
|
|
|
retriever = MultiKBRetriever(
|
|
reranker_weight=settings.RERANKER_WEIGHT,
|
|
)
|
|
rows = asyncio.run(
|
|
retriever.retrieve(
|
|
query=requirement_text,
|
|
kb_vector_stores=kb_vector_stores,
|
|
fetch_k_per_kb=16,
|
|
top_k=8,
|
|
)
|
|
)
|
|
if rows:
|
|
return format_retrieval_context(rows)
|
|
except Exception:
|
|
return ""
|
|
|
|
return ""
|
|
|
|
|
|
def _build_generated_requirement(req: Dict[str, Any], pipeline_result: Dict[str, Any]) -> Dict[str, Any]:
|
|
test_items = [
|
|
{
|
|
"id": item.get("id"),
|
|
"content": item.get("content"),
|
|
}
|
|
for item in _flatten_record(pipeline_result.get("test_items", {}))
|
|
]
|
|
test_cases = [
|
|
{
|
|
"id": item.get("id"),
|
|
"itemId": item.get("item_id"),
|
|
"testContent": item.get("test_content"),
|
|
"operationSteps": item.get("operation_steps", []),
|
|
"expectedResultPlaceholder": item.get("expected_result_placeholder"),
|
|
}
|
|
for item in _flatten_record(pipeline_result.get("test_cases", {}))
|
|
]
|
|
expected_results = [
|
|
{
|
|
"id": item.get("id"),
|
|
"caseId": item.get("case_id"),
|
|
"result": item.get("result"),
|
|
}
|
|
for item in _flatten_record(pipeline_result.get("expected_results", {}))
|
|
]
|
|
|
|
return {
|
|
**req,
|
|
"测试项": test_items,
|
|
"测试用例": test_cases,
|
|
"预期结果": expected_results,
|
|
}
|
|
|
|
|
|
def _mark_job_failed(job_id: int, error_message: str) -> None:
|
|
db = SessionLocal()
|
|
try:
|
|
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
|
|
if not job:
|
|
return
|
|
job.status = "failed"
|
|
job.completed_at = datetime.utcnow()
|
|
job.error_message = error_message[:2000]
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None:
|
|
db = SessionLocal()
|
|
try:
|
|
job = db.query(ToolJob).filter(ToolJob.id == job_id).first()
|
|
if not job:
|
|
return
|
|
|
|
requirements = payload.get("requirements") or []
|
|
source_document_name = str(payload.get("source_document_name") or job.input_file_name or "")
|
|
source_job_id = payload.get("source_job_id")
|
|
knowledge_base_id = payload.get("knowledge_base_id")
|
|
|
|
job.status = "processing"
|
|
job.started_at = datetime.utcnow()
|
|
job.error_message = None
|
|
job.output_summary = {
|
|
"source_document_name": source_document_name,
|
|
"current_step": 0,
|
|
"total_steps": len(requirements),
|
|
}
|
|
db.commit()
|
|
|
|
generated_requirements: List[Dict[str, Any]] = []
|
|
|
|
for index, req in enumerate(requirements):
|
|
req_id = str(req.get("id") or f"REQ-{index + 1:03d}")
|
|
job.output_summary = {
|
|
"source_document_name": source_document_name,
|
|
"current_step": index + 1,
|
|
"total_steps": len(requirements),
|
|
"current_requirement_id": req_id,
|
|
}
|
|
db.commit()
|
|
|
|
description = str(req.get("description") or "").strip()
|
|
if not description:
|
|
generated_requirements.append(
|
|
{
|
|
**req,
|
|
"测试项": [],
|
|
"测试用例": [],
|
|
"预期结果": [],
|
|
}
|
|
)
|
|
continue
|
|
|
|
knowledge_context = _resolve_knowledge_context(
|
|
db,
|
|
user_id=job.user_id,
|
|
requirement_text=description,
|
|
knowledge_base_id=knowledge_base_id,
|
|
)
|
|
|
|
pipeline_result = run_testing_pipeline(
|
|
user_requirement_text=description,
|
|
requirement_type_input=req.get("requirementType"),
|
|
debug=False,
|
|
knowledge_context=knowledge_context,
|
|
use_model_generation=True,
|
|
max_items_per_group=12,
|
|
cases_per_item=2,
|
|
max_focus_points=6,
|
|
max_llm_calls=10,
|
|
)
|
|
generated_requirements.append(_build_generated_requirement(req, pipeline_result))
|
|
|
|
generated_at = datetime.utcnow()
|
|
generated_file = {
|
|
"sourceDocument": source_document_name,
|
|
"sourceJobId": source_job_id,
|
|
"generatedAt": generated_at.isoformat(),
|
|
"totalRequirements": len(generated_requirements),
|
|
"knowledgeBaseId": knowledge_base_id,
|
|
"requirements": generated_requirements,
|
|
}
|
|
|
|
generation = TestingGeneration(
|
|
job_id=job.id,
|
|
source_job_id=source_job_id,
|
|
source_document_name=source_document_name,
|
|
generated_at=generated_at,
|
|
total_requirements=len(generated_requirements),
|
|
knowledge_base_id=knowledge_base_id,
|
|
generated_file=generated_file,
|
|
)
|
|
db.add(generation)
|
|
|
|
job.status = "completed"
|
|
job.completed_at = datetime.utcnow()
|
|
job.output_summary = {
|
|
"source_document_name": source_document_name,
|
|
"current_step": len(generated_requirements),
|
|
"total_steps": len(generated_requirements),
|
|
"total_requirements": len(generated_requirements),
|
|
"knowledge_base_id": knowledge_base_id,
|
|
}
|
|
db.commit()
|
|
except Exception as exc:
|
|
db.rollback()
|
|
_mark_job_failed(job_id, str(exc))
|
|
finally:
|
|
db.close()
|