增加代码知识库;修复文档处理内容;增加API设置

This commit is contained in:
2026-05-16 20:20:10 +08:00
parent 69b49d28b2
commit 7aa3ce3294
119 changed files with 182273 additions and 793 deletions

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter
from app.api.api_v1 import api_keys, auth, chat, knowledge_base, testing, tools
from app.api.api_v1 import api_keys, auth, chat, consistency, knowledge_base, model_configs, testing, tools
api_router = APIRouter()
@@ -7,5 +7,7 @@ api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
api_router.include_router(knowledge_base.router, prefix="/knowledge-base", tags=["knowledge-base"])
api_router.include_router(chat.router, prefix="/chat", tags=["chat"])
api_router.include_router(api_keys.router, prefix="/api-keys", tags=["api-keys"])
api_router.include_router(model_configs.router, prefix="/model-configs", tags=["model-configs"])
api_router.include_router(testing.router, prefix="/testing", tags=["testing"])
api_router.include_router(tools.router, prefix="/tools", tags=["tools"])
api_router.include_router(tools.router, prefix="/tools", tags=["tools"])
api_router.include_router(consistency.router, prefix="/consistency", tags=["consistency"])

View File

@@ -39,7 +39,7 @@ def create_api_key(
api_key = APIKeyService.create_api_key(
db=db, user_id=current_user.id, name=api_key_in.name
)
logger.info(f"API key created: {api_key.key} for user {current_user.id}")
logger.info("API key created: id=%s for user %s", api_key.id, current_user.id)
return api_key
@router.put("/{id}", response_model=schemas.APIKey)
@@ -60,7 +60,7 @@ def update_api_key(
raise HTTPException(status_code=403, detail="Not enough permissions")
api_key = APIKeyService.update_api_key(db=db, api_key=api_key, update_data=api_key_in)
logger.info(f"API key updated: {api_key.key} for user {current_user.id}")
logger.info("API key updated: id=%s for user %s", api_key.id, current_user.id)
return api_key
@router.delete("/{id}", response_model=schemas.APIKey)
@@ -80,5 +80,5 @@ def delete_api_key(
raise HTTPException(status_code=403, detail="Not enough permissions")
APIKeyService.delete_api_key(db=db, api_key=api_key)
logger.info(f"API key deleted: {api_key.key} for user {current_user.id}")
logger.info("API key deleted: id=%s for user %s", api_key.id, current_user.id)
return api_key

View File

@@ -120,7 +120,8 @@ async def create_message(
messages=messages,
knowledge_base_ids=knowledge_base_ids,
chat_id=chat_id,
db=db
db=db,
user_id=current_user.id,
):
yield chunk
@@ -152,4 +153,4 @@ def delete_chat(
db.delete(chat)
db.commit()
return {"status": "success"}
return {"status": "success"}

View File

@@ -0,0 +1,338 @@
from __future__ import annotations
from datetime import datetime
from pathlib import Path
from typing import Any, List
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile
from fastapi.responses import Response
from sqlalchemy.orm import Session
from app.core.security import get_current_user
from app.db.session import get_db
from app.models.tooling import ConsistencyResult
from app.models.user import User
from app.schemas.consistency import (
AutoConsistencyJobCreateResponse,
AutoConsistencyJobStatusResponse,
CodeKnowledgeBaseCreate,
CodeKnowledgeBaseResponse,
CodeKnowledgeBaseUploadResponse,
CodeQuestionRequest,
CodeQuestionResponse,
ConsistencyJobCreate,
ConsistencyJobCreateResponse,
ConsistencyJobResponse,
ConsistencyResultResponse,
)
from app.services.consistency.exporter import export_excel, export_json, export_markdown
from app.services.consistency_job_service import (
AUTO_UPLOAD_ROOT,
CODE_UPLOAD_ROOT,
ask_code_kb,
create_code_kb,
create_auto_consistency_tool_job,
create_consistency_job,
create_uploaded_code_kb,
get_owned_auto_job,
get_owned_code_kb,
get_owned_consistency_job,
list_code_kbs,
list_consistency_jobs,
result_model_to_export_dict,
run_auto_consistency_job,
run_code_kb_build,
run_consistency_job,
safe_upload_name,
save_uploaded_bytes,
)
from app.services.model_config import ModelConfigService
router = APIRouter()
def datetime_path() -> str:
return datetime.utcnow().strftime("%Y%m%d%H%M%S%f")
async def _save_code_uploads(files: List[UploadFile], target_dir: Path) -> str:
if not files:
raise HTTPException(status_code=400, detail="No code files uploaded.")
source_dir = target_dir
extracted_dirs: List[Path] = []
for file in files:
content = await file.read()
if not content:
continue
saved = save_uploaded_bytes(target_dir, safe_upload_name(file.filename), content)
if saved.is_dir():
extracted_dirs.append(saved)
if len(files) == 1 and extracted_dirs:
source_dir = extracted_dirs[0]
return str(source_dir.resolve())
async def _save_requirement_upload(file: UploadFile, target_dir: Path) -> str:
safe_name = safe_upload_name(file.filename)
if Path(safe_name).suffix.lower() not in {".pdf", ".docx"}:
raise HTTPException(status_code=400, detail="Requirement file must be .pdf or .docx.")
content = await file.read()
if not content:
raise HTTPException(status_code=400, detail="Requirement file is empty.")
saved = save_uploaded_bytes(target_dir, safe_name, content)
return str(saved.resolve())
@router.post("/code-kbs", response_model=CodeKnowledgeBaseResponse)
async def register_code_kb(
payload: CodeKnowledgeBaseCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
try:
return create_code_kb(db, current_user.id, payload)
except (FileNotFoundError, RuntimeError, ValueError) as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@router.post("/code-kbs/upload", response_model=CodeKnowledgeBaseUploadResponse)
async def upload_and_build_code_kb(
background_tasks: BackgroundTasks,
name: str = Form(...),
use_semantic: bool = Form(True),
files: List[UploadFile] = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
if use_semantic:
try:
ModelConfigService.require_active_config(db, current_user.id)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
target_dir = CODE_UPLOAD_ROOT / str(current_user.id) / datetime_path()
code_source_dir = await _save_code_uploads(files, target_dir)
output_dir = str((target_dir / "artifacts").resolve())
code_kb = create_uploaded_code_kb(
db=db,
user_id=current_user.id,
name=name,
project_path=code_source_dir,
output_dir=output_dir,
)
background_tasks.add_task(run_code_kb_build, code_kb.id, use_semantic)
return {"id": code_kb.id, "status": code_kb.status}
@router.get("/code-kbs", response_model=List[CodeKnowledgeBaseResponse])
async def get_code_kbs(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
return list_code_kbs(db, current_user.id)
@router.get("/code-kbs/{code_kb_id}", response_model=CodeKnowledgeBaseResponse)
async def get_code_kb(
code_kb_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
code_kb = get_owned_code_kb(db, current_user.id, code_kb_id)
if not code_kb:
raise HTTPException(status_code=404, detail="Code knowledge base not found.")
return code_kb
@router.delete("/code-kbs/{code_kb_id}")
async def delete_code_kb(
code_kb_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
code_kb = get_owned_code_kb(db, current_user.id, code_kb_id)
if not code_kb:
raise HTTPException(status_code=404, detail="Code knowledge base not found.")
db.delete(code_kb)
db.commit()
return {"message": "deleted"}
@router.post("/code-kbs/{code_kb_id}/ask", response_model=CodeQuestionResponse)
async def ask_code_kb_api(
code_kb_id: int,
payload: CodeQuestionRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
code_kb = get_owned_code_kb(db, current_user.id, code_kb_id)
if not code_kb:
raise HTTPException(status_code=404, detail="Code knowledge base not found.")
try:
model_profile = ModelConfigService.require_active_config(db, current_user.id)
ModelConfigService.touch_last_used(db, model_profile)
return ask_code_kb(
code_kb=code_kb,
question=payload.question,
top_k=payload.top_k,
min_similarity=payload.min_similarity,
use_llm=payload.use_llm,
model_profile=model_profile,
)
except (RuntimeError, ValueError) as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@router.post("/jobs", response_model=ConsistencyJobCreateResponse)
async def create_job(
background_tasks: BackgroundTasks,
payload: ConsistencyJobCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
try:
ModelConfigService.require_active_config(db, current_user.id)
job = create_consistency_job(db, current_user.id, payload)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
background_tasks.add_task(run_consistency_job, job.id)
return {"job_id": job.id, "status": job.status}
@router.post("/auto-jobs", response_model=AutoConsistencyJobCreateResponse)
async def create_auto_job(
background_tasks: BackgroundTasks,
requirement_file: UploadFile = File(...),
code_files: List[UploadFile] = File(...),
code_kb_name: str = Form("uploaded-code-kb"),
top_k: int = Form(8),
max_call_hops: int = Form(2),
min_similarity: float = Form(0.55),
use_llm: bool = Form(True),
use_semantic: bool = Form(True),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
try:
ModelConfigService.require_active_config(db, current_user.id)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
timestamp = datetime_path()
target_dir = AUTO_UPLOAD_ROOT / str(current_user.id) / timestamp
requirement_path = await _save_requirement_upload(requirement_file, target_dir / "requirement")
code_source_dir = await _save_code_uploads(code_files, target_dir / "code")
tool_job = create_auto_consistency_tool_job(
db=db,
user_id=current_user.id,
requirement_file_path=requirement_path,
requirement_file_name=safe_upload_name(requirement_file.filename),
code_source_dir=code_source_dir,
code_kb_name=code_kb_name,
top_k=top_k,
max_call_hops=max_call_hops,
min_similarity=min_similarity,
use_llm=use_llm,
use_semantic=use_semantic,
)
background_tasks.add_task(run_auto_consistency_job, tool_job.id)
return {"job_id": tool_job.id, "status": tool_job.status}
@router.get("/auto-jobs/{job_id}", response_model=AutoConsistencyJobStatusResponse)
async def get_auto_job(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = get_owned_auto_job(db, current_user.id, job_id)
if not job:
raise HTTPException(status_code=404, detail="Auto consistency job not found.")
summary = job.output_summary or {}
return {
"job_id": job.id,
"status": job.status,
"error_message": job.error_message,
"current_step": summary.get("current_step"),
"srs_extraction_id": summary.get("srs_extraction_id"),
"code_kb_id": summary.get("code_kb_id"),
"consistency_job_id": summary.get("consistency_job_id"),
"created_at": job.created_at,
"updated_at": job.updated_at,
}
@router.get("/jobs", response_model=List[ConsistencyJobResponse])
async def get_jobs(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
return list_consistency_jobs(db, current_user.id)
@router.get("/jobs/{job_id}", response_model=ConsistencyJobResponse)
async def get_job(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = get_owned_consistency_job(db, current_user.id, job_id)
if not job:
raise HTTPException(status_code=404, detail="Consistency job not found.")
return job
@router.get("/jobs/{job_id}/results", response_model=List[ConsistencyResultResponse])
async def get_job_results(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = get_owned_consistency_job(db, current_user.id, job_id)
if not job:
raise HTTPException(status_code=404, detail="Consistency job not found.")
return (
db.query(ConsistencyResult)
.filter(ConsistencyResult.job_id == job.id)
.order_by(ConsistencyResult.id)
.all()
)
@router.get("/jobs/{job_id}/export")
async def export_job_results(
job_id: int,
format: str = Query(default="json", pattern="^(json|markdown|md|excel|xlsx)$"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Response:
job = get_owned_consistency_job(db, current_user.id, job_id)
if not job:
raise HTTPException(status_code=404, detail="Consistency job not found.")
rows = (
db.query(ConsistencyResult)
.filter(ConsistencyResult.job_id == job.id)
.order_by(ConsistencyResult.id)
.all()
)
payload = [result_model_to_export_dict(row) for row in rows]
if format in {"markdown", "md"}:
content = export_markdown(payload).encode("utf-8")
return Response(
content,
media_type="text/markdown; charset=utf-8",
headers={"Content-Disposition": f'attachment; filename="consistency-job-{job.id}.md"'},
)
if format in {"excel", "xlsx"}:
try:
content = export_excel(payload)
except RuntimeError as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
return Response(
content,
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={"Content-Disposition": f'attachment; filename="consistency-job-{job.id}.xlsx"'},
)
return Response(
export_json(payload),
media_type="application/json; charset=utf-8",
headers={"Content-Disposition": f'attachment; filename="consistency-job-{job.id}.json"'},
)

View File

@@ -28,6 +28,7 @@ from app.core.minio import get_minio_client
from minio.error import MinioException
from app.services.vector_store import VectorStoreFactory
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.model_config import ModelConfigService
router = APIRouter()
@@ -163,19 +164,23 @@ async def delete_knowledge_base(
# Get all document file paths before deletion
document_paths = [doc.file_path for doc in kb.documents]
cleanup_errors = []
# Initialize services
minio_client = get_minio_client()
embeddings = EmbeddingsFactory.create()
vector_store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb_id}",
embedding_function=embeddings,
)
vector_store = None
try:
model_profile = ModelConfigService.get_active_config(db, current_user.id)
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
vector_store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb_id}",
embedding_function=embeddings,
)
except Exception as e:
cleanup_errors.append(f"Failed to initialize vector store cleanup: {str(e)}")
# Clean up external resources first
cleanup_errors = []
# 1. Clean up MinIO files
try:
# Delete all objects with prefix kb_{kb_id}/
@@ -188,12 +193,13 @@ async def delete_knowledge_base(
logger.error(f"MinIO cleanup error for kb {kb_id}: {str(e)}")
# 2. Clean up vector store
try:
vector_store._store.delete_collection(f"kb_{kb_id}")
logger.info(f"Cleaned up vector store for knowledge base {kb_id}")
except Exception as e:
cleanup_errors.append(f"Failed to clean up vector store: {str(e)}")
logger.error(f"Vector store cleanup error for kb {kb_id}: {str(e)}")
if vector_store is not None:
try:
vector_store._store.delete_collection(f"kb_{kb_id}")
logger.info(f"Cleaned up vector store for knowledge base {kb_id}")
except Exception as e:
cleanup_errors.append(f"Failed to clean up vector store: {str(e)}")
logger.error(f"Vector store cleanup error for kb {kb_id}: {str(e)}")
# Finally, delete database records in a single transaction
db.delete(kb)
@@ -366,6 +372,11 @@ async def process_kb_documents(
if not upload_ids:
return {"tasks": []}
try:
ModelConfigService.require_active_config(db, current_user.id)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
uploads = db.query(DocumentUpload).filter(DocumentUpload.id.in_(upload_ids)).all()
uploads_dict = {upload.id: upload for upload in uploads}
@@ -411,12 +422,13 @@ async def process_kb_documents(
background_tasks.add_task(
add_processing_tasks_to_queue,
task_data,
kb_id
kb_id,
current_user.id,
)
return {"tasks": task_info}
async def add_processing_tasks_to_queue(task_data, kb_id):
async def add_processing_tasks_to_queue(task_data, kb_id, user_id):
"""Helper function to add document processing tasks to the queue without blocking the main response."""
for data in task_data:
asyncio.create_task(
@@ -425,7 +437,8 @@ async def add_processing_tasks_to_queue(task_data, kb_id):
data["file_name"],
kb_id,
data["task_id"],
None
None,
user_id=user_id,
)
)
logger.info(f"Added {len(task_data)} document processing tasks to queue")
@@ -551,7 +564,11 @@ async def test_retrieval(
detail=f"Knowledge base {request.kb_id} not found",
)
embeddings = EmbeddingsFactory.create()
try:
model_profile = ModelConfigService.require_active_config(db, current_user.id)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
vector_store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
@@ -571,5 +588,7 @@ async def test_retrieval(
return {"results": response}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,78 @@
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.core.security import get_current_user
from app.db.session import get_db
from app.models.user import User
from app.schemas.model_config import (
ModelConfigCreate,
ModelConfigResponse,
ModelConfigUpdate,
ModelProviderOptionsResponse,
)
from app.services.model_config import ModelConfigService, provider_options_response
router = APIRouter()
@router.get("/providers", response_model=ModelProviderOptionsResponse)
def read_model_provider_options(
current_user: User = Depends(get_current_user),
) -> Any:
_ = current_user
return provider_options_response()
@router.get("", response_model=List[ModelConfigResponse])
def read_model_configs(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
return ModelConfigService.list_configs(db, current_user.id)
@router.post("", response_model=ModelConfigResponse)
def create_model_config(
*,
db: Session = Depends(get_db),
payload: ModelConfigCreate,
current_user: User = Depends(get_current_user),
) -> Any:
try:
return ModelConfigService.create_config(db, current_user.id, payload)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@router.put("/{config_id}", response_model=ModelConfigResponse)
def update_model_config(
*,
db: Session = Depends(get_db),
config_id: int,
payload: ModelConfigUpdate,
current_user: User = Depends(get_current_user),
) -> Any:
item = ModelConfigService.get_config(db, current_user.id, config_id)
if not item:
raise HTTPException(status_code=404, detail="Model config not found.")
try:
return ModelConfigService.update_config(db, item, payload)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@router.delete("/{config_id}", response_model=ModelConfigResponse)
def delete_model_config(
*,
db: Session = Depends(get_db),
config_id: int,
current_user: User = Depends(get_current_user),
) -> Any:
item = ModelConfigService.get_config(db, current_user.id, config_id)
if not item:
raise HTTPException(status_code=404, detail="Model config not found.")
response = ModelConfigResponse.model_validate(item)
ModelConfigService.delete_config(db, item)
return response

View File

@@ -12,6 +12,8 @@ from app.models.knowledge import Document, KnowledgeBase
from app.models.user import User
from app.schemas.testing import TestingPipelineRequest, TestingPipelineResponse
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.llm.llm_factory import LLMFactory
from app.services.model_config import ModelConfigService
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
from app.services.testing_pipeline import run_testing_pipeline
from app.services.vector_store import VectorStoreFactory
@@ -21,8 +23,12 @@ logger = logging.getLogger(__name__)
MODEL_PIPELINE_TIMEOUT_SECONDS = 300
async def _build_kb_vector_stores(db: Session, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create()
async def _build_kb_vector_stores(
db: Session,
knowledge_bases: List[KnowledgeBase],
model_profile: Any,
) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
kb_vector_stores: List[Dict[str, Any]] = []
for kb in knowledge_bases:
@@ -47,10 +53,9 @@ async def generate_testing_content(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Any:
_ = current_user
model_profile = ModelConfigService.get_active_config(db, current_user.id)
knowledge_context = (payload.knowledge_context or "").strip()
if payload.knowledge_base_ids:
if payload.knowledge_base_ids and model_profile is not None:
try:
knowledge_bases = (
db.query(KnowledgeBase)
@@ -61,7 +66,7 @@ async def generate_testing_content(
.all()
)
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases, model_profile)
if kb_vector_stores:
retriever = MultiKBRetriever(
reranker_weight=settings.RERANKER_WEIGHT,
@@ -81,13 +86,39 @@ async def generate_testing_content(
payload.knowledge_base_ids,
exc,
)
elif payload.knowledge_base_ids:
logger.warning(
"Testing generation skipped knowledge retrieval because no active model config exists: user=%s knowledge_base_ids=%s",
current_user.id,
payload.knowledge_base_ids,
)
use_model_generation = bool(payload.use_model_generation and model_profile is not None)
llm_model = None
if use_model_generation:
try:
llm_model = LLMFactory.create(streaming=False, model_profile=model_profile)
ModelConfigService.touch_last_used(db, model_profile)
except Exception as exc:
logger.exception(
"Testing generation LLM initialization failed for user=%s, falling back to rule-based output: %s",
current_user.id,
exc,
)
use_model_generation = False
elif payload.use_model_generation:
logger.warning(
"Testing generation falling back to rule-based output because no active model config exists: user=%s",
current_user.id,
)
pipeline_kwargs = {
"user_requirement_text": payload.requirement_text,
"requirement_type_input": payload.requirement_type,
"debug": payload.debug,
"knowledge_context": knowledge_context,
"use_model_generation": payload.use_model_generation,
"use_model_generation": use_model_generation,
"llm_model": llm_model,
"max_items_per_group": payload.max_items_per_group,
"cases_per_item": payload.cases_per_item,
"max_focus_points": payload.max_focus_points,

View File

@@ -34,6 +34,7 @@ from app.services.srs_job_service import (
replace_requirements,
run_srs_job,
)
from app.services.model_config import ModelConfigService
from app.services.testing_generation_service import (
build_testing_generation_response,
create_testing_generation,
@@ -72,6 +73,11 @@ async def create_srs_job(
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(status_code=400, detail="仅支持 .pdf/.docx 文件")
try:
ModelConfigService.require_active_config(db, current_user.id)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
job = ToolJob(
user_id=current_user.id,
tool_name="srs.requirement_extractor",

View File

@@ -9,6 +9,7 @@ from app.db.session import get_db
from app.core.security import get_api_key_user
from app.core.config import settings
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.model_config import ModelConfigService
router = APIRouter()
@@ -36,7 +37,11 @@ def query_knowledge_base(
detail=f"Knowledge base {knowledge_base_id} not found",
)
embeddings = EmbeddingsFactory.create()
try:
model_profile = ModelConfigService.require_active_config(db, current_user.id)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
vector_store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
@@ -56,5 +61,7 @@ def query_knowledge_base(
return {"results": response}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,4 +1,5 @@
import os
import secrets
from typing import List, Optional
from pydantic_settings import BaseSettings
@@ -8,6 +9,7 @@ class Settings(BaseSettings):
PROJECT_NAME: str = "RAG Web UI" # Project name
VERSION: str = "0.1.0" # Project version
API_V1_STR: str = "/api" # API version string
ENVIRONMENT: str = os.getenv("ENVIRONMENT", os.getenv("APP_ENV", "development"))
# MySQL settings
MYSQL_SERVER: str = os.getenv("MYSQL_SERVER", "localhost")
@@ -27,7 +29,7 @@ class Settings(BaseSettings):
)
# JWT settings
SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-here")
SECRET_KEY: str = os.getenv("SECRET_KEY", "")
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "10080"))
@@ -48,9 +50,7 @@ class Settings(BaseSettings):
# OpenAI settings
OPENAI_API_BASE: str = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
OPENAI_API_KEY: str = os.getenv(
"OPENAI_API_KEY", os.getenv("API_KEY", "your-openai-api-key-here")
)
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", os.getenv("API_KEY", ""))
OPENAI_MODEL: str = os.getenv("OPENAI_MODEL", "gpt-4")
OPENAI_EMBEDDINGS_MODEL: str = os.getenv("OPENAI_EMBEDDINGS_MODEL", "text-embedding-ada-002")
@@ -63,7 +63,7 @@ class Settings(BaseSettings):
"DASH_SCOPE_API_BASE", "https://dashscope.aliyuncs.com/compatible-mode/v1"
)
DASH_SCOPE_CHAT_MODEL: str = os.getenv("DASH_SCOPE_CHAT_MODEL", "qwen3-max")
DASH_SCOPE_EMBEDDINGS_MODEL: str = os.getenv("DASH_SCOPE_EMBEDDINGS_MODEL", "")
DASH_SCOPE_EMBEDDINGS_MODEL: str = os.getenv("DASH_SCOPE_EMBEDDINGS_MODEL", "text-embedding-v4")
# Vector Store settings
VECTOR_STORE_TYPE: str = os.getenv("VECTOR_STORE_TYPE", "chroma")
@@ -121,3 +121,7 @@ class Settings(BaseSettings):
settings = Settings()
if not settings.SECRET_KEY:
if settings.ENVIRONMENT.lower() in {"prod", "production"}:
raise ValueError("SECRET_KEY must be set in production.")
settings.SECRET_KEY = secrets.token_urlsafe(32)

View File

@@ -2,7 +2,16 @@ from .user import User
from .knowledge import KnowledgeBase, Document, DocumentChunk
from .chat import Chat, Message
from .api_key import APIKey
from .tooling import ToolJob, SRSExtraction, SRSRequirement, TestingGeneration
from .model_config import UserModelConfig
from .tooling import (
CodeKnowledgeBase,
ConsistencyJob,
ConsistencyResult,
SRSExtraction,
SRSRequirement,
TestingGeneration,
ToolJob,
)
__all__ = [
"User",
@@ -12,8 +21,12 @@ __all__ = [
"Chat",
"Message",
"APIKey",
"UserModelConfig",
"ToolJob",
"SRSExtraction",
"SRSRequirement",
"TestingGeneration",
"CodeKnowledgeBase",
"ConsistencyJob",
"ConsistencyResult",
]

View File

@@ -0,0 +1,34 @@
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text
from sqlalchemy.orm import relationship
from app.models.base import Base, TimestampMixin
class UserModelConfig(Base, TimestampMixin):
__tablename__ = "user_model_configs"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
name = Column(String(255), nullable=False)
provider = Column(String(64), nullable=False, default="dashscope")
api_key = Column(Text, nullable=False)
api_base = Column(String(512), nullable=True)
chat_model = Column(String(128), nullable=False)
embedding_model = Column(String(128), nullable=False)
is_active = Column(Boolean, default=True, nullable=False, index=True)
last_used_at = Column(DateTime(timezone=True), nullable=True)
user = relationship("User", back_populates="model_configs")
@property
def has_api_key(self) -> bool:
return bool((self.api_key or "").strip())
@property
def api_key_masked(self) -> str:
value = (self.api_key or "").strip()
if not value:
return ""
if len(value) <= 10:
return "*" * len(value)
return f"{value[:4]}{'*' * 8}{value[-4:]}"

View File

@@ -1,7 +1,7 @@
from datetime import datetime
import sqlalchemy as sa
from sqlalchemy import Column, DateTime, ForeignKey, Integer, JSON, String, Text
from sqlalchemy import Column, DateTime, Float, ForeignKey, Integer, JSON, String, Text
from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.orm import relationship
@@ -102,3 +102,70 @@ class TestingGeneration(Base, TimestampMixin):
job = relationship("ToolJob", back_populates="testing_generation", foreign_keys=[job_id])
class CodeKnowledgeBase(Base, TimestampMixin):
__tablename__ = "code_knowledge_bases"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
name = Column(String(255), nullable=False)
project_path = Column(String(1024), nullable=True)
vector_path = Column(String(1024), nullable=False)
metadata_path = Column(String(1024), nullable=False)
graph_path = Column(String(1024), nullable=False)
status = Column(String(32), nullable=False, default="active", index=True)
metadata_summary = Column(JSON, nullable=True)
user = relationship("User")
consistency_jobs = relationship(
"ConsistencyJob",
back_populates="code_kb",
cascade="all, delete-orphan",
)
class ConsistencyJob(Base, TimestampMixin):
__tablename__ = "consistency_jobs"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
srs_extraction_id = Column(Integer, ForeignKey("srs_extractions.id"), nullable=False, index=True)
code_kb_id = Column(Integer, ForeignKey("code_knowledge_bases.id"), nullable=False, index=True)
status = Column(String(32), nullable=False, default="pending", index=True)
total_requirements = Column(Integer, nullable=False, default=0)
completed_requirements = Column(Integer, nullable=False, default=0)
output_summary = Column(JSON, nullable=True)
error_message = Column(Text, nullable=True)
started_at = Column(DateTime, nullable=True)
completed_at = Column(DateTime, nullable=True)
user = relationship("User")
srs_extraction = relationship("SRSExtraction")
code_kb = relationship("CodeKnowledgeBase", back_populates="consistency_jobs")
results = relationship(
"ConsistencyResult",
back_populates="job",
cascade="all, delete-orphan",
order_by="ConsistencyResult.id",
)
class ConsistencyResult(Base, TimestampMixin):
__tablename__ = "consistency_results"
id = Column(Integer, primary_key=True, index=True)
job_id = Column(Integer, ForeignKey("consistency_jobs.id", ondelete="CASCADE"), nullable=False, index=True)
requirement_uid = Column(String(64), nullable=False, index=True)
verdict = Column(String(32), nullable=False, index=True)
coverage_score = Column(Float, nullable=False, default=0.0)
confidence = Column(Float, nullable=False, default=0.0)
matched_functions = Column(JSON, nullable=False)
covered_points = Column(JSON, nullable=False)
missing_points = Column(JSON, nullable=False)
conflict_points = Column(JSON, nullable=False)
call_chain_evidence = Column(JSON, nullable=False)
suggestion = Column(Text, nullable=True)
raw_judgment = Column(JSON, nullable=True)
job = relationship("ConsistencyJob", back_populates="results")

View File

@@ -15,4 +15,5 @@ class User(Base, TimestampMixin):
# Relationships
knowledge_bases = relationship("KnowledgeBase", back_populates="user")
chats = relationship("Chat", back_populates="user")
api_keys = relationship("APIKey", back_populates="user", cascade="all, delete-orphan")
api_keys = relationship("APIKey", back_populates="user", cascade="all, delete-orphan")
model_configs = relationship("UserModelConfig", back_populates="user", cascade="all, delete-orphan")

View File

@@ -1,4 +1,10 @@
from .api_key import APIKey, APIKeyCreate, APIKeyUpdate, APIKeyInDB
from .model_config import (
ModelConfigCreate,
ModelConfigResponse,
ModelConfigUpdate,
ModelProviderOptionsResponse,
)
from .user import UserBase, UserCreate, UserUpdate, UserResponse
from .token import Token, TokenPayload
from .knowledge import KnowledgeBaseBase, KnowledgeBaseCreate, KnowledgeBaseUpdate, KnowledgeBaseResponse

View File

@@ -0,0 +1,119 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class CodeKnowledgeBaseCreate(BaseModel):
name: str = Field(..., min_length=1)
project_path: Optional[str] = None
vector_path: str = Field(..., min_length=1)
metadata_path: str = Field(..., min_length=1)
graph_path: str = Field(..., min_length=1)
class CodeKnowledgeBaseUploadResponse(BaseModel):
id: int
status: str
class CodeKnowledgeBaseResponse(BaseModel):
id: int
name: str
project_path: Optional[str] = None
vector_path: str
metadata_path: str
graph_path: str
status: str
metadata_summary: Optional[Dict[str, Any]] = None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class CodeQuestionRequest(BaseModel):
question: str = Field(..., min_length=1)
top_k: int = Field(default=6, ge=1, le=20)
min_similarity: float = Field(default=0.0, ge=0.0, le=1.0)
use_llm: bool = True
class CodeQuestionResponse(BaseModel):
answer: str
evidence: List[Dict[str, Any]]
raw_response: Optional[str] = None
class ConsistencyJobCreate(BaseModel):
srs_extraction_id: int
code_kb_id: int
requirement_uids: Optional[List[str]] = None
top_k: int = Field(default=8, ge=1, le=30)
max_call_hops: int = Field(default=2, ge=0, le=4)
min_similarity: float = Field(default=0.55, ge=0.0, le=1.0)
use_llm: bool = True
class ConsistencyJobCreateResponse(BaseModel):
job_id: int
status: str
class AutoConsistencyJobCreateResponse(BaseModel):
job_id: int
status: str
class AutoConsistencyJobStatusResponse(BaseModel):
job_id: int
status: str
error_message: Optional[str] = None
current_step: Optional[str] = None
srs_extraction_id: Optional[int] = None
code_kb_id: Optional[int] = None
consistency_job_id: Optional[int] = None
created_at: datetime
updated_at: datetime
class ConsistencyJobResponse(BaseModel):
id: int
srs_extraction_id: int
code_kb_id: int
status: str
total_requirements: int
completed_requirements: int
output_summary: Optional[Dict[str, Any]] = None
error_message: Optional[str] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class ConsistencyResultResponse(BaseModel):
id: int
job_id: int
requirement_uid: str
verdict: str
coverage_score: float
confidence: float
matched_functions: List[Dict[str, Any]]
covered_points: List[str]
missing_points: List[str]
conflict_points: List[str]
call_chain_evidence: List[str]
suggestion: Optional[str] = None
raw_judgment: Optional[Dict[str, Any]] = None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True

View File

@@ -0,0 +1,57 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class ModelConfigBase(BaseModel):
name: str
provider: str = "dashscope"
api_base: Optional[str] = None
chat_model: str = "qwen3-max"
embedding_model: str = "text-embedding-v4"
is_active: bool = True
class ModelConfigCreate(ModelConfigBase):
api_key: str = Field(default="")
class ModelConfigUpdate(BaseModel):
name: Optional[str] = None
provider: Optional[str] = None
api_key: Optional[str] = None
api_base: Optional[str] = None
chat_model: Optional[str] = None
embedding_model: Optional[str] = None
is_active: Optional[bool] = None
class ModelConfigResponse(ModelConfigBase):
id: int
user_id: int
has_api_key: bool
api_key_masked: str
last_used_at: Optional[datetime] = None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class ModelProviderOption(BaseModel):
provider: str
label: str
default_api_base: str
default_chat_model: str
default_embedding_model: str
chat_models: List[str]
embedding_models: List[str]
requires_api_key: bool = True
supports_custom_api_base: bool = True
class ModelProviderOptionsResponse(BaseModel):
providers: List[ModelProviderOption]
defaults: Dict[str, Any]

View File

@@ -17,6 +17,7 @@ from app.services.fusion_prompts import (
from app.services.graph.graphrag_adapter import GraphRAGAdapter
from app.services.intent_router import route_intent
from app.services.llm.llm_factory import LLMFactory
from app.services.model_config import ModelConfigService
from app.services.reranker.external_api import ExternalRerankerClient
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
from app.services.testing_pipeline.pipeline import run_testing_pipeline
@@ -202,8 +203,12 @@ def _build_global_community_context_fallback(rows: List[Dict[str, Any]]) -> str:
return "\n\n".join(lines)
async def _build_kb_vector_stores(db: Any, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create()
async def _build_kb_vector_stores(
db: Any,
knowledge_bases: List[KnowledgeBase],
model_profile: Any,
) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
kb_vector_stores: List[Dict[str, Any]] = []
for kb in knowledge_bases:
@@ -221,10 +226,13 @@ async def _build_kb_vector_stores(db: Any, knowledge_bases: List[KnowledgeBase])
return kb_vector_stores
def _build_reranker_client() -> ExternalRerankerClient:
def _build_reranker_client(model_profile: Any = None) -> ExternalRerankerClient:
api_key = settings.RERANKER_API_KEY
if model_profile is not None and getattr(model_profile, "provider", "") == "dashscope":
api_key = getattr(model_profile, "api_key", "") or api_key
return ExternalRerankerClient(
api_url=settings.RERANKER_API_URL,
api_key=settings.RERANKER_API_KEY,
api_key=api_key,
model=settings.RERANKER_MODEL,
timeout_seconds=settings.RERANKER_TIMEOUT_SECONDS,
)
@@ -287,6 +295,7 @@ async def generate_response(
knowledge_base_ids: List[int],
chat_id: int,
db: Any,
user_id: int,
) -> AsyncGenerator[str, None]:
try:
user_message = Message(content=query, role="user", chat_id=chat_id)
@@ -297,6 +306,9 @@ async def generate_response(
db.add(bot_message)
db.commit()
model_profile = ModelConfigService.require_active_config(db, user_id)
ModelConfigService.touch_last_used(db, model_profile)
if _is_testing_generation_request(query):
explicit_type = _extract_requirement_type_from_query(query)
@@ -309,7 +321,7 @@ async def generate_response(
.filter(KnowledgeBase.id.in_(knowledge_base_ids))
.all()
)
kb_vector_stores = await _build_kb_vector_stores(db, testing_kbs)
kb_vector_stores = await _build_kb_vector_stores(db, testing_kbs, model_profile)
if kb_vector_stores:
testing_retriever = MultiKBRetriever(
@@ -330,6 +342,7 @@ async def generate_response(
debug=True,
knowledge_context=knowledge_context,
use_model_generation=True,
llm_model=LLMFactory.create(streaming=False, model_profile=model_profile),
max_items_per_group=6,
cases_per_item=1,
max_focus_points=6,
@@ -391,11 +404,11 @@ async def generate_response(
)
kb_ids = [kb.id for kb in knowledge_bases]
llm = LLMFactory.create()
llm = LLMFactory.create(model_profile=model_profile)
decision = await route_intent(llm=llm, query=query, messages=messages)
intent = decision["intent"]
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases, model_profile)
if intent in {"B", "C", "D"} and not kb_vector_stores:
intent = "A"
decision = {
@@ -403,7 +416,7 @@ async def generate_response(
"reason": "未发现可用知识库向量集合,已降级为通用对话路。",
}
reranker_client = _build_reranker_client()
reranker_client = _build_reranker_client(model_profile)
retriever = MultiKBRetriever(
reranker_client=reranker_client,
reranker_weight=settings.RERANKER_WEIGHT,
@@ -432,7 +445,7 @@ async def generate_response(
used_kb_ids: List[int] = []
if settings.GRAPHRAG_ENABLED and kb_ids:
try:
adapter = GraphRAGAdapter()
adapter = GraphRAGAdapter(model_profile=model_profile)
graph_context, used_kb_ids = await adapter.local_context_multi(
kb_ids,
query,
@@ -465,7 +478,7 @@ async def generate_response(
community_context = ""
if settings.GRAPHRAG_ENABLED and kb_ids:
try:
adapter = GraphRAGAdapter()
adapter = GraphRAGAdapter(model_profile=model_profile)
community_context, used_kb_ids = await adapter.global_context_multi(
kb_ids,
query,

View File

@@ -0,0 +1,10 @@
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext, CodeSearchHit
__all__ = [
"CodeFunctionEvidence",
"CodeGraphContext",
"CodeKnowledgeBaseAdapter",
"CodeSearchHit",
]

View File

@@ -0,0 +1,517 @@
from __future__ import annotations
import json
import logging
import math
import re
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
from app.services.code_kb.graph import CodeCallGraph
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext, CodeSearchHit
logger = logging.getLogger(__name__)
FIELD_ALIASES = {
"name": ("name", "function_name"),
"file": ("file", "file_path"),
"summary": ("summary",),
"logic_flow": ("logic_flow", "logic"),
"code_snippet": ("code_snippet", "source", "code"),
"calls": ("calls", "called_functions"),
"called_by": ("called_by", "caller_functions", "callers"),
}
class SimpleVectorIndex:
def __init__(self, vectors: Any, dimension: int) -> None:
try:
import numpy as np
except ImportError: # pragma: no cover - depends on deployment environment
np = None
self._np = np
self.vectors = [[float(value) for value in vector] for vector in (vectors or [])]
self.d = int(dimension or (len(self.vectors[0]) if self.vectors else 0))
self.ntotal = len(self.vectors)
@classmethod
def from_file(cls, vector_path: str) -> Optional["SimpleVectorIndex"]:
try:
with open(vector_path, "r", encoding="utf-8") as file:
payload = json.load(file)
except (OSError, UnicodeDecodeError, json.JSONDecodeError):
return None
if not isinstance(payload, dict) or payload.get("format") != "simple_l2_vector_index":
return None
return cls(payload.get("vectors") or [], int(payload.get("dimension") or 0))
def search(self, vector: Any, top_k: int) -> Any:
if self.ntotal == 0:
if self._np is not None:
return self._np.array([[]], dtype="float32"), self._np.array([[]], dtype="int64")
return [[]], [[]]
if self._np is not None:
query = self._np.array(vector, dtype="float32")
if query.ndim != 2 or query.shape[1] != self.d:
raise ValueError(f"Query embedding dimension does not match index dimension {self.d}.")
distances = self._np.sum((self._np.array(self.vectors, dtype="float32") - query[0]) ** 2, axis=1)
order = self._np.argsort(distances)[:top_k]
return self._np.array([distances[order]], dtype="float32"), self._np.array([order], dtype="int64")
query_row = vector[0] if isinstance(vector, list) and vector and isinstance(vector[0], list) else vector
query_values = [float(value) for value in query_row]
if len(query_values) != self.d:
raise ValueError(f"Query embedding dimension does not match index dimension {self.d}.")
scored = [
(sum((left - right) ** 2 for left, right in zip(stored, query_values)), index)
for index, stored in enumerate(self.vectors)
]
scored.sort(key=lambda item: item[0])
selected = scored[:top_k]
return [[item[0] for item in selected]], [[item[1] for item in selected]]
def has_nonzero_vectors(self) -> bool:
return any(
any(abs(float(value)) > 1e-12 for value in vector)
for vector in self.vectors
)
class UnavailableVectorIndex:
d = 0
ntotal = 0
def __init__(self, reason: str) -> None:
self.reason = reason
def search(self, vector: Any, top_k: int) -> Any:
raise RuntimeError(self.reason)
def _first_value(source: Dict[str, Any], aliases: Iterable[str], default: Any = None) -> Any:
for alias in aliases:
value = source.get(alias)
if value not in (None, ""):
return value
return default
def _as_list(value: Any) -> List[str]:
if value is None:
return []
if isinstance(value, list):
return [str(item) for item in value if item not in (None, "")]
if isinstance(value, tuple):
return [str(item) for item in value if item not in (None, "")]
if isinstance(value, str):
return [value] if value else []
return [str(value)]
def _as_int(value: Any) -> Optional[int]:
if value in (None, ""):
return None
try:
return int(value)
except (TypeError, ValueError):
return None
def _as_bool(value: Any, default: bool = True) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return bool(value)
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "y"}:
return True
if normalized in {"0", "false", "no", "n"}:
return False
return default
def _load_json(path: str) -> Any:
with open(path, "r", encoding="utf-8") as file:
return json.load(file)
def read_code_kb_summary(metadata_path: str, graph_path: str) -> Dict[str, Any]:
metadata = _load_json(metadata_path)
graph_data = _load_json(graph_path) if graph_path and Path(graph_path).exists() else {}
metadata_items = metadata.get("functions", metadata) if isinstance(metadata, dict) else metadata
metadata_items = metadata_items or []
graph_meta = graph_data.get("metadata", {}) if isinstance(graph_data, dict) else {}
return {
"function_count": len(metadata_items) if isinstance(metadata_items, list) else 0,
"graph_nodes": graph_meta.get("total_nodes"),
"graph_edges": graph_meta.get("total_edges"),
"project_root": graph_meta.get("project_root"),
"generated_at": graph_meta.get("generated_at"),
}
class CodeKnowledgeBaseAdapter:
def __init__(self, embedding_function: Any = None) -> None:
self.embedding_function = embedding_function
self.faiss_index: Any = None
self.metadata: List[Dict[str, Any]] = []
self.graph_data: Dict[str, Any] = {}
self.functions: List[CodeFunctionEvidence] = []
self.functions_by_id: Dict[str, CodeFunctionEvidence] = {}
self.call_graph: Optional[CodeCallGraph] = None
self.vector_search_enabled = False
self.vector_search_disabled_reason = "Code knowledge base has not been loaded."
def load(self, vector_path: str, metadata_path: str, graph_path: str) -> None:
self._validate_paths(vector_path, metadata_path, graph_path)
self.faiss_index = self._read_faiss_index(vector_path)
self.metadata = self._read_metadata(metadata_path)
self.graph_data = _load_json(graph_path)
graph_nodes = self._index_graph_nodes(self.graph_data)
self.functions = [
self._normalize_metadata(row, graph_nodes, index_dimension=self.faiss_index.d)
for row in self.metadata
]
self.functions_by_id = {item.node_id: item for item in self.functions}
self.call_graph = CodeCallGraph(self.graph_data, self.functions)
self._configure_vector_search()
def _validate_paths(self, vector_path: str, metadata_path: str, graph_path: str) -> None:
missing = [
path
for path in [vector_path, metadata_path, graph_path]
if not path or not Path(path).exists()
]
if missing:
raise FileNotFoundError(f"Code knowledge base files not found: {missing}")
def _read_faiss_index(self, vector_path: str) -> Any:
simple_index = SimpleVectorIndex.from_file(vector_path)
if simple_index is not None:
return simple_index
try:
import faiss # type: ignore
except ImportError as exc:
logger.warning("faiss is not installed; code search will use lexical fallback.")
return UnavailableVectorIndex(
"faiss is required to read this vector index. Install faiss-cpu or rebuild the code KB."
)
return faiss.read_index(vector_path)
def _read_metadata(self, metadata_path: str) -> List[Dict[str, Any]]:
metadata = _load_json(metadata_path)
if isinstance(metadata, dict):
metadata = metadata.get("functions", metadata.get("items", []))
if not isinstance(metadata, list):
raise ValueError("Code KB metadata must be a list or a dict with functions/items.")
return [item for item in metadata if isinstance(item, dict)]
def _index_graph_nodes(self, graph_data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
return {
item.get("id"): item
for item in graph_data.get("nodes", []) or []
if isinstance(item, dict) and item.get("id")
}
def _normalize_metadata(
self,
metadata: Dict[str, Any],
graph_nodes: Dict[str, Dict[str, Any]],
index_dimension: int,
) -> CodeFunctionEvidence:
name = _first_value(metadata, FIELD_ALIASES["name"], "")
node_id = metadata.get("node_id") or (f"Function:{name}" if name else "")
graph_node = graph_nodes.get(node_id, {})
raw_attributes = graph_node.get("raw_attributes") or {}
if not name:
name = graph_node.get("name") or node_id.removeprefix("Function:")
file_path = _first_value(metadata, FIELD_ALIASES["file"], graph_node.get("file_path", ""))
embedding_dim = metadata.get("embedding_dim") or index_dimension
embedding_available = _as_bool(
metadata.get("embedding_available", raw_attributes.get("embedding_available")),
default=True,
)
return CodeFunctionEvidence(
node_id=node_id,
name=name,
qualified_name=metadata.get("qualified_name") or node_id.removeprefix("Function:") or name,
file=file_path or "",
start_line=_as_int(metadata.get("start_line") or graph_node.get("start_line")),
end_line=_as_int(metadata.get("end_line") or graph_node.get("end_line")),
signature=metadata.get("signature") or graph_node.get("signature") or "",
summary=_first_value(metadata, FIELD_ALIASES["summary"], graph_node.get("summary", "")) or "",
logic_flow=_first_value(
metadata, FIELD_ALIASES["logic_flow"], graph_node.get("logic_flow", "")
)
or "",
code_snippet=_first_value(
metadata, FIELD_ALIASES["code_snippet"], raw_attributes.get("code_snippet", "")
)
or "",
calls=_as_list(_first_value(metadata, FIELD_ALIASES["calls"], raw_attributes.get("calls", []))),
called_by=_as_list(
_first_value(metadata, FIELD_ALIASES["called_by"], raw_attributes.get("called_by", []))
),
includes=_as_list(metadata.get("includes") or raw_attributes.get("includes")),
embedding_model=metadata.get("embedding_model") or "",
embedding_dim=int(embedding_dim or 0),
embedding_available=embedding_available,
raw=metadata,
)
def _configure_vector_search(self) -> None:
index_total = int(getattr(self.faiss_index, "ntotal", 0))
self.vector_search_enabled = False
self.vector_search_disabled_reason = ""
if isinstance(self.faiss_index, UnavailableVectorIndex):
self.vector_search_disabled_reason = self.faiss_index.reason
return
if index_total == 0:
self.vector_search_disabled_reason = "Vector index is empty."
return
if index_total != len(self.functions):
self.vector_search_disabled_reason = (
f"Vector index size ({index_total}) differs from metadata size ({len(self.functions)})."
)
logger.warning(
"FAISS index size (%s) differs from metadata size (%s)",
index_total,
len(self.functions),
)
return
if not self._index_has_nonzero_vectors():
self.vector_search_disabled_reason = "Vector index contains only zero embeddings."
logger.info("Code KB vector search disabled: %s", self.vector_search_disabled_reason)
return
if not any(item.embedding_available for item in self.functions):
self.vector_search_disabled_reason = "No function metadata has usable embeddings."
logger.info("Code KB vector search disabled: %s", self.vector_search_disabled_reason)
return
self.vector_search_enabled = True
def _index_has_nonzero_vectors(self) -> bool:
has_nonzero_vectors = getattr(self.faiss_index, "has_nonzero_vectors", None)
if callable(has_nonzero_vectors):
return bool(has_nonzero_vectors())
reconstruct_n = getattr(self.faiss_index, "reconstruct_n", None)
if not callable(reconstruct_n):
return True
try:
import numpy as np
index_total = int(getattr(self.faiss_index, "ntotal", 0) or 0)
sample_size = min(index_total, 1000)
if sample_size <= 0:
return False
vectors = reconstruct_n(0, sample_size)
return bool(np.any(np.abs(vectors) > 1e-12))
except Exception as exc: # pragma: no cover - depends on FAISS capabilities
logger.debug("Could not inspect FAISS vectors for zero embeddings: %s", exc)
return True
def _get_embedding_function(self) -> Any:
if self.embedding_function is None:
from app.services.embedding.embedding_factory import EmbeddingsFactory
self.embedding_function = EmbeddingsFactory.create()
return self.embedding_function
def _embed_query(self, query: str) -> np.ndarray:
try:
import numpy as np
except ImportError as exc:
raise RuntimeError("numpy is required to search a code knowledge base.") from exc
embedding_function = self._get_embedding_function()
if hasattr(embedding_function, "embed_query"):
vector = embedding_function.embed_query(query)
elif callable(embedding_function):
vector = embedding_function(query)
else:
raise TypeError("Unsupported embedding function.")
query_vector = np.array([vector], dtype="float32")
index_dimension = int(getattr(self.faiss_index, "d", 0) or 0)
if index_dimension and query_vector.shape[1] != index_dimension:
raise ValueError(
f"Query embedding dimension {query_vector.shape[1]} does not match "
f"FAISS index dimension {index_dimension}."
)
return query_vector
@staticmethod
def distance_to_similarity(distance: float) -> float:
if math.isnan(distance) or distance < 0:
return 0.0
if distance <= 2.0:
return max(0.0, min(1.0, 1.0 - distance / 2.0))
return max(0.0, min(1.0, 1.0 / (1.0 + distance)))
def search_functions(
self,
query: str,
top_k: int = 8,
min_similarity: float = 0.0,
) -> List[CodeSearchHit]:
if self.faiss_index is None:
raise RuntimeError("Code knowledge base has not been loaded.")
if not query.strip():
return []
if not self.vector_search_enabled:
logger.info(
"Vector code search disabled, using lexical search: %s",
self.vector_search_disabled_reason,
)
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
try:
vector = self._embed_query(query)
distances, indices = self.faiss_index.search(vector, top_k)
except Exception as exc:
logger.warning("Vector code search failed, falling back to lexical search: %s", exc)
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
hits: List[CodeSearchHit] = []
for rank, raw_index in enumerate(indices[0], start=1):
index = int(raw_index)
if index < 0 or index >= len(self.functions):
continue
evidence = self.functions[index]
if not evidence.embedding_available:
continue
distance = float(distances[0][rank - 1])
similarity = self.distance_to_similarity(distance)
if similarity < min_similarity:
continue
hits.append(
CodeSearchHit(
evidence=evidence,
similarity=similarity,
distance=distance,
rank=rank,
)
)
if not hits:
return self._lexical_search_functions(query, top_k=top_k, min_similarity=min_similarity)
return hits
def _lexical_search_functions(
self,
query: str,
top_k: int = 8,
min_similarity: float = 0.0,
) -> List[CodeSearchHit]:
query_tokens = self._tokens(query)
if not query_tokens:
return []
scored: List[CodeSearchHit] = []
for evidence in self.functions:
text = self._function_search_text(evidence)
score = self._lexical_similarity(query_tokens, text, evidence)
if score < min_similarity:
continue
scored.append(
CodeSearchHit(
evidence=evidence,
similarity=score,
distance=max(0.0, 1.0 - score),
rank=0,
)
)
scored.sort(key=lambda item: item.similarity, reverse=True)
for rank, item in enumerate(scored[:top_k], start=1):
item.rank = rank
return scored[:top_k]
@staticmethod
def _tokens(text: str) -> List[str]:
normalized = (text or "").lower()
tokens = re.findall(r"[a-z_][a-z0-9_]*|\d+(?:\.\d+)?|[\u4e00-\u9fff]{2,}", normalized)
expanded: List[str] = []
for token in tokens:
expanded.append(token)
if re.fullmatch(r"[\u4e00-\u9fff]{3,}", token):
expanded.extend(token[index : index + 2] for index in range(len(token) - 1))
return expanded
@staticmethod
def _function_search_text(evidence: CodeFunctionEvidence) -> str:
return "\n".join(
[
evidence.node_id,
evidence.name,
evidence.qualified_name,
evidence.file,
evidence.signature,
evidence.summary,
evidence.logic_flow,
evidence.code_snippet[:4000],
" ".join(evidence.calls),
" ".join(evidence.called_by),
" ".join(evidence.includes),
]
).lower()
def _lexical_similarity(
self,
query_tokens: List[str],
text: str,
evidence: CodeFunctionEvidence,
) -> float:
text_tokens = set(self._tokens(text))
if not text_tokens:
return 0.0
unique_query_tokens = list(dict.fromkeys(query_tokens))
overlap_count = sum(1 for token in unique_query_tokens if token in text_tokens)
substring_count = sum(1 for token in unique_query_tokens if token in text)
overlap_score = overlap_count / max(1, len(unique_query_tokens))
substring_score = substring_count / max(1, len(unique_query_tokens))
name_text = f"{evidence.name} {evidence.qualified_name}".lower()
name_hits = sum(1 for token in unique_query_tokens if token in name_text)
name_score = min(1.0, name_hits / max(1, min(len(unique_query_tokens), 4)))
evidence_score = 0.0
if evidence.summary:
evidence_score += 0.08
if evidence.logic_flow:
evidence_score += 0.05
if evidence.code_snippet:
evidence_score += 0.04
if evidence.start_line is not None and evidence.file:
evidence_score += 0.03
return max(
0.0,
min(
1.0,
overlap_score * 0.45
+ substring_score * 0.35
+ name_score * 0.12
+ evidence_score,
),
)
def get_function(self, node_id: str) -> Optional[CodeFunctionEvidence]:
return self.functions_by_id.get(node_id)
def expand_call_context(self, node_id: str, max_hops: int = 2) -> CodeGraphContext:
if not self.call_graph:
return CodeGraphContext(node_id=node_id)
return self.call_graph.expand(node_id, max_hops=max_hops)
def summary(self) -> Dict[str, Any]:
return {
"function_count": len(self.functions),
"index_size": int(getattr(self.faiss_index, "ntotal", 0) or 0),
"embedding_dim": int(getattr(self.faiss_index, "d", 0) or 0),
"embedding_available_count": sum(1 for item in self.functions if item.embedding_available),
"vector_search_enabled": self.vector_search_enabled,
"vector_search_disabled_reason": self.vector_search_disabled_reason,
"graph_nodes": len(self.graph_data.get("nodes", []) or []),
"graph_edges": len(self.graph_data.get("edges", []) or []),
}

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
from typing import Iterable, List
from app.services.code_kb.schema import CodeGraphContext, CodeSearchHit
def _clip(value: str, limit: int) -> str:
value = value or ""
if len(value) <= limit:
return value
return value[:limit].rstrip() + "\n...[truncated]"
def format_evidence_context(
hits: Iterable[CodeSearchHit],
graph_contexts: Iterable[CodeGraphContext],
max_code_chars: int = 1000,
max_text_chars: int = 800,
) -> str:
context_by_node = {item.node_id: item for item in graph_contexts}
blocks: List[str] = []
for hit in hits:
item = hit.evidence
graph = context_by_node.get(item.node_id)
lines = [
f"[Function Evidence #{hit.rank}]",
f"node_id: {item.node_id}",
f"name: {item.name}",
f"qualified_name: {item.qualified_name}",
f"file: {item.file}",
f"lines: {item.start_line}-{item.end_line}",
f"similarity: {hit.similarity:.4f}",
f"signature: {_clip(item.signature, 300)}",
f"summary: {_clip(item.summary, max_text_chars)}",
f"logic_flow: {_clip(item.logic_flow, max_text_chars)}",
f"calls: {', '.join(item.calls[:20]) or '-'}",
f"called_by: {', '.join(item.called_by[:20]) or '-'}",
]
if graph:
lines.append(f"call_chain_evidence: {'; '.join(graph.call_chains[:12]) or '-'}")
if item.code_snippet:
lines.extend(["code_snippet:", _clip(item.code_snippet, max_code_chars)])
blocks.append("\n".join(lines))
return "\n\n---\n\n".join(blocks)

View File

@@ -0,0 +1,105 @@
from __future__ import annotations
from collections import defaultdict, deque
from typing import Callable, Dict, Iterable, List, Optional, Set
from app.services.code_kb.schema import CodeFunctionEvidence, CodeGraphContext
class CodeCallGraph:
def __init__(
self,
graph_data: Optional[dict],
functions: Iterable[CodeFunctionEvidence],
) -> None:
self.functions_by_id: Dict[str, CodeFunctionEvidence] = {
item.node_id: item for item in functions
}
self.name_to_ids: Dict[str, List[str]] = defaultdict(list)
for item in self.functions_by_id.values():
for name in {item.name, item.qualified_name, item.node_id.removeprefix("Function:")}:
if name:
self.name_to_ids[name].append(item.node_id)
self.calls_by_id: Dict[str, Set[str]] = defaultdict(set)
self.called_by_id: Dict[str, Set[str]] = defaultdict(set)
self._load_graph_edges(graph_data or {})
self._load_metadata_edges()
def _load_graph_edges(self, graph_data: dict) -> None:
for edge in graph_data.get("edges", []) or []:
if edge.get("type") != "CALLS":
continue
source_id = edge.get("source_id")
target_id = edge.get("target_id")
if source_id in self.functions_by_id and target_id in self.functions_by_id:
self.calls_by_id[source_id].add(target_id)
self.called_by_id[target_id].add(source_id)
def _resolve_name(self, value: str) -> List[str]:
if not value:
return []
if value in self.functions_by_id:
return [value]
if value.startswith("Function:") and value in self.functions_by_id:
return [value]
return self.name_to_ids.get(value, [])
def _load_metadata_edges(self) -> None:
for function in self.functions_by_id.values():
for callee in function.calls:
for target_id in self._resolve_name(callee):
if target_id != function.node_id:
self.calls_by_id[function.node_id].add(target_id)
self.called_by_id[target_id].add(function.node_id)
for caller in function.called_by:
for source_id in self._resolve_name(caller):
if source_id != function.node_id:
self.called_by_id[function.node_id].add(source_id)
self.calls_by_id[source_id].add(function.node_id)
def _bfs(
self,
start_id: str,
max_hops: int,
next_nodes: Callable[[str], Iterable[str]],
) -> List[CodeFunctionEvidence]:
seen = {start_id}
result: List[CodeFunctionEvidence] = []
queue = deque([(start_id, 0)])
while queue:
current_id, depth = queue.popleft()
if depth >= max_hops:
continue
for next_id in sorted(next_nodes(current_id)):
if next_id in seen:
continue
seen.add(next_id)
function = self.functions_by_id.get(next_id)
if function:
result.append(function)
queue.append((next_id, depth + 1))
return result
def expand(self, node_id: str, max_hops: int = 2) -> CodeGraphContext:
callers = self._bfs(node_id, max_hops, lambda item: self.called_by_id.get(item, set()))
callees = self._bfs(node_id, max_hops, lambda item: self.calls_by_id.get(item, set()))
center = self.functions_by_id.get(node_id)
center_name = center.name if center else node_id
call_chains: List[str] = []
for caller in callers[:5]:
call_chains.append(f"{caller.name} -> {center_name}")
for callee in callees[:5]:
call_chains.append(f"{center_name} -> {callee.name}")
for caller in callers[:3]:
for callee in callees[:3]:
call_chains.append(f"{caller.name} -> {center_name} -> {callee.name}")
return CodeGraphContext(
node_id=node_id,
callers=callers,
callees=callees,
call_chains=list(dict.fromkeys(call_chains)),
)

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
from typing import List
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter
from app.services.code_kb.schema import CodeSearchHit
class CodeFunctionRetriever:
def __init__(self, adapter: CodeKnowledgeBaseAdapter) -> None:
self.adapter = adapter
def retrieve(
self,
query: str,
top_k: int = 8,
min_similarity: float = 0.0,
) -> List[CodeSearchHit]:
return self.adapter.search_functions(
query=query,
top_k=top_k,
min_similarity=min_similarity,
)

View File

@@ -0,0 +1,63 @@
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Optional
@dataclass
class CodeFunctionEvidence:
node_id: str
name: str
qualified_name: str
file: str
start_line: Optional[int] = None
end_line: Optional[int] = None
signature: str = ""
summary: str = ""
logic_flow: str = ""
code_snippet: str = ""
calls: List[str] = field(default_factory=list)
called_by: List[str] = field(default_factory=list)
includes: List[str] = field(default_factory=list)
embedding_model: str = ""
embedding_dim: int = 0
embedding_available: bool = True
raw: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@dataclass
class CodeSearchHit:
evidence: CodeFunctionEvidence
similarity: float
distance: float
rank: int
def to_dict(self) -> Dict[str, Any]:
data = self.evidence.to_dict()
data.update(
{
"similarity": self.similarity,
"distance": self.distance,
"rank": self.rank,
}
)
return data
@dataclass
class CodeGraphContext:
node_id: str
callers: List[CodeFunctionEvidence] = field(default_factory=list)
callees: List[CodeFunctionEvidence] = field(default_factory=list)
call_chains: List[str] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"node_id": self.node_id,
"callers": [item.to_dict() for item in self.callers],
"callees": [item.to_dict() for item in self.callees],
"call_chains": self.call_chains,
}

View File

@@ -0,0 +1,4 @@
from app.services.consistency.comparator import ConsistencyComparator
__all__ = ["ConsistencyComparator"]

View File

@@ -0,0 +1,258 @@
from __future__ import annotations
import json
import logging
import re
from typing import Any, Dict, Iterable, List, Optional
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter
from app.services.code_kb.formatter import format_evidence_context
from app.services.code_kb.schema import CodeGraphContext, CodeSearchHit
from app.services.consistency.prompt import build_judgment_prompt, build_requirement_query
from app.services.consistency.schema import ConsistencyResultItem, RequirementSnapshot, VERDICTS
from app.services.consistency.scorer import coverage_score
logger = logging.getLogger(__name__)
def _clip(value: str, limit: int) -> str:
text = value or ""
if len(text) <= limit:
return text
return text[:limit].rstrip() + "\n...[truncated]"
def _as_list(value: Any) -> List[str]:
if value is None:
return []
if isinstance(value, list):
return [str(item) for item in value if str(item).strip()]
if isinstance(value, tuple):
return [str(item) for item in value if str(item).strip()]
if isinstance(value, str):
text = value.strip()
if not text:
return []
try:
parsed = json.loads(text)
return _as_list(parsed)
except json.JSONDecodeError:
return [line.strip() for line in text.splitlines() if line.strip()]
return [str(value)]
def requirement_to_snapshot(requirement: Any) -> RequirementSnapshot:
getter = requirement.get if isinstance(requirement, dict) else lambda key, default=None: getattr(requirement, key, default)
return RequirementSnapshot(
requirement_uid=getter("requirement_uid") or getter("id") or "",
title=getter("title") or "",
description=getter("description") or "",
acceptance_criteria=_as_list(getter("acceptance_criteria") or getter("acceptanceCriteria")),
requirement_type=getter("requirement_type") or getter("requirementType"),
section_title=getter("section_title") or getter("sectionTitle"),
interface_name=getter("interface_name") or getter("interfaceName"),
interface_type=getter("interface_type") or getter("interfaceType"),
data_source=getter("data_source") or getter("dataSource"),
data_destination=getter("data_destination") or getter("dataDestination"),
)
class ConsistencyComparator:
def __init__(
self,
code_kb_adapter: CodeKnowledgeBaseAdapter,
llm: Any = None,
use_llm: bool = True,
) -> None:
self.code_kb_adapter = code_kb_adapter
self.llm = llm
self.use_llm = use_llm
def compare_requirements(
self,
requirements: Iterable[Any],
top_k: int = 8,
max_call_hops: int = 2,
min_similarity: float = 0.55,
) -> List[ConsistencyResultItem]:
return [
self.compare_requirement(
requirement,
top_k=top_k,
max_call_hops=max_call_hops,
min_similarity=min_similarity,
)
for requirement in requirements
]
def compare_requirement(
self,
requirement: Any,
top_k: int = 8,
max_call_hops: int = 2,
min_similarity: float = 0.55,
) -> ConsistencyResultItem:
snapshot = requirement_to_snapshot(requirement)
query = build_requirement_query(snapshot)
hits = self.code_kb_adapter.search_functions(
query=query,
top_k=top_k,
min_similarity=min_similarity,
)
contexts = [
self.code_kb_adapter.expand_call_context(hit.evidence.node_id, max_hops=max_call_hops)
for hit in hits
]
if not hits:
judgment = self._missing_judgment("未找到满足相似度阈值的函数证据。")
elif not self.use_llm:
judgment = self._heuristic_judgment(hits, contexts)
else:
judgment = self._llm_judgment(snapshot, hits, contexts)
judgment = self._normalize_judgment(judgment)
judgment["requirement_snapshot"] = snapshot.to_dict()
score = coverage_score(snapshot, hits, contexts, judgment)
matched_functions = [self._matched_function_payload(hit) for hit in hits]
call_chains = self._collect_call_chains(contexts)
return ConsistencyResultItem(
requirement_uid=snapshot.requirement_uid,
requirement_title=snapshot.title,
requirement_type=snapshot.requirement_type,
requirement_text=snapshot.description,
verdict=judgment["verdict"],
coverage_score=score,
confidence=float(judgment.get("confidence") or 0.0),
matched_functions=matched_functions,
covered_points=_as_list(judgment.get("covered_points")),
missing_points=_as_list(judgment.get("missing_points")),
conflict_points=_as_list(judgment.get("conflict_points")),
call_chain_evidence=call_chains,
suggestion=str(judgment.get("suggestion") or ""),
raw_judgment=judgment,
)
def _llm_judgment(
self,
requirement: RequirementSnapshot,
hits: List[CodeSearchHit],
contexts: List[CodeGraphContext],
) -> Dict[str, Any]:
try:
evidence_context = format_evidence_context(hits, contexts)
prompt = build_judgment_prompt(requirement, evidence_context)
from app.services.llm.llm_factory import LLMFactory
llm = self.llm or LLMFactory.create(temperature=0, streaming=False)
response = llm.invoke(prompt) if hasattr(llm, "invoke") else llm(prompt)
text = getattr(response, "content", response)
return self.parse_json_judgment(str(text))
except Exception as exc:
logger.exception("LLM consistency judgment failed: %s", exc)
return {
"verdict": "uncertain",
"confidence": 0.2,
"covered_points": [],
"missing_points": ["模型判定失败,无法可靠确认覆盖情况。"],
"conflict_points": [],
"primary_evidence": [hit.evidence.node_id for hit in hits[:3]],
"reasoning": f"LLM judgment failed: {exc}",
"suggestion": "请检查模型配置,或人工复核匹配函数证据。",
"fallback": True,
}
@staticmethod
def parse_json_judgment(raw_text: str) -> Dict[str, Any]:
text = raw_text.strip()
if text.startswith("```"):
text = re.sub(r"^```(?:json)?", "", text, flags=re.IGNORECASE).strip()
text = re.sub(r"```$", "", text).strip()
try:
return json.loads(text)
except json.JSONDecodeError:
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
if match:
return json.loads(match.group(0))
raise
def _heuristic_judgment(
self,
hits: List[CodeSearchHit],
contexts: List[CodeGraphContext],
) -> Dict[str, Any]:
best = hits[0].similarity if hits else 0.0
if best >= 0.78:
verdict = "partial"
confidence = min(0.68, best)
else:
verdict = "uncertain"
confidence = min(0.5, best)
return {
"verdict": verdict,
"confidence": confidence,
"covered_points": [],
"missing_points": ["未启用 LLM 判定,无法细分验收准则覆盖点。"],
"conflict_points": [],
"primary_evidence": [hit.evidence.node_id for hit in hits[:3]],
"reasoning": "仅基于向量召回和调用图生成保守判定。",
"suggestion": "启用模型判定或人工复核主要匹配函数。",
"call_context_count": len(contexts),
}
def _missing_judgment(self, reason: str) -> Dict[str, Any]:
return {
"verdict": "missing",
"confidence": 0.75,
"covered_points": [],
"missing_points": [reason],
"conflict_points": [],
"primary_evidence": [],
"reasoning": reason,
"suggestion": "补充代码实现或降低阈值后重新召回,并人工确认是否存在命名差异。",
}
def _normalize_judgment(self, judgment: Dict[str, Any]) -> Dict[str, Any]:
verdict = str(judgment.get("verdict") or "uncertain").strip().lower()
if verdict not in VERDICTS:
verdict = "uncertain"
confidence = judgment.get("confidence", 0.0)
try:
confidence = max(0.0, min(1.0, float(confidence)))
except (TypeError, ValueError):
confidence = 0.0
normalized = dict(judgment)
normalized["verdict"] = verdict
normalized["confidence"] = confidence
normalized.setdefault("covered_points", [])
normalized.setdefault("missing_points", [])
normalized.setdefault("conflict_points", [])
normalized.setdefault("primary_evidence", [])
normalized.setdefault("reasoning", "")
normalized.setdefault("suggestion", "")
return normalized
def _matched_function_payload(self, hit: CodeSearchHit) -> Dict[str, Any]:
item = hit.evidence
return {
"node_id": item.node_id,
"name": item.name,
"file": item.file,
"start_line": item.start_line,
"end_line": item.end_line,
"similarity": round(hit.similarity, 4),
"role": item.summary[:120] if item.summary else "",
"evidence_summary": item.summary,
"logic_flow": _clip(item.logic_flow, 1200),
"code_snippet": _clip(item.code_snippet, 2000),
"calls": item.calls[:20],
"called_by": item.called_by[:20],
"signature": item.signature,
}
def _collect_call_chains(self, contexts: List[CodeGraphContext]) -> List[str]:
chains: List[str] = []
for context in contexts:
chains.extend(context.call_chains)
return list(dict.fromkeys(chains))[:30]

View File

@@ -0,0 +1,134 @@
from __future__ import annotations
import io
import json
from typing import Any, Dict, Iterable, List
def normalize_result_dicts(results: Iterable[Any]) -> List[Dict[str, Any]]:
normalized: List[Dict[str, Any]] = []
for item in results:
if hasattr(item, "to_dict"):
normalized.append(item.to_dict())
elif isinstance(item, dict):
normalized.append(item)
else:
normalized.append(
{
"requirement_uid": getattr(item, "requirement_uid", ""),
"verdict": getattr(item, "verdict", ""),
"coverage_score": getattr(item, "coverage_score", 0.0),
"confidence": getattr(item, "confidence", 0.0),
"matched_functions": getattr(item, "matched_functions", []),
"covered_points": getattr(item, "covered_points", []),
"missing_points": getattr(item, "missing_points", []),
"conflict_points": getattr(item, "conflict_points", []),
"call_chain_evidence": getattr(item, "call_chain_evidence", []),
"suggestion": getattr(item, "suggestion", ""),
"raw_judgment": getattr(item, "raw_judgment", {}),
}
)
return normalized
def export_json(results: Iterable[Any]) -> bytes:
return json.dumps(
{"results": normalize_result_dicts(results)},
ensure_ascii=False,
indent=2,
).encode("utf-8")
def export_markdown(results: Iterable[Any]) -> str:
rows = normalize_result_dicts(results)
lines = [
"# 需求代码一致性比对报告",
"",
"| 需求 ID | 判定 | 覆盖分 | 置信度 | 匹配函数 | 缺失点 | 建议 |",
"| --- | --- | ---: | ---: | ---: | ---: | --- |",
]
for item in rows:
lines.append(
"| {uid} | {verdict} | {score:.2f} | {confidence:.2f} | {functions} | {missing} | {suggestion} |".format(
uid=item.get("requirement_uid", ""),
verdict=item.get("verdict", ""),
score=float(item.get("coverage_score") or 0),
confidence=float(item.get("confidence") or 0),
functions=len(item.get("matched_functions") or []),
missing=len(item.get("missing_points") or []),
suggestion=str(item.get("suggestion") or "").replace("|", "/"),
)
)
for item in rows:
lines.extend(
[
"",
f"## {item.get('requirement_uid', '')} {item.get('requirement_title', '')}",
"",
f"- 判定: `{item.get('verdict', '')}`",
f"- 覆盖分: {float(item.get('coverage_score') or 0):.2f}",
f"- 置信度: {float(item.get('confidence') or 0):.2f}",
f"- 建议: {item.get('suggestion') or '-'}",
"",
"### 匹配函数",
]
)
for function in item.get("matched_functions") or []:
lines.append(
f"- `{function.get('name')}` {function.get('file')}:{function.get('start_line')} "
f"(similarity={float(function.get('similarity') or 0):.2f})"
)
lines.extend(["", "### 缺失点"])
for point in item.get("missing_points") or ["-"]:
lines.append(f"- {point}")
if item.get("conflict_points"):
lines.extend(["", "### 冲突点"])
for point in item.get("conflict_points") or []:
lines.append(f"- {point}")
return "\n".join(lines)
def export_excel(results: Iterable[Any]) -> bytes:
try:
from openpyxl import Workbook
except ImportError as exc:
raise RuntimeError("openpyxl is required to export Excel reports.") from exc
rows = normalize_result_dicts(results)
workbook = Workbook()
sheet = workbook.active
sheet.title = "Consistency"
headers = [
"需求ID",
"需求标题",
"需求类型",
"判定",
"覆盖分",
"置信度",
"匹配函数数量",
"主要文件",
"缺失点数量",
"建议",
]
sheet.append(headers)
for item in rows:
functions = item.get("matched_functions") or []
sheet.append(
[
item.get("requirement_uid", ""),
item.get("requirement_title", ""),
item.get("requirement_type", ""),
item.get("verdict", ""),
item.get("coverage_score", 0),
item.get("confidence", 0),
len(functions),
functions[0].get("file", "") if functions else "",
len(item.get("missing_points") or []),
item.get("suggestion", ""),
]
)
output = io.BytesIO()
workbook.save(output)
return output.getvalue()

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
import json
from app.services.consistency.schema import RequirementSnapshot
SYSTEM_INSTRUCTION = """你是需求代码一致性审查助手。
只能基于输入的需求、验收准则、函数摘要、代码片段、调用链证据判断。
不得补充未给出的代码事实。
证据不足时输出 uncertain。
输出严格 JSON不要 Markdown。"""
def build_requirement_query(requirement: RequirementSnapshot) -> str:
parts = []
req_type = (requirement.requirement_type or "").lower()
if req_type == "interface":
parts.extend(
[
requirement.interface_name or "",
requirement.interface_type or "",
requirement.data_source or "",
requirement.data_destination or "",
requirement.description,
]
)
else:
parts.extend(
[
requirement.description,
"\n".join(requirement.acceptance_criteria),
requirement.section_title or "",
requirement.interface_name or "",
requirement.data_source or "",
requirement.data_destination or "",
]
)
return "\n".join(part for part in parts if part).strip()
def build_judgment_prompt(requirement: RequirementSnapshot, evidence_context: str) -> str:
payload = {
"requirement": requirement.to_dict(),
"evidence": evidence_context,
"output_schema": {
"verdict": "implemented | partial | missing | conflict | uncertain",
"confidence": 0.0,
"covered_points": [],
"missing_points": [],
"conflict_points": [],
"primary_evidence": [],
"reasoning": "brief reason based only on evidence",
"suggestion": "next action",
},
}
return SYSTEM_INSTRUCTION + "\n\n" + json.dumps(payload, ensure_ascii=False, indent=2)

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
import argparse
from pathlib import Path
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run requirement-code consistency comparison.")
parser.add_argument("--srs-extraction-id", type=int, required=True)
parser.add_argument("--vector-path", required=True)
parser.add_argument("--metadata-path", required=True)
parser.add_argument("--graph-path", required=True)
parser.add_argument("--output", required=True)
parser.add_argument("--output-excel", default=None)
parser.add_argument("--output-markdown", default=None)
parser.add_argument("--top-k", type=int, default=8)
parser.add_argument("--max-call-hops", type=int, default=2)
parser.add_argument("--min-similarity", type=float, default=0.55)
parser.add_argument("--no-llm", action="store_true")
return parser.parse_args()
def main() -> int:
args = parse_args()
from app.db.session import SessionLocal
from app.models.tooling import SRSRequirement
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter
from app.services.consistency.comparator import ConsistencyComparator
from app.services.consistency.exporter import export_excel, export_json, export_markdown
adapter = CodeKnowledgeBaseAdapter()
adapter.load(args.vector_path, args.metadata_path, args.graph_path)
comparator = ConsistencyComparator(adapter, use_llm=not args.no_llm)
db = SessionLocal()
try:
requirements = (
db.query(SRSRequirement)
.filter(SRSRequirement.extraction_id == args.srs_extraction_id)
.order_by(SRSRequirement.sort_order)
.all()
)
results = comparator.compare_requirements(
requirements,
top_k=args.top_k,
max_call_hops=args.max_call_hops,
min_similarity=args.min_similarity,
)
finally:
db.close()
Path(args.output).write_bytes(export_json(results))
if args.output_markdown:
Path(args.output_markdown).write_text(export_markdown(results), encoding="utf-8")
if args.output_excel:
Path(args.output_excel).write_bytes(export_excel(results))
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Optional
VERDICTS = {"implemented", "partial", "missing", "conflict", "uncertain"}
@dataclass
class RequirementSnapshot:
requirement_uid: str
title: str
description: str
acceptance_criteria: List[str] = field(default_factory=list)
requirement_type: Optional[str] = None
section_title: Optional[str] = None
interface_name: Optional[str] = None
interface_type: Optional[str] = None
data_source: Optional[str] = None
data_destination: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@dataclass
class ConsistencyResultItem:
requirement_uid: str
requirement_title: str
requirement_type: Optional[str]
requirement_text: str
verdict: str
coverage_score: float
confidence: float
matched_functions: List[Dict[str, Any]]
covered_points: List[str] = field(default_factory=list)
missing_points: List[str] = field(default_factory=list)
conflict_points: List[str] = field(default_factory=list)
call_chain_evidence: List[str] = field(default_factory=list)
suggestion: str = ""
raw_judgment: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return asdict(self)

View File

@@ -0,0 +1,120 @@
from __future__ import annotations
import re
from typing import Any, Dict, Iterable, List
from app.services.code_kb.schema import CodeGraphContext, CodeSearchHit
from app.services.consistency.schema import RequirementSnapshot
def _clamp(value: float) -> float:
return max(0.0, min(1.0, value))
def _tokens(*values: str) -> List[str]:
text = " ".join(value or "" for value in values).lower()
return [item for item in re.split(r"[^a-z0-9_\u4e00-\u9fff]+", text) if len(item) >= 2]
def semantic_score(hits: List[CodeSearchHit]) -> float:
if not hits:
return 0.0
top = max(hit.similarity for hit in hits)
avg = sum(hit.similarity for hit in hits[:3]) / min(3, len(hits))
return _clamp(top * 0.7 + avg * 0.3)
def acceptance_coverage_score(requirement: RequirementSnapshot, judgment: Dict[str, Any]) -> float:
criteria = requirement.acceptance_criteria or []
covered = judgment.get("covered_points") or []
missing = judgment.get("missing_points") or []
verdict = judgment.get("verdict")
if criteria:
if missing:
return _clamp((len(criteria) - min(len(missing), len(criteria))) / len(criteria))
if covered:
return _clamp(len(covered) / len(criteria))
return 1.0 if verdict == "implemented" else 0.4 if verdict == "partial" else 0.0
return {"implemented": 1.0, "partial": 0.55, "conflict": 0.25, "missing": 0.0}.get(verdict, 0.35)
def evidence_strength_score(hits: List[CodeSearchHit]) -> float:
if not hits:
return 0.0
scores: List[float] = []
for hit in hits[:5]:
item = hit.evidence
checks = [
bool(item.file),
item.start_line is not None,
item.end_line is not None,
bool(item.summary),
bool(item.logic_flow),
bool(item.code_snippet),
]
scores.append(sum(1 for value in checks if value) / len(checks))
return _clamp(sum(scores) / len(scores))
def call_graph_score(contexts: Iterable[CodeGraphContext]) -> float:
contexts = list(contexts)
if not contexts:
return 0.0
scored = []
for context in contexts[:5]:
score = 0.0
if context.callers:
score += 0.35
if context.callees:
score += 0.35
if context.call_chains:
score += 0.30
scored.append(score)
return _clamp(sum(scored) / len(scored))
def exact_match_score(requirement: RequirementSnapshot, hits: List[CodeSearchHit]) -> float:
if not hits:
return 0.0
important = _tokens(
requirement.interface_name or "",
requirement.interface_type or "",
requirement.data_source or "",
requirement.data_destination or "",
requirement.title or "",
)
if not important:
important = _tokens(requirement.description)[:12]
if not important:
return 0.0
evidence_text = " ".join(
f"{hit.evidence.name} {hit.evidence.qualified_name} {hit.evidence.summary} {hit.evidence.logic_flow}"
for hit in hits[:5]
).lower()
matched = sum(1 for token in important if token.lower() in evidence_text)
return _clamp(matched / len(important))
def coverage_score(
requirement: RequirementSnapshot,
hits: List[CodeSearchHit],
contexts: List[CodeGraphContext],
judgment: Dict[str, Any],
) -> float:
score = (
semantic_score(hits) * 0.25
+ acceptance_coverage_score(requirement, judgment) * 0.30
+ evidence_strength_score(hits) * 0.20
+ call_graph_score(contexts) * 0.15
+ exact_match_score(requirement, hits) * 0.10
)
verdict = judgment.get("verdict")
if verdict == "missing":
score = min(score, 0.25)
elif verdict == "uncertain":
score = min(score, 0.55)
elif verdict == "conflict":
score = min(score, 0.45)
return round(_clamp(score), 4)

View File

@@ -0,0 +1,711 @@
from __future__ import annotations
import json
import os
import shutil
import subprocess
import sys
import zipfile
from collections import Counter
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
from sqlalchemy.orm import Session
from app.db.session import SessionLocal
from app.models.tooling import (
CodeKnowledgeBase,
ConsistencyJob,
ConsistencyResult,
SRSExtraction,
SRSRequirement,
ToolJob,
)
from app.schemas.consistency import CodeKnowledgeBaseCreate, ConsistencyJobCreate
from app.services.code_kb.adapter import CodeKnowledgeBaseAdapter, read_code_kb_summary
from app.services.code_kb.formatter import format_evidence_context
from app.services.consistency.comparator import ConsistencyComparator
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.llm.llm_factory import LLMFactory
from app.services.model_config import ModelConfigService
from app.services.srs_job_service import _build_internal_title, _parse_generated_at
from app.tools.srs_reqs_qwen import get_srs_tool
CODE_UPLOAD_ROOT = Path("uploads") / "code_kbs"
AUTO_UPLOAD_ROOT = Path("uploads") / "consistency_auto"
def _workspace_root() -> Path:
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "rag-web-ui").exists() and (parent / "RAG-TEST-TOOLS").exists():
return parent
return current.parents[4]
def _rag_test_tools_root() -> Path:
candidates = [
_workspace_root() / "RAG-TEST-TOOLS",
Path(__file__).resolve().parents[3] / "RAG-TEST-TOOLS",
Path.cwd().parent / "RAG-TEST-TOOLS",
]
for candidate in candidates:
if candidate.exists():
return candidate.resolve()
return candidates[0]
def safe_upload_name(file_name: str | None, fallback: str = "upload.bin") -> str:
safe_name = Path(file_name or fallback).name
return safe_name or fallback
def ensure_upload_dir(*parts: str) -> Path:
path = Path("uploads").joinpath(*parts)
path.mkdir(parents=True, exist_ok=True)
return path
def save_uploaded_bytes(target_dir: Path, file_name: str, content: bytes) -> Path:
target_dir.mkdir(parents=True, exist_ok=True)
path = target_dir / safe_upload_name(file_name)
path.write_bytes(content)
if path.suffix.lower() == ".zip":
extract_dir = target_dir / path.stem
extract_zip_safe(path, extract_dir)
return extract_dir
return path
def extract_zip_safe(zip_path: Path, target_dir: Path) -> None:
target_dir.mkdir(parents=True, exist_ok=True)
target_root = target_dir.resolve()
with zipfile.ZipFile(zip_path) as archive:
for member in archive.infolist():
member_path = (target_dir / member.filename).resolve()
try:
member_path.relative_to(target_root)
except ValueError as exc:
raise ValueError(f"Unsafe zip entry: {member.filename}")
archive.extractall(target_dir)
def _build_code_kb_artifacts(
project_path: str,
output_dir: str,
base_name: str,
use_semantic: bool,
model_profile: Any = None,
) -> Dict[str, str]:
tools_root = _rag_test_tools_root()
if not tools_root.exists():
raise FileNotFoundError(f"RAG-TEST-TOOLS not found: {tools_root}")
output_path = Path(output_dir).resolve()
output_path.mkdir(parents=True, exist_ok=True)
command = [
sys.executable,
"-m",
"rag_test_tools.build_code_kb",
"--project",
str(Path(project_path).resolve()),
"--output",
str(output_path),
"--base-name",
base_name,
]
if not use_semantic:
command.append("--skip-semantic")
env = os.environ.copy()
if model_profile is not None:
api_key = getattr(model_profile, "api_key", "") or ""
api_base = getattr(model_profile, "api_base", "") or ""
if api_key:
env["DASHSCOPE_API_KEY"] = api_key
env["DASH_SCOPE_API_KEY"] = api_key
env["QWEN_API_KEY"] = api_key
if api_base:
env["DASH_SCOPE_API_BASE"] = api_base
env["QWEN_API_URL"] = api_base
if getattr(model_profile, "chat_model", None):
env["QWEN_CHAT_MODEL"] = model_profile.chat_model
if getattr(model_profile, "embedding_model", None):
env["QWEN_EMBEDDING_MODEL"] = model_profile.embedding_model
completed = subprocess.run(
command,
cwd=str(tools_root),
env=env,
capture_output=True,
text=True,
timeout=3600,
check=False,
)
if completed.returncode != 0:
raise RuntimeError(
"Code knowledge base build failed: "
f"{completed.stderr or completed.stdout or completed.returncode}"
)
try:
return json.loads(completed.stdout)
except json.JSONDecodeError as exc:
raise RuntimeError(f"Code KB build returned invalid JSON: {completed.stdout}") from exc
def _ensure_paths_exist(paths: Iterable[str]) -> None:
missing = [path for path in paths if not path or not Path(path).exists()]
if missing:
raise FileNotFoundError(f"Code knowledge base file path does not exist: {missing}")
def create_code_kb(db: Session, user_id: int, payload: CodeKnowledgeBaseCreate) -> CodeKnowledgeBase:
_ensure_paths_exist([payload.vector_path, payload.metadata_path, payload.graph_path])
adapter = CodeKnowledgeBaseAdapter()
adapter.load(payload.vector_path, payload.metadata_path, payload.graph_path)
summary = {
**read_code_kb_summary(payload.metadata_path, payload.graph_path),
**adapter.summary(),
}
code_kb = CodeKnowledgeBase(
user_id=user_id,
name=payload.name,
project_path=payload.project_path,
vector_path=payload.vector_path,
metadata_path=payload.metadata_path,
graph_path=payload.graph_path,
status="active",
metadata_summary=summary,
)
db.add(code_kb)
db.commit()
db.refresh(code_kb)
return code_kb
def create_uploaded_code_kb(
db: Session,
user_id: int,
name: str,
project_path: str,
output_dir: str,
) -> CodeKnowledgeBase:
base_name = f"code_kb_{datetime.utcnow().strftime('%Y%m%d%H%M%S%f')}"
output_path = Path(output_dir).resolve()
code_kb = CodeKnowledgeBase(
user_id=user_id,
name=name,
project_path=project_path,
vector_path=str(output_path / f"{base_name}_rag.faiss"),
metadata_path=str(output_path / f"{base_name}_rag_metadata.json"),
graph_path=str(output_path / f"{base_name}_code_knowledge_graph.json"),
status="pending",
metadata_summary={"base_name": base_name, "output_dir": str(output_path)},
)
db.add(code_kb)
db.commit()
db.refresh(code_kb)
return code_kb
def run_code_kb_build(code_kb_id: int, use_semantic: bool = True) -> None:
db = SessionLocal()
try:
code_kb = db.query(CodeKnowledgeBase).filter(CodeKnowledgeBase.id == code_kb_id).first()
if not code_kb:
return
model_profile = None
if use_semantic:
model_profile = ModelConfigService.require_active_config(db, code_kb.user_id)
ModelConfigService.touch_last_used(db, model_profile)
code_kb.status = "processing"
db.add(code_kb)
db.commit()
summary = code_kb.metadata_summary or {}
base_name = summary.get("base_name") or f"code_kb_{code_kb.id}"
output_dir = summary.get("output_dir") or str(Path(code_kb.vector_path).parent)
artifact_paths = _build_code_kb_artifacts(
project_path=code_kb.project_path or "",
output_dir=output_dir,
base_name=base_name,
use_semantic=use_semantic,
model_profile=model_profile,
)
code_kb.graph_path = artifact_paths["graph_path"]
code_kb.vector_path = artifact_paths["vector_path"]
code_kb.metadata_path = artifact_paths["metadata_path"]
embedding_function = (
EmbeddingsFactory.create(model_profile=model_profile)
if model_profile is not None
else None
)
adapter = CodeKnowledgeBaseAdapter(embedding_function=embedding_function)
adapter.load(code_kb.vector_path, code_kb.metadata_path, code_kb.graph_path)
code_kb.status = "active"
code_kb.metadata_summary = {
**read_code_kb_summary(code_kb.metadata_path, code_kb.graph_path),
**adapter.summary(),
"source": "upload",
}
db.add(code_kb)
db.commit()
except Exception as exc:
if "code_kb" in locals() and code_kb:
code_kb.status = "failed"
code_kb.metadata_summary = {
**(code_kb.metadata_summary or {}),
"error_message": str(exc)[:2000],
}
db.add(code_kb)
db.commit()
finally:
db.close()
def list_code_kbs(db: Session, user_id: int) -> List[CodeKnowledgeBase]:
return (
db.query(CodeKnowledgeBase)
.filter(CodeKnowledgeBase.user_id == user_id)
.order_by(CodeKnowledgeBase.created_at.desc())
.all()
)
def get_owned_code_kb(db: Session, user_id: int, code_kb_id: int) -> Optional[CodeKnowledgeBase]:
return (
db.query(CodeKnowledgeBase)
.filter(CodeKnowledgeBase.id == code_kb_id, CodeKnowledgeBase.user_id == user_id)
.first()
)
def get_owned_srs_extraction(db: Session, user_id: int, extraction_id: int) -> Optional[SRSExtraction]:
return (
db.query(SRSExtraction)
.join(ToolJob, SRSExtraction.job_id == ToolJob.id)
.filter(SRSExtraction.id == extraction_id, ToolJob.user_id == user_id)
.first()
)
def create_consistency_job(
db: Session,
user_id: int,
payload: ConsistencyJobCreate,
) -> ConsistencyJob:
extraction = get_owned_srs_extraction(db, user_id, payload.srs_extraction_id)
if not extraction:
raise ValueError("SRS extraction does not exist.")
code_kb = get_owned_code_kb(db, user_id, payload.code_kb_id)
if not code_kb:
raise ValueError("Code knowledge base does not exist.")
if code_kb.status != "active":
raise ValueError("Code knowledge base is not active.")
requirement_query = db.query(SRSRequirement).filter(SRSRequirement.extraction_id == extraction.id)
if payload.requirement_uids:
requirement_query = requirement_query.filter(SRSRequirement.requirement_uid.in_(payload.requirement_uids))
total = requirement_query.count()
if total == 0:
raise ValueError("No SRS requirements matched the selected scope.")
job = ConsistencyJob(
user_id=user_id,
srs_extraction_id=extraction.id,
code_kb_id=code_kb.id,
status="pending",
total_requirements=total,
completed_requirements=0,
output_summary={
"requirement_uids": payload.requirement_uids,
"top_k": payload.top_k,
"max_call_hops": payload.max_call_hops,
"min_similarity": payload.min_similarity,
"use_llm": payload.use_llm,
},
)
db.add(job)
db.commit()
db.refresh(job)
return job
def list_consistency_jobs(db: Session, user_id: int) -> List[ConsistencyJob]:
return (
db.query(ConsistencyJob)
.filter(ConsistencyJob.user_id == user_id)
.order_by(ConsistencyJob.created_at.desc())
.all()
)
def get_owned_consistency_job(db: Session, user_id: int, job_id: int) -> Optional[ConsistencyJob]:
return (
db.query(ConsistencyJob)
.filter(ConsistencyJob.id == job_id, ConsistencyJob.user_id == user_id)
.first()
)
def ask_code_kb(
code_kb: CodeKnowledgeBase,
question: str,
top_k: int = 6,
min_similarity: float = 0.0,
use_llm: bool = True,
model_profile: Any = None,
) -> Dict[str, Any]:
if code_kb.status != "active":
raise ValueError("Code knowledge base is not active.")
if model_profile is None:
raise ValueError("请先在 API 密钥页面新增并启用模型配置。")
adapter = CodeKnowledgeBaseAdapter(
embedding_function=EmbeddingsFactory.create(model_profile=model_profile)
)
adapter.load(code_kb.vector_path, code_kb.metadata_path, code_kb.graph_path)
hits = adapter.search_functions(question, top_k=top_k, min_similarity=min_similarity)
contexts = [adapter.expand_call_context(hit.evidence.node_id, max_hops=2) for hit in hits]
evidence = [hit.to_dict() for hit in hits]
if not hits:
return {
"answer": "未检索到相关函数证据,无法基于代码知识库回答。",
"evidence": [],
"raw_response": None,
}
evidence_context = format_evidence_context(hits, contexts)
if not use_llm:
return {
"answer": "已检索到相关函数证据,请查看 evidence 字段中的函数摘要、文件位置和调用链。",
"evidence": evidence,
"raw_response": None,
}
prompt = (
"你是代码知识库问答助手。只能基于给定代码证据回答问题;"
"如果证据不足,请明确说明不足。回答需要包含关键函数名和文件位置。\n\n"
f"问题:{question}\n\n代码证据:\n{evidence_context}"
)
try:
llm = LLMFactory.create(temperature=0, streaming=False, model_profile=model_profile)
response = llm.invoke(prompt) if hasattr(llm, "invoke") else llm(prompt)
answer = str(getattr(response, "content", response))
return {"answer": answer, "evidence": evidence, "raw_response": answer}
except Exception as exc:
return {
"answer": f"模型问答失败,已返回检索证据供人工查看。错误:{exc}",
"evidence": evidence,
"raw_response": None,
}
def result_model_to_export_dict(result: ConsistencyResult) -> Dict[str, Any]:
raw = result.raw_judgment or {}
return {
"requirement_uid": result.requirement_uid,
"requirement_title": raw.get("requirement_title", ""),
"requirement_type": raw.get("requirement_type"),
"requirement_text": raw.get("requirement_text", ""),
"verdict": result.verdict,
"coverage_score": result.coverage_score,
"confidence": result.confidence,
"matched_functions": result.matched_functions or [],
"covered_points": result.covered_points or [],
"missing_points": result.missing_points or [],
"conflict_points": result.conflict_points or [],
"call_chain_evidence": result.call_chain_evidence or [],
"suggestion": result.suggestion or "",
"raw_judgment": raw,
}
def _store_result(db: Session, job: ConsistencyJob, result: Any) -> None:
result_dict = result.to_dict()
raw_judgment = dict(result_dict.get("raw_judgment") or {})
raw_judgment.update(
{
"requirement_title": result_dict.get("requirement_title"),
"requirement_type": result_dict.get("requirement_type"),
"requirement_text": result_dict.get("requirement_text"),
}
)
db.add(
ConsistencyResult(
job_id=job.id,
requirement_uid=result.requirement_uid,
verdict=result.verdict,
coverage_score=result.coverage_score,
confidence=result.confidence,
matched_functions=result.matched_functions,
covered_points=result.covered_points,
missing_points=result.missing_points,
conflict_points=result.conflict_points,
call_chain_evidence=result.call_chain_evidence,
suggestion=result.suggestion,
raw_judgment=raw_judgment,
)
)
def _create_srs_extraction_for_job(db: Session, job: ToolJob) -> SRSExtraction:
model_profile = ModelConfigService.require_active_config(db, job.user_id)
ModelConfigService.touch_last_used(db, model_profile)
payload = get_srs_tool().run(job.input_file_path, model_profile=model_profile)
extraction = SRSExtraction(
job_id=job.id,
document_name=payload["document_name"],
document_title=payload.get("document_title") or payload["document_name"],
generated_at=_parse_generated_at(payload.get("generated_at")),
total_requirements=len(payload.get("requirements", [])),
statistics=payload.get("statistics", {}),
raw_output=payload.get("raw_output", {}),
)
db.add(extraction)
db.flush()
for index, item in enumerate(payload.get("requirements", [])):
requirement = SRSRequirement(
extraction_id=extraction.id,
requirement_uid=item.get("id") or f"REQ-{index + 1:03d}",
title=_build_internal_title(item.get("description"), item.get("id") or "", index),
description=item.get("description") or "",
priority=item.get("priority") or "",
acceptance_criteria=item.get("acceptance_criteria") or ["待补充验收标准"],
source_field=item.get("source_field") or "文档解析",
section_uid=item.get("section_uid"),
section_number=item.get("section_number"),
section_title=item.get("section_title"),
requirement_type=item.get("requirement_type"),
interface_name=item.get("interface_name"),
interface_type=item.get("interface_type"),
data_source=item.get("data_source"),
data_destination=item.get("data_destination"),
sort_order=int(item.get("sort_order") or index),
)
db.add(requirement)
return extraction
def create_auto_consistency_tool_job(
db: Session,
user_id: int,
requirement_file_path: str,
requirement_file_name: str,
code_source_dir: str,
code_kb_name: str,
top_k: int,
max_call_hops: int,
min_similarity: float,
use_llm: bool,
use_semantic: bool,
) -> ToolJob:
job = ToolJob(
user_id=user_id,
tool_name="consistency.auto_compare",
status="pending",
input_file_name=requirement_file_name,
input_file_path=requirement_file_path,
output_summary={
"current_step": "pending",
"code_source_dir": code_source_dir,
"code_kb_name": code_kb_name,
"top_k": top_k,
"max_call_hops": max_call_hops,
"min_similarity": min_similarity,
"use_llm": use_llm,
"use_semantic": use_semantic,
},
)
db.add(job)
db.commit()
db.refresh(job)
return job
def get_owned_auto_job(db: Session, user_id: int, job_id: int) -> Optional[ToolJob]:
return (
db.query(ToolJob)
.filter(
ToolJob.id == job_id,
ToolJob.user_id == user_id,
ToolJob.tool_name == "consistency.auto_compare",
)
.first()
)
def run_auto_consistency_job(tool_job_id: int) -> None:
db = SessionLocal()
try:
tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first()
if not tool_job:
return
options = tool_job.output_summary or {}
tool_job.status = "processing"
tool_job.started_at = datetime.utcnow()
tool_job.output_summary = {**options, "current_step": "extracting_requirements"}
db.add(tool_job)
db.commit()
extraction = _create_srs_extraction_for_job(db, tool_job)
db.commit()
options = tool_job.output_summary or options
code_output_dir = str((AUTO_UPLOAD_ROOT / str(tool_job.id) / "code_kb").resolve())
code_kb = create_uploaded_code_kb(
db,
tool_job.user_id,
options.get("code_kb_name") or f"auto-code-kb-{tool_job.id}",
options["code_source_dir"],
code_output_dir,
)
tool_job.output_summary = {
**options,
"current_step": "building_code_kb",
"srs_extraction_id": extraction.id,
"code_kb_id": code_kb.id,
}
db.add(tool_job)
db.commit()
db.close()
run_code_kb_build(code_kb.id, use_semantic=bool(options.get("use_semantic", True)))
db = SessionLocal()
code_kb = db.query(CodeKnowledgeBase).filter(CodeKnowledgeBase.id == code_kb.id).first()
tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first()
if not tool_job or not code_kb:
return
if code_kb.status != "active":
raise RuntimeError((code_kb.metadata_summary or {}).get("error_message") or "Code KB build failed.")
consistency_payload = ConsistencyJobCreate(
srs_extraction_id=extraction.id,
code_kb_id=code_kb.id,
top_k=int(options.get("top_k", 8)),
max_call_hops=int(options.get("max_call_hops", 2)),
min_similarity=float(options.get("min_similarity", 0.55)),
use_llm=bool(options.get("use_llm", True)),
)
consistency_job = create_consistency_job(db, tool_job.user_id, consistency_payload)
tool_job.output_summary = {
**(tool_job.output_summary or {}),
"current_step": "comparing",
"consistency_job_id": consistency_job.id,
}
db.add(tool_job)
db.commit()
db.close()
run_consistency_job(consistency_job.id)
db = SessionLocal()
consistency_job = db.query(ConsistencyJob).filter(ConsistencyJob.id == consistency_job.id).first()
tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first()
if not tool_job or not consistency_job:
return
if consistency_job.status == "failed":
raise RuntimeError(consistency_job.error_message or "Consistency comparison failed.")
tool_job.status = "completed"
tool_job.completed_at = datetime.utcnow()
tool_job.output_summary = {
**(tool_job.output_summary or {}),
"current_step": "completed",
"consistency_job_id": consistency_job.id,
}
db.add(tool_job)
db.commit()
except Exception as exc:
if "db" not in locals() or db is None:
db = SessionLocal()
tool_job = db.query(ToolJob).filter(ToolJob.id == tool_job_id).first()
if tool_job:
tool_job.status = "failed"
tool_job.error_message = str(exc)[:2000]
tool_job.completed_at = datetime.utcnow()
tool_job.output_summary = {
**(tool_job.output_summary or {}),
"current_step": "failed",
}
db.add(tool_job)
db.commit()
finally:
db.close()
def run_consistency_job(job_id: int) -> None:
db = SessionLocal()
try:
job = db.query(ConsistencyJob).filter(ConsistencyJob.id == job_id).first()
if not job:
return
job.status = "processing"
job.started_at = datetime.utcnow()
db.add(job)
db.commit()
options = job.output_summary or {}
code_kb = db.query(CodeKnowledgeBase).filter(CodeKnowledgeBase.id == job.code_kb_id).first()
if not code_kb:
raise RuntimeError("Code knowledge base does not exist.")
model_profile = ModelConfigService.require_active_config(db, job.user_id)
ModelConfigService.touch_last_used(db, model_profile)
adapter = CodeKnowledgeBaseAdapter(
embedding_function=EmbeddingsFactory.create(model_profile=model_profile)
)
adapter.load(code_kb.vector_path, code_kb.metadata_path, code_kb.graph_path)
llm = None
if bool(options.get("use_llm", True)):
llm = LLMFactory.create(temperature=0, streaming=False, model_profile=model_profile)
comparator = ConsistencyComparator(
adapter,
llm=llm,
use_llm=bool(options.get("use_llm", True)),
)
query = (
db.query(SRSRequirement)
.filter(SRSRequirement.extraction_id == job.srs_extraction_id)
.order_by(SRSRequirement.sort_order)
)
requirement_uids = options.get("requirement_uids")
if requirement_uids:
query = query.filter(SRSRequirement.requirement_uid.in_(requirement_uids))
requirements = query.all()
job.total_requirements = len(requirements)
db.add(job)
db.commit()
verdict_counter: Counter[str] = Counter()
for requirement in requirements:
result = comparator.compare_requirement(
requirement,
top_k=int(options.get("top_k", 8)),
max_call_hops=int(options.get("max_call_hops", 2)),
min_similarity=float(options.get("min_similarity", 0.55)),
)
_store_result(db, job, result)
verdict_counter[result.verdict] += 1
job.completed_requirements += 1
job.output_summary = {**options, "verdict_counts": dict(verdict_counter)}
db.add(job)
db.commit()
job.status = "completed"
job.completed_at = datetime.utcnow()
job.output_summary = {**options, "verdict_counts": dict(verdict_counter)}
db.add(job)
db.commit()
except Exception as exc:
if "job" in locals() and job:
job.status = "failed"
job.error_message = str(exc)
job.completed_at = datetime.utcnow()
db.add(job)
db.commit()
finally:
db.close()

View File

@@ -6,6 +6,7 @@ import traceback
import json
from app.db.session import SessionLocal
from io import BytesIO
from types import SimpleNamespace
from typing import Optional, List, Dict, Any
from fastapi import UploadFile
from langchain_community.document_loaders import (
@@ -26,6 +27,7 @@ from minio.error import MinioException
from minio.commonconfig import CopySource
from app.services.vector_store import VectorStoreFactory
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.model_config import ModelConfigService
class UploadResult(BaseModel):
file_path: str
@@ -120,7 +122,45 @@ def _sanitize_metadata_for_vector_store(metadata: Optional[Dict[str, Any]]) -> D
return sanitized
async def process_document(file_path: str, file_name: str, kb_id: int, document_id: int, chunk_size: int = 1000, chunk_overlap: int = 200) -> None:
def _resolve_model_profile(db: Session, user_id: Optional[int]) -> Any:
if user_id is None:
return None
return ModelConfigService.require_active_config(db, user_id)
def _model_profile_snapshot(model_profile: Any) -> Any:
if model_profile is None:
return None
return SimpleNamespace(
provider=model_profile.provider,
api_key=model_profile.api_key,
api_base=model_profile.api_base,
chat_model=model_profile.chat_model,
embedding_model=model_profile.embedding_model,
)
def _load_model_profile_for_user(user_id: Optional[int]) -> Any:
if user_id is None:
return None
db = SessionLocal()
try:
model_profile = ModelConfigService.require_active_config(db, user_id)
ModelConfigService.touch_last_used(db, model_profile)
return _model_profile_snapshot(model_profile)
finally:
db.close()
async def process_document(
file_path: str,
file_name: str,
kb_id: int,
document_id: int,
chunk_size: int = 1000,
chunk_overlap: int = 200,
user_id: Optional[int] = None,
) -> None:
"""Process document and store in vector database with incremental updates"""
logger = logging.getLogger(__name__)
@@ -129,7 +169,8 @@ async def process_document(file_path: str, file_name: str, kb_id: int, document_
# Initialize embeddings
logger.info("Initializing OpenAI embeddings...")
embeddings = EmbeddingsFactory.create()
model_profile = _load_model_profile_for_user(user_id)
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
logger.info(f"Initializing vector store with collection: kb_{kb_id}")
vector_store = VectorStoreFactory.create(
@@ -202,7 +243,7 @@ async def process_document(file_path: str, file_name: str, kb_id: int, document_
try:
from app.services.graph.graphrag_adapter import GraphRAGAdapter
graph_adapter = GraphRAGAdapter()
graph_adapter = GraphRAGAdapter(model_profile=model_profile)
source_texts = [doc.page_content for doc in documents_to_update if doc.page_content.strip()]
await graph_adapter.ingest_texts(kb_id, source_texts)
logger.info("GraphRAG ingestion completed in incremental processing")
@@ -323,7 +364,8 @@ async def process_document_background(
task_id: int,
db: Session = None,
chunk_size: int = 1000,
chunk_overlap: int = 200
chunk_overlap: int = 200,
user_id: Optional[int] = None,
) -> None:
"""Process document in background"""
logger = logging.getLogger(__name__)
@@ -348,6 +390,9 @@ async def process_document_background(
logger.info(f"Task {task_id}: Setting status to processing")
task.status = "processing"
db.commit()
model_profile = _resolve_model_profile(db, user_id)
if model_profile is not None:
ModelConfigService.touch_last_used(db, model_profile)
# 1. 从临时目录下载文件
minio_client = get_minio_client()
@@ -416,7 +461,7 @@ async def process_document_background(
# 3. 创建向量存储
logger.info(f"Task {task_id}: Initializing vector store")
embeddings = EmbeddingsFactory.create()
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
vector_store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
@@ -520,7 +565,7 @@ async def process_document_background(
from app.services.graph.graphrag_adapter import GraphRAGAdapter
logger.info(f"Task {task_id}: Starting GraphRAG ingestion")
graph_adapter = GraphRAGAdapter()
graph_adapter = GraphRAGAdapter(model_profile=model_profile)
source_texts = [doc.page_content for doc in documents if doc.page_content.strip()]
await graph_adapter.ingest_texts(kb_id, source_texts)
logger.info(f"Task {task_id}: GraphRAG ingestion completed")

View File

@@ -1,30 +1,39 @@
from app.core.config import settings
from langchain_openai import OpenAIEmbeddings
from langchain_ollama import OllamaEmbeddings
from typing import Optional
# If you plan on adding other embeddings, import them here
# from some_other_module import AnotherEmbeddingClass
class EmbeddingsFactory:
@staticmethod
def create():
def create(provider: Optional[str] = None, model_profile: Optional[object] = None):
"""
Factory method to create an embeddings instance based on .env config.
"""
# Suppose your .env has a value like EMBEDDINGS_PROVIDER=openai
embeddings_provider = settings.EMBEDDINGS_PROVIDER.lower()
if model_profile is not None:
embeddings_provider = (provider or getattr(model_profile, "provider", None) or "dashscope").lower()
api_key = getattr(model_profile, "api_key", "") or ""
api_base = getattr(model_profile, "api_base", None) or _default_api_base(embeddings_provider)
model = getattr(model_profile, "embedding_model", None) or _default_embedding_model(embeddings_provider)
else:
embeddings_provider = (provider or settings.EMBEDDINGS_PROVIDER).lower()
api_key = _default_api_key(embeddings_provider)
api_base = _default_api_base(embeddings_provider)
model = _default_embedding_model(embeddings_provider)
if embeddings_provider == "openai":
return OpenAIEmbeddings(
openai_api_key=settings.OPENAI_API_KEY,
openai_api_base=settings.OPENAI_API_BASE,
model=settings.OPENAI_EMBEDDINGS_MODEL
openai_api_key=api_key,
openai_api_base=api_base,
model=model
)
elif embeddings_provider == "dashscope":
elif embeddings_provider in {"dashscope", "openai_compatible"}:
return OpenAIEmbeddings(
openai_api_key=settings.DASH_SCOPE_API_KEY,
openai_api_base=settings.DASH_SCOPE_API_BASE,
model=settings.DASH_SCOPE_EMBEDDINGS_MODEL,
openai_api_key=api_key,
openai_api_base=api_base,
model=model,
# DashScope OpenAI-compatible embedding expects string input,
# while LangChain's len-safe path may send token ids.
check_embedding_ctx_length=False,
@@ -35,8 +44,8 @@ class EmbeddingsFactory:
)
elif embeddings_provider == "ollama":
return OllamaEmbeddings(
model=settings.OLLAMA_EMBEDDINGS_MODEL,
base_url=settings.OLLAMA_API_BASE
model=model,
base_url=api_base
)
# Extend with other providers:
@@ -44,3 +53,34 @@ class EmbeddingsFactory:
# return AnotherEmbeddingClass(...)
else:
raise ValueError(f"Unsupported embeddings provider: {embeddings_provider}")
def _default_embedding_model(provider: Optional[str]) -> str:
provider = (provider or settings.EMBEDDINGS_PROVIDER).lower()
if provider == "openai":
return settings.OPENAI_EMBEDDINGS_MODEL
if provider == "dashscope":
return settings.DASH_SCOPE_EMBEDDINGS_MODEL or "text-embedding-v4"
if provider == "ollama":
return settings.OLLAMA_EMBEDDINGS_MODEL
return settings.DASH_SCOPE_EMBEDDINGS_MODEL or settings.OPENAI_EMBEDDINGS_MODEL
def _default_api_key(provider: Optional[str]) -> str:
provider = (provider or settings.EMBEDDINGS_PROVIDER).lower()
if provider == "openai":
return settings.OPENAI_API_KEY
if provider == "dashscope":
return settings.DASH_SCOPE_API_KEY
return settings.API_KEY
def _default_api_base(provider: Optional[str]) -> str:
provider = (provider or settings.EMBEDDINGS_PROVIDER).lower()
if provider == "openai":
return settings.OPENAI_API_BASE
if provider == "dashscope":
return settings.DASH_SCOPE_API_BASE
if provider == "ollama":
return settings.OLLAMA_API_BASE
return settings.DASH_SCOPE_API_BASE

View File

@@ -13,11 +13,11 @@ from app.services.llm.llm_factory import LLMFactory
class GraphRAGAdapter:
_instance_lock = asyncio.Lock()
def __init__(self):
def __init__(self, model_profile: Any = None):
self._graphrag_instances: Dict[int, Any] = {}
self._kb_locks: Dict[int, asyncio.Lock] = {}
self._embedding_model = EmbeddingsFactory.create()
self._llm_model = LLMFactory.create(streaming=False)
self._embedding_model = EmbeddingsFactory.create(model_profile=model_profile)
self._llm_model = LLMFactory.create(streaming=False, model_profile=model_profile)
self._symbols = self._load_symbols()
def _load_symbols(self) -> Dict[str, Any]:

View File

@@ -11,42 +11,51 @@ class LLMFactory:
provider: Optional[str] = None,
temperature: float = 0,
streaming: bool = True,
model_profile: Optional[object] = None,
) -> BaseChatModel:
"""
Create a LLM instance based on the provider
"""
# If no provider specified, use the one from settings
provider = provider or settings.CHAT_PROVIDER
if model_profile is not None:
provider = (provider or getattr(model_profile, "provider", None) or "dashscope").lower()
model = getattr(model_profile, "chat_model", None) or _default_chat_model(provider)
api_key = getattr(model_profile, "api_key", "") or ""
api_base = getattr(model_profile, "api_base", None) or _default_api_base(provider)
else:
provider = provider or settings.CHAT_PROVIDER
model = _default_chat_model(provider)
api_key = _default_api_key(provider)
api_base = _default_api_base(provider)
if provider.lower() == "openai":
return ChatOpenAI(
temperature=temperature,
streaming=streaming,
model=settings.OPENAI_MODEL,
openai_api_key=settings.OPENAI_API_KEY,
openai_api_base=settings.OPENAI_API_BASE
model=model,
openai_api_key=api_key,
openai_api_base=api_base
)
elif provider.lower() == "deepseek":
return ChatDeepSeek(
temperature=temperature,
streaming=streaming,
model=settings.DEEPSEEK_MODEL,
api_key=settings.DEEPSEEK_API_KEY,
api_base=settings.DEEPSEEK_API_BASE
model=model,
api_key=api_key,
api_base=api_base
)
elif provider.lower() == "dashscope":
elif provider.lower() in {"dashscope", "openai_compatible"}:
return ChatOpenAI(
temperature=temperature,
streaming=streaming,
model=settings.DASH_SCOPE_CHAT_MODEL,
openai_api_key=settings.DASH_SCOPE_API_KEY,
openai_api_base=settings.DASH_SCOPE_API_BASE,
model=model,
openai_api_key=api_key,
openai_api_base=api_base,
)
elif provider.lower() == "ollama":
# Initialize Ollama model
return OllamaLLM(
model=settings.OLLAMA_MODEL,
base_url=settings.OLLAMA_API_BASE,
model=model,
base_url=api_base,
temperature=temperature,
streaming=streaming
)
@@ -54,4 +63,41 @@ class LLMFactory:
# elif provider.lower() == "anthropic":
# return ChatAnthropic(...)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
raise ValueError(f"Unsupported LLM provider: {provider}")
def _default_chat_model(provider: Optional[str]) -> str:
provider = (provider or settings.CHAT_PROVIDER).lower()
if provider == "openai":
return settings.OPENAI_MODEL
if provider == "deepseek":
return settings.DEEPSEEK_MODEL
if provider == "dashscope":
return settings.DASH_SCOPE_CHAT_MODEL
if provider == "ollama":
return settings.OLLAMA_MODEL
return settings.DASH_SCOPE_CHAT_MODEL or settings.OPENAI_MODEL
def _default_api_key(provider: Optional[str]) -> str:
provider = (provider or settings.CHAT_PROVIDER).lower()
if provider == "openai":
return settings.OPENAI_API_KEY
if provider == "deepseek":
return settings.DEEPSEEK_API_KEY
if provider == "dashscope":
return settings.DASH_SCOPE_API_KEY
return settings.API_KEY
def _default_api_base(provider: Optional[str]) -> str:
provider = (provider or settings.CHAT_PROVIDER).lower()
if provider == "openai":
return settings.OPENAI_API_BASE
if provider == "deepseek":
return settings.DEEPSEEK_API_BASE
if provider == "dashscope":
return settings.DASH_SCOPE_API_BASE
if provider == "ollama":
return settings.OLLAMA_API_BASE
return settings.DASH_SCOPE_API_BASE

View File

@@ -0,0 +1,211 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from app.core.config import settings
from app.models.model_config import UserModelConfig
from app.schemas.model_config import ModelConfigCreate, ModelConfigUpdate
PROVIDER_OPTIONS: List[Dict[str, Any]] = [
{
"provider": "dashscope",
"label": "DashScope",
"default_api_base": settings.DASH_SCOPE_API_BASE,
"default_chat_model": settings.DASH_SCOPE_CHAT_MODEL or "qwen3-max",
"default_embedding_model": settings.DASH_SCOPE_EMBEDDINGS_MODEL or "text-embedding-v4",
"chat_models": ["qwen3-max", "qwen-plus", "qwen-turbo", "qwen-max"],
"embedding_models": ["text-embedding-v4", "text-embedding-v3", "text-embedding-v2"],
"requires_api_key": True,
"supports_custom_api_base": True,
},
{
"provider": "openai",
"label": "OpenAI",
"default_api_base": settings.OPENAI_API_BASE,
"default_chat_model": settings.OPENAI_MODEL or "gpt-4o",
"default_embedding_model": settings.OPENAI_EMBEDDINGS_MODEL or "text-embedding-3-small",
"chat_models": ["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"],
"embedding_models": ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
"requires_api_key": True,
"supports_custom_api_base": True,
},
{
"provider": "openai_compatible",
"label": "OpenAI Compatible",
"default_api_base": "",
"default_chat_model": "qwen3-max",
"default_embedding_model": "text-embedding-v4",
"chat_models": ["qwen3-max", "deepseek-chat", "gpt-4o-mini"],
"embedding_models": ["text-embedding-v4", "text-embedding-3-small"],
"requires_api_key": True,
"supports_custom_api_base": True,
},
{
"provider": "ollama",
"label": "Ollama",
"default_api_base": settings.OLLAMA_API_BASE,
"default_chat_model": settings.OLLAMA_MODEL,
"default_embedding_model": settings.OLLAMA_EMBEDDINGS_MODEL,
"chat_models": [settings.OLLAMA_MODEL, "llama3.1", "qwen2.5", "deepseek-r1:7b"],
"embedding_models": [settings.OLLAMA_EMBEDDINGS_MODEL, "nomic-embed-text", "mxbai-embed-large"],
"requires_api_key": False,
"supports_custom_api_base": True,
},
]
def provider_options_response() -> Dict[str, Any]:
first = PROVIDER_OPTIONS[0]
return {
"providers": PROVIDER_OPTIONS,
"defaults": {
"provider": first["provider"],
"api_base": first["default_api_base"],
"chat_model": first["default_chat_model"],
"embedding_model": first["default_embedding_model"],
},
}
def _provider_option(provider: str) -> Dict[str, Any]:
normalized = (provider or "dashscope").strip().lower()
for option in PROVIDER_OPTIONS:
if option["provider"] == normalized:
return option
raise ValueError(f"Unsupported model provider: {provider}")
def _normalized_payload(payload: Dict[str, Any], existing: Optional[UserModelConfig] = None) -> Dict[str, Any]:
provider = str(payload.get("provider") or getattr(existing, "provider", "dashscope")).strip().lower()
option = _provider_option(provider)
api_base = payload.get("api_base")
if api_base is None and existing is not None:
api_base = existing.api_base
if not api_base:
api_base = option["default_api_base"]
if option["supports_custom_api_base"] and provider == "openai_compatible" and not api_base:
raise ValueError("OpenAI Compatible provider requires an API base URL.")
chat_model = str(payload.get("chat_model") or getattr(existing, "chat_model", "") or option["default_chat_model"]).strip()
embedding_model = str(
payload.get("embedding_model")
or getattr(existing, "embedding_model", "")
or option["default_embedding_model"]
).strip()
if not chat_model:
raise ValueError("Chat model is required.")
if not embedding_model:
raise ValueError("Embedding model is required.")
api_key = payload.get("api_key")
if api_key is None and existing is not None:
api_key = existing.api_key
api_key = str(api_key or "").strip()
if option["requires_api_key"] and not api_key:
raise ValueError("API key is required for this provider.")
name = payload.get("name")
if name is None and existing is not None:
name = existing.name
name = str(name or "").strip()
if not name:
name = option["label"]
return {
"name": name,
"provider": provider,
"api_key": api_key,
"api_base": str(api_base).strip() if api_base else None,
"chat_model": chat_model,
"embedding_model": embedding_model,
"is_active": bool(payload.get("is_active", getattr(existing, "is_active", True))),
}
class ModelConfigService:
@staticmethod
def list_configs(db: Session, user_id: int) -> List[UserModelConfig]:
return (
db.query(UserModelConfig)
.filter(UserModelConfig.user_id == user_id)
.order_by(UserModelConfig.is_active.desc(), UserModelConfig.updated_at.desc())
.all()
)
@staticmethod
def get_config(db: Session, user_id: int, config_id: int) -> Optional[UserModelConfig]:
return (
db.query(UserModelConfig)
.filter(UserModelConfig.id == config_id, UserModelConfig.user_id == user_id)
.first()
)
@staticmethod
def get_active_config(db: Session, user_id: int) -> Optional[UserModelConfig]:
return (
db.query(UserModelConfig)
.filter(UserModelConfig.user_id == user_id, UserModelConfig.is_active.is_(True))
.order_by(UserModelConfig.updated_at.desc())
.first()
)
@staticmethod
def require_active_config(db: Session, user_id: int) -> UserModelConfig:
config = ModelConfigService.get_active_config(db, user_id)
if config is None:
raise ValueError("请先在 API 密钥页面新增并启用模型配置。")
return config
@staticmethod
def create_config(db: Session, user_id: int, payload: ModelConfigCreate) -> UserModelConfig:
data = _normalized_payload(payload.model_dump())
if data["is_active"]:
ModelConfigService._deactivate_user_configs(db, user_id)
item = UserModelConfig(user_id=user_id, **data)
db.add(item)
db.commit()
db.refresh(item)
return item
@staticmethod
def update_config(
db: Session,
item: UserModelConfig,
payload: ModelConfigUpdate,
) -> UserModelConfig:
raw = payload.model_dump(exclude_unset=True)
if raw.get("api_key") == "":
raw.pop("api_key")
data = _normalized_payload(raw, existing=item)
if data["is_active"]:
ModelConfigService._deactivate_user_configs(db, item.user_id, exclude_id=item.id)
for field, value in data.items():
setattr(item, field, value)
db.add(item)
db.commit()
db.refresh(item)
return item
@staticmethod
def delete_config(db: Session, item: UserModelConfig) -> None:
db.delete(item)
db.commit()
@staticmethod
def touch_last_used(db: Session, item: UserModelConfig) -> UserModelConfig:
item.last_used_at = datetime.utcnow()
db.add(item)
db.commit()
db.refresh(item)
return item
@staticmethod
def _deactivate_user_configs(db: Session, user_id: int, exclude_id: Optional[int] = None) -> None:
query = db.query(UserModelConfig).filter(UserModelConfig.user_id == user_id)
if exclude_id is not None:
query = query.filter(UserModelConfig.id != exclude_id)
query.update({UserModelConfig.is_active: False}, synchronize_session=False)

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from app.db.session import SessionLocal
from app.models.tooling import SRSExtraction, SRSRequirement, ToolJob
from app.services.model_config import ModelConfigService
from app.tools.srs_reqs_qwen import get_srs_tool
TYPE_TO_CHINESE = {
@@ -63,7 +64,9 @@ def run_srs_job(job_id: int) -> None:
job.error_message = None
db.commit()
payload = get_srs_tool().run(job.input_file_path)
model_profile = ModelConfigService.require_active_config(db, job.user_id)
ModelConfigService.touch_last_used(db, model_profile)
payload = get_srs_tool().run(job.input_file_path, model_profile=model_profile)
extraction = SRSExtraction(
job_id=job.id,

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import logging
from datetime import datetime
from typing import Any, Dict, List
@@ -11,10 +12,15 @@ from app.db.session import SessionLocal
from app.models.knowledge import Document, KnowledgeBase
from app.models.tooling import TestingGeneration, ToolJob
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.llm.llm_factory import LLMFactory
from app.services.model_config import ModelConfigService
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
from app.services.testing_pipeline import run_testing_pipeline
from app.services.vector_store import VectorStoreFactory
logger = logging.getLogger(__name__)
def _flatten_record(value: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
items: List[Dict[str, Any]] = []
for current in value.values():
@@ -22,8 +28,15 @@ def _flatten_record(value: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, An
return items
def _build_kb_vector_stores(db: Session, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create()
def _build_kb_vector_stores(
db: Session,
knowledge_bases: List[KnowledgeBase],
model_profile: Any,
) -> List[Dict[str, Any]]:
if model_profile is None:
return []
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
kb_vector_stores: List[Dict[str, Any]] = []
for kb in knowledge_bases:
@@ -47,8 +60,9 @@ def _resolve_knowledge_context(
user_id: int,
requirement_text: str,
knowledge_base_id: int | None,
model_profile: Any,
) -> str:
if knowledge_base_id is None:
if knowledge_base_id is None or model_profile is None:
return ""
try:
@@ -60,7 +74,7 @@ def _resolve_knowledge_context(
)
.all()
)
kb_vector_stores = _build_kb_vector_stores(db, knowledge_bases)
kb_vector_stores = _build_kb_vector_stores(db, knowledge_bases, model_profile)
if not kb_vector_stores:
return ""
@@ -143,6 +157,27 @@ def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None:
source_document_name = str(payload.get("source_document_name") or job.input_file_name or "")
source_job_id = payload.get("source_job_id")
knowledge_base_id = payload.get("knowledge_base_id")
model_profile = ModelConfigService.get_active_config(db, job.user_id)
if model_profile is not None:
ModelConfigService.touch_last_used(db, model_profile)
use_model_generation = model_profile is not None
llm_model = None
if use_model_generation:
try:
llm_model = LLMFactory.create(streaming=False, model_profile=model_profile)
except Exception as exc:
logger.exception(
"Testing generation LLM initialization failed for job=%s, falling back to rule-based output: %s",
job_id,
exc,
)
use_model_generation = False
else:
logger.info(
"Testing generation job=%s has no active model config; using rule-based output.",
job_id,
)
job.status = "processing"
job.started_at = datetime.utcnow()
@@ -183,6 +218,7 @@ def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None:
user_id=job.user_id,
requirement_text=description,
knowledge_base_id=knowledge_base_id,
model_profile=model_profile,
)
pipeline_result = run_testing_pipeline(
@@ -190,7 +226,8 @@ def run_testing_generation_job(job_id: int, payload: Dict[str, Any]) -> None:
requirement_type_input=req.get("requirementType"),
debug=False,
knowledge_context=knowledge_context,
use_model_generation=True,
use_model_generation=use_model_generation,
llm_model=llm_model,
max_items_per_group=12,
cases_per_item=2,
max_focus_points=6,

View File

@@ -33,13 +33,13 @@ def run_testing_pipeline(
debug: bool = False,
knowledge_context: Optional[str] = None,
use_model_generation: bool = False,
llm_model: Any = None,
max_items_per_group: int = 12,
cases_per_item: int = 2,
max_focus_points: int = 6,
max_llm_calls: int = 10,
) -> Dict[str, Any]:
llm_model = None
if use_model_generation:
if use_model_generation and llm_model is None:
try:
from app.services.llm.llm_factory import LLMFactory

View File

@@ -6,6 +6,7 @@ from typing import Generator, Tuple
from alembic.config import Config
from alembic.config import main as alembic_main
from alembic.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.engine import Connection
@@ -52,13 +53,8 @@ class DatabaseMigrator:
with self.database_connection() as connection:
context = MigrationContext.configure(connection)
current_rev = context.get_current_revision()
heads = context.get_current_heads()
if not heads:
logger.warning("No migration heads found. Database might not be initialized.")
return True, current_rev or "None", "head"
head_rev = heads[0]
script = ScriptDirectory.from_config(self.alembic_cfg)
head_rev = script.get_current_head()
return current_rev != head_rev, current_rev or "None", head_rev
def _get_alembic_config(self) -> Config:

View File

@@ -10,7 +10,7 @@ llm:
# 模型名称
model: "glm-5"
# API密钥建议使用环境变量 DASHSCOPE_API_KEY
api_key: "sk-7097f7842f724f0c9e70c4bf3b16dacb"
api_key: ""
# 可选参数
temperature: 0.3
max_tokens: 1024

View File

@@ -56,11 +56,11 @@ class SRSTool:
def __init__(self) -> None:
ToolRegistry.register(self.DEFINITION)
def run(self, file_path: str, enable_llm: bool = True) -> Dict[str, Any]:
def run(self, file_path: str, enable_llm: bool = True, model_profile: Any = None) -> Dict[str, Any]:
if not enable_llm:
raise ValueError("当前版本仅支持LLM模式请将 enable_llm 设为 true")
config = self._load_config()
config = self._load_config(model_profile=model_profile)
llm = self._build_llm(config, enable_llm=enable_llm)
parser = create_parser(file_path)
@@ -164,7 +164,7 @@ class SRSTool:
return text
return f"{text[:20].rstrip()}"
def _load_config(self) -> Dict[str, Any]:
def _load_config(self, model_profile: Any = None) -> Dict[str, Any]:
config_path = Path(__file__).with_name("default_config.yaml")
if config_path.exists():
with config_path.open("r", encoding="utf-8") as handle:
@@ -173,9 +173,14 @@ class SRSTool:
config = {}
config.setdefault("llm", {})
config["llm"]["model"] = settings.DASH_SCOPE_CHAT_MODEL or settings.OPENAI_MODEL
config["llm"]["api_key"] = settings.DASH_SCOPE_API_KEY or os.getenv("DASHSCOPE_API_KEY", "")
config["llm"]["api_base"] = settings.DASH_SCOPE_API_BASE
if model_profile is not None:
config["llm"]["model"] = getattr(model_profile, "chat_model", None) or settings.DASH_SCOPE_CHAT_MODEL
config["llm"]["api_key"] = getattr(model_profile, "api_key", "") or ""
config["llm"]["api_base"] = getattr(model_profile, "api_base", None) or settings.DASH_SCOPE_API_BASE
else:
config["llm"]["model"] = settings.DASH_SCOPE_CHAT_MODEL or settings.OPENAI_MODEL
config["llm"]["api_key"] = settings.DASH_SCOPE_API_KEY or os.getenv("DASHSCOPE_API_KEY", "")
config["llm"]["api_base"] = settings.DASH_SCOPE_API_BASE
config["llm"]["enabled"] = bool(config["llm"].get("api_key"))
return config
@@ -186,7 +191,7 @@ class SRSTool:
llm_cfg = config.get("llm", {})
api_key = llm_cfg.get("api_key")
if not api_key:
raise ValueError("未配置API密钥:请设置 DASH_SCOPE_API_KEY 或 DASHSCOPE_API_KEY")
raise ValueError("未配置模型 API 密钥,请先在 API 密钥页面新增并启用模型配置。")
return QwenLLM(
api_key=api_key,