增加代码知识库;修复文档处理内容;增加API设置
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter
|
||||
from app.api.api_v1 import api_keys, auth, chat, knowledge_base, testing, tools
|
||||
from app.api.api_v1 import api_keys, auth, chat, consistency, knowledge_base, model_configs, testing, tools
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -7,5 +7,7 @@ api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
api_router.include_router(knowledge_base.router, prefix="/knowledge-base", tags=["knowledge-base"])
|
||||
api_router.include_router(chat.router, prefix="/chat", tags=["chat"])
|
||||
api_router.include_router(api_keys.router, prefix="/api-keys", tags=["api-keys"])
|
||||
api_router.include_router(model_configs.router, prefix="/model-configs", tags=["model-configs"])
|
||||
api_router.include_router(testing.router, prefix="/testing", tags=["testing"])
|
||||
api_router.include_router(tools.router, prefix="/tools", tags=["tools"])
|
||||
api_router.include_router(tools.router, prefix="/tools", tags=["tools"])
|
||||
api_router.include_router(consistency.router, prefix="/consistency", tags=["consistency"])
|
||||
|
||||
@@ -39,7 +39,7 @@ def create_api_key(
|
||||
api_key = APIKeyService.create_api_key(
|
||||
db=db, user_id=current_user.id, name=api_key_in.name
|
||||
)
|
||||
logger.info(f"API key created: {api_key.key} for user {current_user.id}")
|
||||
logger.info("API key created: id=%s for user %s", api_key.id, current_user.id)
|
||||
return api_key
|
||||
|
||||
@router.put("/{id}", response_model=schemas.APIKey)
|
||||
@@ -60,7 +60,7 @@ def update_api_key(
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
api_key = APIKeyService.update_api_key(db=db, api_key=api_key, update_data=api_key_in)
|
||||
logger.info(f"API key updated: {api_key.key} for user {current_user.id}")
|
||||
logger.info("API key updated: id=%s for user %s", api_key.id, current_user.id)
|
||||
return api_key
|
||||
|
||||
@router.delete("/{id}", response_model=schemas.APIKey)
|
||||
@@ -80,5 +80,5 @@ def delete_api_key(
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
APIKeyService.delete_api_key(db=db, api_key=api_key)
|
||||
logger.info(f"API key deleted: {api_key.key} for user {current_user.id}")
|
||||
logger.info("API key deleted: id=%s for user %s", api_key.id, current_user.id)
|
||||
return api_key
|
||||
|
||||
@@ -120,7 +120,8 @@ async def create_message(
|
||||
messages=messages,
|
||||
knowledge_base_ids=knowledge_base_ids,
|
||||
chat_id=chat_id,
|
||||
db=db
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
@@ -152,4 +153,4 @@ def delete_chat(
|
||||
|
||||
db.delete(chat)
|
||||
db.commit()
|
||||
return {"status": "success"}
|
||||
return {"status": "success"}
|
||||
|
||||
338
rag-web-ui/backend/app/api/api_v1/consistency.py
Normal file
338
rag-web-ui/backend/app/api/api_v1/consistency.py
Normal file
@@ -0,0 +1,338 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.db.session import get_db
|
||||
from app.models.tooling import ConsistencyResult
|
||||
from app.models.user import User
|
||||
from app.schemas.consistency import (
|
||||
AutoConsistencyJobCreateResponse,
|
||||
AutoConsistencyJobStatusResponse,
|
||||
CodeKnowledgeBaseCreate,
|
||||
CodeKnowledgeBaseResponse,
|
||||
CodeKnowledgeBaseUploadResponse,
|
||||
CodeQuestionRequest,
|
||||
CodeQuestionResponse,
|
||||
ConsistencyJobCreate,
|
||||
ConsistencyJobCreateResponse,
|
||||
ConsistencyJobResponse,
|
||||
ConsistencyResultResponse,
|
||||
)
|
||||
from app.services.consistency.exporter import export_excel, export_json, export_markdown
|
||||
from app.services.consistency_job_service import (
|
||||
AUTO_UPLOAD_ROOT,
|
||||
CODE_UPLOAD_ROOT,
|
||||
ask_code_kb,
|
||||
create_code_kb,
|
||||
create_auto_consistency_tool_job,
|
||||
create_consistency_job,
|
||||
create_uploaded_code_kb,
|
||||
get_owned_auto_job,
|
||||
get_owned_code_kb,
|
||||
get_owned_consistency_job,
|
||||
list_code_kbs,
|
||||
list_consistency_jobs,
|
||||
result_model_to_export_dict,
|
||||
run_auto_consistency_job,
|
||||
run_code_kb_build,
|
||||
run_consistency_job,
|
||||
safe_upload_name,
|
||||
save_uploaded_bytes,
|
||||
)
|
||||
from app.services.model_config import ModelConfigService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def datetime_path() -> str:
|
||||
return datetime.utcnow().strftime("%Y%m%d%H%M%S%f")
|
||||
|
||||
|
||||
async def _save_code_uploads(files: List[UploadFile], target_dir: Path) -> str:
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No code files uploaded.")
|
||||
source_dir = target_dir
|
||||
extracted_dirs: List[Path] = []
|
||||
for file in files:
|
||||
content = await file.read()
|
||||
if not content:
|
||||
continue
|
||||
saved = save_uploaded_bytes(target_dir, safe_upload_name(file.filename), content)
|
||||
if saved.is_dir():
|
||||
extracted_dirs.append(saved)
|
||||
if len(files) == 1 and extracted_dirs:
|
||||
source_dir = extracted_dirs[0]
|
||||
return str(source_dir.resolve())
|
||||
|
||||
|
||||
async def _save_requirement_upload(file: UploadFile, target_dir: Path) -> str:
|
||||
safe_name = safe_upload_name(file.filename)
|
||||
if Path(safe_name).suffix.lower() not in {".pdf", ".docx"}:
|
||||
raise HTTPException(status_code=400, detail="Requirement file must be .pdf or .docx.")
|
||||
content = await file.read()
|
||||
if not content:
|
||||
raise HTTPException(status_code=400, detail="Requirement file is empty.")
|
||||
saved = save_uploaded_bytes(target_dir, safe_name, content)
|
||||
return str(saved.resolve())
|
||||
|
||||
|
||||
@router.post("/code-kbs", response_model=CodeKnowledgeBaseResponse)
|
||||
async def register_code_kb(
|
||||
payload: CodeKnowledgeBaseCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
try:
|
||||
return create_code_kb(db, current_user.id, payload)
|
||||
except (FileNotFoundError, RuntimeError, ValueError) as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post("/code-kbs/upload", response_model=CodeKnowledgeBaseUploadResponse)
|
||||
async def upload_and_build_code_kb(
|
||||
background_tasks: BackgroundTasks,
|
||||
name: str = Form(...),
|
||||
use_semantic: bool = Form(True),
|
||||
files: List[UploadFile] = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
if use_semantic:
|
||||
try:
|
||||
ModelConfigService.require_active_config(db, current_user.id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
target_dir = CODE_UPLOAD_ROOT / str(current_user.id) / datetime_path()
|
||||
code_source_dir = await _save_code_uploads(files, target_dir)
|
||||
output_dir = str((target_dir / "artifacts").resolve())
|
||||
code_kb = create_uploaded_code_kb(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
name=name,
|
||||
project_path=code_source_dir,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
background_tasks.add_task(run_code_kb_build, code_kb.id, use_semantic)
|
||||
return {"id": code_kb.id, "status": code_kb.status}
|
||||
|
||||
|
||||
@router.get("/code-kbs", response_model=List[CodeKnowledgeBaseResponse])
|
||||
async def get_code_kbs(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
return list_code_kbs(db, current_user.id)
|
||||
|
||||
|
||||
@router.get("/code-kbs/{code_kb_id}", response_model=CodeKnowledgeBaseResponse)
|
||||
async def get_code_kb(
|
||||
code_kb_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
code_kb = get_owned_code_kb(db, current_user.id, code_kb_id)
|
||||
if not code_kb:
|
||||
raise HTTPException(status_code=404, detail="Code knowledge base not found.")
|
||||
return code_kb
|
||||
|
||||
|
||||
@router.delete("/code-kbs/{code_kb_id}")
|
||||
async def delete_code_kb(
|
||||
code_kb_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
code_kb = get_owned_code_kb(db, current_user.id, code_kb_id)
|
||||
if not code_kb:
|
||||
raise HTTPException(status_code=404, detail="Code knowledge base not found.")
|
||||
db.delete(code_kb)
|
||||
db.commit()
|
||||
return {"message": "deleted"}
|
||||
|
||||
|
||||
@router.post("/code-kbs/{code_kb_id}/ask", response_model=CodeQuestionResponse)
|
||||
async def ask_code_kb_api(
|
||||
code_kb_id: int,
|
||||
payload: CodeQuestionRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
code_kb = get_owned_code_kb(db, current_user.id, code_kb_id)
|
||||
if not code_kb:
|
||||
raise HTTPException(status_code=404, detail="Code knowledge base not found.")
|
||||
try:
|
||||
model_profile = ModelConfigService.require_active_config(db, current_user.id)
|
||||
ModelConfigService.touch_last_used(db, model_profile)
|
||||
return ask_code_kb(
|
||||
code_kb=code_kb,
|
||||
question=payload.question,
|
||||
top_k=payload.top_k,
|
||||
min_similarity=payload.min_similarity,
|
||||
use_llm=payload.use_llm,
|
||||
model_profile=model_profile,
|
||||
)
|
||||
except (RuntimeError, ValueError) as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post("/jobs", response_model=ConsistencyJobCreateResponse)
|
||||
async def create_job(
|
||||
background_tasks: BackgroundTasks,
|
||||
payload: ConsistencyJobCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
try:
|
||||
ModelConfigService.require_active_config(db, current_user.id)
|
||||
job = create_consistency_job(db, current_user.id, payload)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
background_tasks.add_task(run_consistency_job, job.id)
|
||||
return {"job_id": job.id, "status": job.status}
|
||||
|
||||
|
||||
@router.post("/auto-jobs", response_model=AutoConsistencyJobCreateResponse)
|
||||
async def create_auto_job(
|
||||
background_tasks: BackgroundTasks,
|
||||
requirement_file: UploadFile = File(...),
|
||||
code_files: List[UploadFile] = File(...),
|
||||
code_kb_name: str = Form("uploaded-code-kb"),
|
||||
top_k: int = Form(8),
|
||||
max_call_hops: int = Form(2),
|
||||
min_similarity: float = Form(0.55),
|
||||
use_llm: bool = Form(True),
|
||||
use_semantic: bool = Form(True),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
try:
|
||||
ModelConfigService.require_active_config(db, current_user.id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
timestamp = datetime_path()
|
||||
target_dir = AUTO_UPLOAD_ROOT / str(current_user.id) / timestamp
|
||||
requirement_path = await _save_requirement_upload(requirement_file, target_dir / "requirement")
|
||||
code_source_dir = await _save_code_uploads(code_files, target_dir / "code")
|
||||
tool_job = create_auto_consistency_tool_job(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
requirement_file_path=requirement_path,
|
||||
requirement_file_name=safe_upload_name(requirement_file.filename),
|
||||
code_source_dir=code_source_dir,
|
||||
code_kb_name=code_kb_name,
|
||||
top_k=top_k,
|
||||
max_call_hops=max_call_hops,
|
||||
min_similarity=min_similarity,
|
||||
use_llm=use_llm,
|
||||
use_semantic=use_semantic,
|
||||
)
|
||||
background_tasks.add_task(run_auto_consistency_job, tool_job.id)
|
||||
return {"job_id": tool_job.id, "status": tool_job.status}
|
||||
|
||||
|
||||
@router.get("/auto-jobs/{job_id}", response_model=AutoConsistencyJobStatusResponse)
|
||||
async def get_auto_job(
|
||||
job_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
job = get_owned_auto_job(db, current_user.id, job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Auto consistency job not found.")
|
||||
summary = job.output_summary or {}
|
||||
return {
|
||||
"job_id": job.id,
|
||||
"status": job.status,
|
||||
"error_message": job.error_message,
|
||||
"current_step": summary.get("current_step"),
|
||||
"srs_extraction_id": summary.get("srs_extraction_id"),
|
||||
"code_kb_id": summary.get("code_kb_id"),
|
||||
"consistency_job_id": summary.get("consistency_job_id"),
|
||||
"created_at": job.created_at,
|
||||
"updated_at": job.updated_at,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/jobs", response_model=List[ConsistencyJobResponse])
|
||||
async def get_jobs(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
return list_consistency_jobs(db, current_user.id)
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}", response_model=ConsistencyJobResponse)
|
||||
async def get_job(
|
||||
job_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
job = get_owned_consistency_job(db, current_user.id, job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Consistency job not found.")
|
||||
return job
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}/results", response_model=List[ConsistencyResultResponse])
|
||||
async def get_job_results(
|
||||
job_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
job = get_owned_consistency_job(db, current_user.id, job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Consistency job not found.")
|
||||
return (
|
||||
db.query(ConsistencyResult)
|
||||
.filter(ConsistencyResult.job_id == job.id)
|
||||
.order_by(ConsistencyResult.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}/export")
|
||||
async def export_job_results(
|
||||
job_id: int,
|
||||
format: str = Query(default="json", pattern="^(json|markdown|md|excel|xlsx)$"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Response:
|
||||
job = get_owned_consistency_job(db, current_user.id, job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Consistency job not found.")
|
||||
rows = (
|
||||
db.query(ConsistencyResult)
|
||||
.filter(ConsistencyResult.job_id == job.id)
|
||||
.order_by(ConsistencyResult.id)
|
||||
.all()
|
||||
)
|
||||
payload = [result_model_to_export_dict(row) for row in rows]
|
||||
if format in {"markdown", "md"}:
|
||||
content = export_markdown(payload).encode("utf-8")
|
||||
return Response(
|
||||
content,
|
||||
media_type="text/markdown; charset=utf-8",
|
||||
headers={"Content-Disposition": f'attachment; filename="consistency-job-{job.id}.md"'},
|
||||
)
|
||||
if format in {"excel", "xlsx"}:
|
||||
try:
|
||||
content = export_excel(payload)
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
return Response(
|
||||
content,
|
||||
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
headers={"Content-Disposition": f'attachment; filename="consistency-job-{job.id}.xlsx"'},
|
||||
)
|
||||
return Response(
|
||||
export_json(payload),
|
||||
media_type="application/json; charset=utf-8",
|
||||
headers={"Content-Disposition": f'attachment; filename="consistency-job-{job.id}.json"'},
|
||||
)
|
||||
@@ -28,6 +28,7 @@ from app.core.minio import get_minio_client
|
||||
from minio.error import MinioException
|
||||
from app.services.vector_store import VectorStoreFactory
|
||||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||||
from app.services.model_config import ModelConfigService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -163,19 +164,23 @@ async def delete_knowledge_base(
|
||||
# Get all document file paths before deletion
|
||||
document_paths = [doc.file_path for doc in kb.documents]
|
||||
|
||||
cleanup_errors = []
|
||||
|
||||
# Initialize services
|
||||
minio_client = get_minio_client()
|
||||
embeddings = EmbeddingsFactory.create()
|
||||
|
||||
vector_store = VectorStoreFactory.create(
|
||||
store_type=settings.VECTOR_STORE_TYPE,
|
||||
collection_name=f"kb_{kb_id}",
|
||||
embedding_function=embeddings,
|
||||
)
|
||||
vector_store = None
|
||||
try:
|
||||
model_profile = ModelConfigService.get_active_config(db, current_user.id)
|
||||
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
|
||||
vector_store = VectorStoreFactory.create(
|
||||
store_type=settings.VECTOR_STORE_TYPE,
|
||||
collection_name=f"kb_{kb_id}",
|
||||
embedding_function=embeddings,
|
||||
)
|
||||
except Exception as e:
|
||||
cleanup_errors.append(f"Failed to initialize vector store cleanup: {str(e)}")
|
||||
|
||||
# Clean up external resources first
|
||||
cleanup_errors = []
|
||||
|
||||
# 1. Clean up MinIO files
|
||||
try:
|
||||
# Delete all objects with prefix kb_{kb_id}/
|
||||
@@ -188,12 +193,13 @@ async def delete_knowledge_base(
|
||||
logger.error(f"MinIO cleanup error for kb {kb_id}: {str(e)}")
|
||||
|
||||
# 2. Clean up vector store
|
||||
try:
|
||||
vector_store._store.delete_collection(f"kb_{kb_id}")
|
||||
logger.info(f"Cleaned up vector store for knowledge base {kb_id}")
|
||||
except Exception as e:
|
||||
cleanup_errors.append(f"Failed to clean up vector store: {str(e)}")
|
||||
logger.error(f"Vector store cleanup error for kb {kb_id}: {str(e)}")
|
||||
if vector_store is not None:
|
||||
try:
|
||||
vector_store._store.delete_collection(f"kb_{kb_id}")
|
||||
logger.info(f"Cleaned up vector store for knowledge base {kb_id}")
|
||||
except Exception as e:
|
||||
cleanup_errors.append(f"Failed to clean up vector store: {str(e)}")
|
||||
logger.error(f"Vector store cleanup error for kb {kb_id}: {str(e)}")
|
||||
|
||||
# Finally, delete database records in a single transaction
|
||||
db.delete(kb)
|
||||
@@ -366,6 +372,11 @@ async def process_kb_documents(
|
||||
|
||||
if not upload_ids:
|
||||
return {"tasks": []}
|
||||
|
||||
try:
|
||||
ModelConfigService.require_active_config(db, current_user.id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
uploads = db.query(DocumentUpload).filter(DocumentUpload.id.in_(upload_ids)).all()
|
||||
uploads_dict = {upload.id: upload for upload in uploads}
|
||||
@@ -411,12 +422,13 @@ async def process_kb_documents(
|
||||
background_tasks.add_task(
|
||||
add_processing_tasks_to_queue,
|
||||
task_data,
|
||||
kb_id
|
||||
kb_id,
|
||||
current_user.id,
|
||||
)
|
||||
|
||||
return {"tasks": task_info}
|
||||
|
||||
async def add_processing_tasks_to_queue(task_data, kb_id):
|
||||
async def add_processing_tasks_to_queue(task_data, kb_id, user_id):
|
||||
"""Helper function to add document processing tasks to the queue without blocking the main response."""
|
||||
for data in task_data:
|
||||
asyncio.create_task(
|
||||
@@ -425,7 +437,8 @@ async def add_processing_tasks_to_queue(task_data, kb_id):
|
||||
data["file_name"],
|
||||
kb_id,
|
||||
data["task_id"],
|
||||
None
|
||||
None,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
logger.info(f"Added {len(task_data)} document processing tasks to queue")
|
||||
@@ -551,7 +564,11 @@ async def test_retrieval(
|
||||
detail=f"Knowledge base {request.kb_id} not found",
|
||||
)
|
||||
|
||||
embeddings = EmbeddingsFactory.create()
|
||||
try:
|
||||
model_profile = ModelConfigService.require_active_config(db, current_user.id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
|
||||
|
||||
vector_store = VectorStoreFactory.create(
|
||||
store_type=settings.VECTOR_STORE_TYPE,
|
||||
@@ -571,5 +588,7 @@ async def test_retrieval(
|
||||
|
||||
return {"results": response}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
78
rag-web-ui/backend/app/api/api_v1/model_configs.py
Normal file
78
rag-web-ui/backend/app/api/api_v1/model_configs.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import Any, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.model_config import (
|
||||
ModelConfigCreate,
|
||||
ModelConfigResponse,
|
||||
ModelConfigUpdate,
|
||||
ModelProviderOptionsResponse,
|
||||
)
|
||||
from app.services.model_config import ModelConfigService, provider_options_response
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/providers", response_model=ModelProviderOptionsResponse)
|
||||
def read_model_provider_options(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
_ = current_user
|
||||
return provider_options_response()
|
||||
|
||||
|
||||
@router.get("", response_model=List[ModelConfigResponse])
|
||||
def read_model_configs(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
return ModelConfigService.list_configs(db, current_user.id)
|
||||
|
||||
|
||||
@router.post("", response_model=ModelConfigResponse)
|
||||
def create_model_config(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
payload: ModelConfigCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
try:
|
||||
return ModelConfigService.create_config(db, current_user.id, payload)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.put("/{config_id}", response_model=ModelConfigResponse)
|
||||
def update_model_config(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
config_id: int,
|
||||
payload: ModelConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
item = ModelConfigService.get_config(db, current_user.id, config_id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Model config not found.")
|
||||
try:
|
||||
return ModelConfigService.update_config(db, item, payload)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.delete("/{config_id}", response_model=ModelConfigResponse)
|
||||
def delete_model_config(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
config_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
item = ModelConfigService.get_config(db, current_user.id, config_id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Model config not found.")
|
||||
response = ModelConfigResponse.model_validate(item)
|
||||
ModelConfigService.delete_config(db, item)
|
||||
return response
|
||||
@@ -12,6 +12,8 @@ from app.models.knowledge import Document, KnowledgeBase
|
||||
from app.models.user import User
|
||||
from app.schemas.testing import TestingPipelineRequest, TestingPipelineResponse
|
||||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
from app.services.model_config import ModelConfigService
|
||||
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
|
||||
from app.services.testing_pipeline import run_testing_pipeline
|
||||
from app.services.vector_store import VectorStoreFactory
|
||||
@@ -21,8 +23,12 @@ logger = logging.getLogger(__name__)
|
||||
MODEL_PIPELINE_TIMEOUT_SECONDS = 300
|
||||
|
||||
|
||||
async def _build_kb_vector_stores(db: Session, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
|
||||
embeddings = EmbeddingsFactory.create()
|
||||
async def _build_kb_vector_stores(
|
||||
db: Session,
|
||||
knowledge_bases: List[KnowledgeBase],
|
||||
model_profile: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
|
||||
kb_vector_stores: List[Dict[str, Any]] = []
|
||||
|
||||
for kb in knowledge_bases:
|
||||
@@ -47,10 +53,9 @@ async def generate_testing_content(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
_ = current_user
|
||||
|
||||
model_profile = ModelConfigService.get_active_config(db, current_user.id)
|
||||
knowledge_context = (payload.knowledge_context or "").strip()
|
||||
if payload.knowledge_base_ids:
|
||||
if payload.knowledge_base_ids and model_profile is not None:
|
||||
try:
|
||||
knowledge_bases = (
|
||||
db.query(KnowledgeBase)
|
||||
@@ -61,7 +66,7 @@ async def generate_testing_content(
|
||||
.all()
|
||||
)
|
||||
|
||||
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
|
||||
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases, model_profile)
|
||||
if kb_vector_stores:
|
||||
retriever = MultiKBRetriever(
|
||||
reranker_weight=settings.RERANKER_WEIGHT,
|
||||
@@ -81,13 +86,39 @@ async def generate_testing_content(
|
||||
payload.knowledge_base_ids,
|
||||
exc,
|
||||
)
|
||||
elif payload.knowledge_base_ids:
|
||||
logger.warning(
|
||||
"Testing generation skipped knowledge retrieval because no active model config exists: user=%s knowledge_base_ids=%s",
|
||||
current_user.id,
|
||||
payload.knowledge_base_ids,
|
||||
)
|
||||
|
||||
use_model_generation = bool(payload.use_model_generation and model_profile is not None)
|
||||
llm_model = None
|
||||
if use_model_generation:
|
||||
try:
|
||||
llm_model = LLMFactory.create(streaming=False, model_profile=model_profile)
|
||||
ModelConfigService.touch_last_used(db, model_profile)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Testing generation LLM initialization failed for user=%s, falling back to rule-based output: %s",
|
||||
current_user.id,
|
||||
exc,
|
||||
)
|
||||
use_model_generation = False
|
||||
elif payload.use_model_generation:
|
||||
logger.warning(
|
||||
"Testing generation falling back to rule-based output because no active model config exists: user=%s",
|
||||
current_user.id,
|
||||
)
|
||||
|
||||
pipeline_kwargs = {
|
||||
"user_requirement_text": payload.requirement_text,
|
||||
"requirement_type_input": payload.requirement_type,
|
||||
"debug": payload.debug,
|
||||
"knowledge_context": knowledge_context,
|
||||
"use_model_generation": payload.use_model_generation,
|
||||
"use_model_generation": use_model_generation,
|
||||
"llm_model": llm_model,
|
||||
"max_items_per_group": payload.max_items_per_group,
|
||||
"cases_per_item": payload.cases_per_item,
|
||||
"max_focus_points": payload.max_focus_points,
|
||||
|
||||
@@ -34,6 +34,7 @@ from app.services.srs_job_service import (
|
||||
replace_requirements,
|
||||
run_srs_job,
|
||||
)
|
||||
from app.services.model_config import ModelConfigService
|
||||
from app.services.testing_generation_service import (
|
||||
build_testing_generation_response,
|
||||
create_testing_generation,
|
||||
@@ -72,6 +73,11 @@ async def create_srs_job(
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(status_code=400, detail="仅支持 .pdf/.docx 文件")
|
||||
|
||||
try:
|
||||
ModelConfigService.require_active_config(db, current_user.id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
job = ToolJob(
|
||||
user_id=current_user.id,
|
||||
tool_name="srs.requirement_extractor",
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.db.session import get_db
|
||||
from app.core.security import get_api_key_user
|
||||
from app.core.config import settings
|
||||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||||
from app.services.model_config import ModelConfigService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -36,7 +37,11 @@ def query_knowledge_base(
|
||||
detail=f"Knowledge base {knowledge_base_id} not found",
|
||||
)
|
||||
|
||||
embeddings = EmbeddingsFactory.create()
|
||||
try:
|
||||
model_profile = ModelConfigService.require_active_config(db, current_user.id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
|
||||
|
||||
vector_store = VectorStoreFactory.create(
|
||||
store_type=settings.VECTOR_STORE_TYPE,
|
||||
@@ -56,5 +61,7 @@ def query_knowledge_base(
|
||||
|
||||
return {"results": response}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
Reference in New Issue
Block a user