init. project

This commit is contained in:
2026-04-13 11:34:23 +08:00
commit c7c0659a85
202 changed files with 31196 additions and 0 deletions

View File

@@ -0,0 +1,11 @@
from fastapi import APIRouter
from app.api.api_v1 import api_keys, auth, chat, knowledge_base, testing, tools
api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
api_router.include_router(knowledge_base.router, prefix="/knowledge-base", tags=["knowledge-base"])
api_router.include_router(chat.router, prefix="/chat", tags=["chat"])
api_router.include_router(api_keys.router, prefix="/api-keys", tags=["api-keys"])
api_router.include_router(testing.router, prefix="/testing", tags=["testing"])
api_router.include_router(tools.router, prefix="/tools", tags=["tools"])

View File

@@ -0,0 +1,84 @@
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
import logging
from app import models, schemas
from app.db.session import get_db
from app.services.api_key import APIKeyService
from app.core.security import get_current_user
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/", response_model=List[schemas.APIKey])
def read_api_keys(
db: Session = Depends(get_db),
skip: int = 0,
limit: int = 100,
current_user: models.User = Depends(get_current_user),
) -> Any:
"""
Retrieve API keys.
"""
api_keys = APIKeyService.get_api_keys(
db=db, user_id=current_user.id, skip=skip, limit=limit
)
return api_keys
@router.post("/", response_model=schemas.APIKey)
def create_api_key(
*,
db: Session = Depends(get_db),
api_key_in: schemas.APIKeyCreate,
current_user: models.User = Depends(get_current_user),
) -> Any:
"""
Create new API key.
"""
api_key = APIKeyService.create_api_key(
db=db, user_id=current_user.id, name=api_key_in.name
)
logger.info(f"API key created: {api_key.key} for user {current_user.id}")
return api_key
@router.put("/{id}", response_model=schemas.APIKey)
def update_api_key(
*,
db: Session = Depends(get_db),
id: int,
api_key_in: schemas.APIKeyUpdate,
current_user: models.User = Depends(get_current_user),
) -> Any:
"""
Update API key.
"""
api_key = APIKeyService.get_api_key(db=db, api_key_id=id)
if not api_key:
raise HTTPException(status_code=404, detail="API key not found")
if api_key.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not enough permissions")
api_key = APIKeyService.update_api_key(db=db, api_key=api_key, update_data=api_key_in)
logger.info(f"API key updated: {api_key.key} for user {current_user.id}")
return api_key
@router.delete("/{id}", response_model=schemas.APIKey)
def delete_api_key(
*,
db: Session = Depends(get_db),
id: int,
current_user: models.User = Depends(get_current_user),
) -> Any:
"""
Delete API key.
"""
api_key = APIKeyService.get_api_key(db=db, api_key_id=id)
if not api_key:
raise HTTPException(status_code=404, detail="API key not found")
if api_key.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not enough permissions")
APIKeyService.delete_api_key(db=db, api_key=api_key)
logger.info(f"API key deleted: {api_key.key} for user {current_user.id}")
return api_key

View File

@@ -0,0 +1,88 @@
from datetime import timedelta
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from requests.exceptions import RequestException
from app.core import security
from app.core.security import get_current_user
from app.core.config import settings
from app.db.session import get_db
from app.models.user import User
from app.schemas.token import Token
from app.schemas.user import UserCreate, UserResponse
router = APIRouter()
@router.post("/register", response_model=UserResponse)
def register(*, db: Session = Depends(get_db), user_in: UserCreate) -> Any:
"""
Register a new user.
"""
try:
# Check if user with this email exists
user = db.query(User).filter(User.email == user_in.email).first()
if user:
raise HTTPException(
status_code=400,
detail="A user with this email already exists.",
)
# Check if user with this username exists
user = db.query(User).filter(User.username == user_in.username).first()
if user:
raise HTTPException(
status_code=400,
detail="A user with this username already exists.",
)
# Create new user
user = User(
email=user_in.email,
username=user_in.username,
hashed_password=security.get_password_hash(user_in.password),
)
db.add(user)
db.commit()
db.refresh(user)
return user
except RequestException as e:
raise HTTPException(
status_code=503,
detail="Network error or server is unreachable. Please try again later.",
) from e
@router.post("/token", response_model=Token)
def login_access_token(
db: Session = Depends(get_db), form_data: OAuth2PasswordRequestForm = Depends()
) -> Any:
"""
OAuth2 compatible token login, get an access token for future requests.
"""
user = db.query(User).filter(User.username == form_data.username).first()
if not user or not security.verify_password(form_data.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
elif not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Inactive user",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = security.create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
@router.post("/test-token", response_model=UserResponse)
def test_token(current_user: User = Depends(get_current_user)) -> Any:
"""
Test access token by getting current user.
"""
return current_user

View File

@@ -0,0 +1,155 @@
from typing import List, Any
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session, joinedload
from app.db.session import get_db
from app.models.user import User
from app.models.chat import Chat, Message
from app.models.knowledge import KnowledgeBase
from app.schemas.chat import (
ChatCreate,
ChatResponse,
ChatUpdate,
MessageCreate,
MessageResponse
)
from app.core.security import get_current_user
from app.services.chat_service import generate_response
router = APIRouter()
@router.post("/", response_model=ChatResponse)
def create_chat(
*,
db: Session = Depends(get_db),
chat_in: ChatCreate,
current_user: User = Depends(get_current_user)
) -> Any:
# Verify knowledge bases exist and belong to user
knowledge_bases = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id.in_(chat_in.knowledge_base_ids),
KnowledgeBase.user_id == current_user.id
)
.all()
)
if len(knowledge_bases) != len(chat_in.knowledge_base_ids):
raise HTTPException(
status_code=400,
detail="One or more knowledge bases not found"
)
chat = Chat(
title=chat_in.title,
user_id=current_user.id,
)
chat.knowledge_bases = knowledge_bases
db.add(chat)
db.commit()
db.refresh(chat)
return chat
@router.get("/", response_model=List[ChatResponse])
def get_chats(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
skip: int = 0,
limit: int = 100
) -> Any:
chats = (
db.query(Chat)
.filter(Chat.user_id == current_user.id)
.offset(skip)
.limit(limit)
.all()
)
return chats
@router.get("/{chat_id}", response_model=ChatResponse)
def get_chat(
*,
db: Session = Depends(get_db),
chat_id: int,
current_user: User = Depends(get_current_user)
) -> Any:
chat = (
db.query(Chat)
.filter(
Chat.id == chat_id,
Chat.user_id == current_user.id
)
.first()
)
if not chat:
raise HTTPException(status_code=404, detail="Chat not found")
return chat
@router.post("/{chat_id}/messages")
async def create_message(
*,
db: Session = Depends(get_db),
chat_id: int,
messages: dict,
current_user: User = Depends(get_current_user)
) -> StreamingResponse:
chat = (
db.query(Chat)
.options(joinedload(Chat.knowledge_bases))
.filter(
Chat.id == chat_id,
Chat.user_id == current_user.id
)
.first()
)
if not chat:
raise HTTPException(status_code=404, detail="Chat not found")
# Get the last user message
last_message = messages["messages"][-1]
if last_message["role"] != "user":
raise HTTPException(status_code=400, detail="Last message must be from user")
# Get knowledge base IDs
knowledge_base_ids = [kb.id for kb in chat.knowledge_bases]
async def response_stream():
async for chunk in generate_response(
query=last_message["content"],
messages=messages,
knowledge_base_ids=knowledge_base_ids,
chat_id=chat_id,
db=db
):
yield chunk
return StreamingResponse(
response_stream(),
media_type="text/event-stream",
headers={
"x-vercel-ai-data-stream": "v1"
}
)
@router.delete("/{chat_id}")
def delete_chat(
*,
db: Session = Depends(get_db),
chat_id: int,
current_user: User = Depends(get_current_user)
) -> Any:
chat = (
db.query(Chat)
.filter(
Chat.id == chat_id,
Chat.user_id == current_user.id
)
.first()
)
if not chat:
raise HTTPException(status_code=404, detail="Chat not found")
db.delete(chat)
db.commit()
return {"status": "success"}

View File

@@ -0,0 +1,575 @@
import hashlib
from typing import List, Any, Dict
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Query
from sqlalchemy.orm import Session
from langchain_chroma import Chroma
from sqlalchemy import text
import logging
from datetime import datetime, timedelta
from pydantic import BaseModel
from sqlalchemy.orm import selectinload
import time
import asyncio
from app.db.session import get_db
from app.models.user import User
from app.core.security import get_current_user
from app.models.knowledge import KnowledgeBase, Document, ProcessingTask, DocumentChunk, DocumentUpload
from app.schemas.knowledge import (
KnowledgeBaseCreate,
KnowledgeBaseResponse,
KnowledgeBaseUpdate,
DocumentResponse,
PreviewRequest
)
from app.services.document_processor import process_document_background, upload_document, preview_document, PreviewResult
from app.core.config import settings
from app.core.minio import get_minio_client
from minio.error import MinioException
from app.services.vector_store import VectorStoreFactory
from app.services.embedding.embedding_factory import EmbeddingsFactory
router = APIRouter()
logger = logging.getLogger(__name__)
class TestRetrievalRequest(BaseModel):
query: str
kb_id: int
top_k: int
@router.post("", response_model=KnowledgeBaseResponse)
def create_knowledge_base(
*,
db: Session = Depends(get_db),
kb_in: KnowledgeBaseCreate,
current_user: User = Depends(get_current_user)
) -> Any:
"""
Create new knowledge base.
"""
kb = KnowledgeBase(
name=kb_in.name,
description=kb_in.description,
user_id=current_user.id
)
db.add(kb)
db.commit()
db.refresh(kb)
logger.info(f"Knowledge base created: {kb.name} for user {current_user.id}")
return kb
@router.get("", response_model=List[KnowledgeBaseResponse])
def get_knowledge_bases(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
skip: int = 0,
limit: int = 100
) -> Any:
"""
Retrieve knowledge bases.
"""
knowledge_bases = (
db.query(KnowledgeBase)
.filter(KnowledgeBase.user_id == current_user.id)
.offset(skip)
.limit(limit)
.all()
)
return knowledge_bases
@router.get("/{kb_id}", response_model=KnowledgeBaseResponse)
def get_knowledge_base(
*,
db: Session = Depends(get_db),
kb_id: int,
current_user: User = Depends(get_current_user)
) -> Any:
"""
Get knowledge base by ID.
"""
from sqlalchemy.orm import joinedload
kb = (
db.query(KnowledgeBase)
.options(
joinedload(KnowledgeBase.documents)
.joinedload(Document.processing_tasks)
)
.filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
)
.first()
)
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
return kb
@router.put("/{kb_id}", response_model=KnowledgeBaseResponse)
def update_knowledge_base(
*,
db: Session = Depends(get_db),
kb_id: int,
kb_in: KnowledgeBaseUpdate,
current_user: User = Depends(get_current_user)
) -> Any:
"""
Update knowledge base.
"""
kb = db.query(KnowledgeBase).filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
for field, value in kb_in.dict(exclude_unset=True).items():
setattr(kb, field, value)
db.add(kb)
db.commit()
db.refresh(kb)
logger.info(f"Knowledge base updated: {kb.name} for user {current_user.id}")
return kb
@router.delete("/{kb_id}")
async def delete_knowledge_base(
*,
db: Session = Depends(get_db),
kb_id: int,
current_user: User = Depends(get_current_user)
) -> Any:
"""
Delete knowledge base and all associated resources.
"""
logger = logging.getLogger(__name__)
kb = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
)
.first()
)
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
try:
# Get all document file paths before deletion
document_paths = [doc.file_path for doc in kb.documents]
# Initialize services
minio_client = get_minio_client()
embeddings = EmbeddingsFactory.create()
vector_store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb_id}",
embedding_function=embeddings,
)
# Clean up external resources first
cleanup_errors = []
# 1. Clean up MinIO files
try:
# Delete all objects with prefix kb_{kb_id}/
objects = minio_client.list_objects(settings.MINIO_BUCKET_NAME, prefix=f"kb_{kb_id}/")
for obj in objects:
minio_client.remove_object(settings.MINIO_BUCKET_NAME, obj.object_name)
logger.info(f"Cleaned up MinIO files for knowledge base {kb_id}")
except MinioException as e:
cleanup_errors.append(f"Failed to clean up MinIO files: {str(e)}")
logger.error(f"MinIO cleanup error for kb {kb_id}: {str(e)}")
# 2. Clean up vector store
try:
vector_store._store.delete_collection(f"kb_{kb_id}")
logger.info(f"Cleaned up vector store for knowledge base {kb_id}")
except Exception as e:
cleanup_errors.append(f"Failed to clean up vector store: {str(e)}")
logger.error(f"Vector store cleanup error for kb {kb_id}: {str(e)}")
# Finally, delete database records in a single transaction
db.delete(kb)
db.commit()
# Report any cleanup errors in the response
if cleanup_errors:
return {
"message": "Knowledge base deleted with cleanup warnings",
"warnings": cleanup_errors
}
return {"message": "Knowledge base and all associated resources deleted successfully"}
except Exception as e:
db.rollback()
logger.error(f"Failed to delete knowledge base {kb_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to delete knowledge base: {str(e)}")
# Batch upload documents
@router.post("/{kb_id}/documents/upload")
async def upload_kb_documents(
kb_id: int,
files: List[UploadFile],
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Upload multiple documents to MinIO.
"""
kb = db.query(KnowledgeBase).filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
results = []
for file in files:
# 1. 计算文件 hash
file_content = await file.read()
file_hash = hashlib.sha256(file_content).hexdigest()
# 2. 检查是否存在完全相同的文件名称和hash都相同
existing_document = db.query(Document).filter(
Document.file_name == file.filename,
Document.file_hash == file_hash,
Document.knowledge_base_id == kb_id
).first()
if existing_document:
# 完全相同的文件,直接返回
results.append({
"document_id": existing_document.id,
"file_name": existing_document.file_name,
"status": "exists",
"message": "文件已存在且已处理完成",
"skip_processing": True
})
continue
# 3. 上传到临时目录
temp_path = f"kb_{kb_id}/temp/{file.filename}"
await file.seek(0)
try:
minio_client = get_minio_client()
file_size = len(file_content) # 使用之前读取的文件内容长度
minio_client.put_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=temp_path,
data=file.file,
length=file_size, # 指定文件大小
content_type=file.content_type
)
except MinioException as e:
logger.error(f"Failed to upload file to MinIO: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to upload file")
# 4. 创建上传记录
upload = DocumentUpload(
knowledge_base_id=kb_id,
file_name=file.filename,
file_hash=file_hash,
file_size=len(file_content),
content_type=file.content_type,
temp_path=temp_path
)
db.add(upload)
db.commit()
db.refresh(upload)
results.append({
"upload_id": upload.id,
"file_name": file.filename,
"temp_path": temp_path,
"status": "pending",
"skip_processing": False
})
return results
@router.post("/{kb_id}/documents/preview")
async def preview_kb_documents(
kb_id: int,
preview_request: PreviewRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> Dict[int, PreviewResult]:
"""
Preview multiple documents' chunks.
"""
results = {}
for doc_id in preview_request.document_ids:
document = db.query(Document).join(KnowledgeBase).filter(
Document.id == doc_id,
Document.knowledge_base_id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if document:
file_path = document.file_path
else:
upload = db.query(DocumentUpload).join(KnowledgeBase).filter(
DocumentUpload.id == doc_id,
DocumentUpload.knowledge_base_id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not upload:
raise HTTPException(status_code=404, detail=f"Document {doc_id} not found")
file_path = upload.temp_path
preview = await preview_document(
file_path,
chunk_size=preview_request.chunk_size,
chunk_overlap=preview_request.chunk_overlap
)
results[doc_id] = preview
return results
@router.post("/{kb_id}/documents/process")
async def process_kb_documents(
kb_id: int,
upload_results: List[dict],
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Process multiple documents asynchronously.
"""
start_time = time.time()
kb = db.query(KnowledgeBase).filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
task_info = []
upload_ids = []
for result in upload_results:
if result.get("skip_processing"):
continue
upload_ids.append(result["upload_id"])
if not upload_ids:
return {"tasks": []}
uploads = db.query(DocumentUpload).filter(DocumentUpload.id.in_(upload_ids)).all()
uploads_dict = {upload.id: upload for upload in uploads}
all_tasks = []
for upload_id in upload_ids:
upload = uploads_dict.get(upload_id)
if not upload:
continue
task = ProcessingTask(
document_upload_id=upload_id,
knowledge_base_id=kb_id,
status="pending"
)
all_tasks.append(task)
db.add_all(all_tasks)
db.commit()
for task in all_tasks:
db.refresh(task)
task_data = []
for i, upload_id in enumerate(upload_ids):
if i < len(all_tasks):
task = all_tasks[i]
upload = uploads_dict.get(upload_id)
task_info.append({
"upload_id": upload_id,
"task_id": task.id
})
if upload:
task_data.append({
"task_id": task.id,
"upload_id": upload_id,
"temp_path": upload.temp_path,
"file_name": upload.file_name
})
background_tasks.add_task(
add_processing_tasks_to_queue,
task_data,
kb_id
)
return {"tasks": task_info}
async def add_processing_tasks_to_queue(task_data, kb_id):
"""Helper function to add document processing tasks to the queue without blocking the main response."""
for data in task_data:
asyncio.create_task(
process_document_background(
data["temp_path"],
data["file_name"],
kb_id,
data["task_id"],
None
)
)
logger.info(f"Added {len(task_data)} document processing tasks to queue")
@router.post("/cleanup")
async def cleanup_temp_files(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Clean up expired temporary files.
"""
expired_time = datetime.utcnow() - timedelta(hours=24)
expired_uploads = db.query(DocumentUpload).filter(
DocumentUpload.created_at < expired_time
).all()
minio_client = get_minio_client()
for upload in expired_uploads:
try:
minio_client.remove_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=upload.temp_path
)
except MinioException as e:
logger.error(f"Failed to delete temp file {upload.temp_path}: {str(e)}")
db.delete(upload)
db.commit()
return {"message": f"Cleaned up {len(expired_uploads)} expired uploads"}
@router.get("/{kb_id}/documents/tasks")
async def get_processing_tasks(
kb_id: int,
task_ids: str = Query(..., description="Comma-separated list of task IDs to check status for"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Get status of multiple processing tasks.
"""
task_id_list = [int(id.strip()) for id in task_ids.split(",")]
kb = db.query(KnowledgeBase).filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
tasks = (
db.query(ProcessingTask)
.options(
selectinload(ProcessingTask.document_upload)
)
.filter(
ProcessingTask.id.in_(task_id_list),
ProcessingTask.knowledge_base_id == kb_id
)
.all()
)
return {
task.id: {
"document_id": task.document_id,
"status": task.status,
"error_message": task.error_message,
"upload_id": task.document_upload_id,
"file_name": task.document_upload.file_name if task.document_upload else None
}
for task in tasks
}
@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse)
async def get_document(
*,
db: Session = Depends(get_db),
kb_id: int,
doc_id: int,
current_user: User = Depends(get_current_user)
) -> Any:
"""
Get document details by ID.
"""
document = (
db.query(Document)
.join(KnowledgeBase)
.filter(
Document.id == doc_id,
Document.knowledge_base_id == kb_id,
KnowledgeBase.user_id == current_user.id
)
.first()
)
if not document:
raise HTTPException(status_code=404, detail="Document not found")
return document
@router.post("/test-retrieval")
async def test_retrieval(
request: TestRetrievalRequest,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> Any:
"""
Test retrieval quality for a given query against a knowledge base.
"""
try:
kb = db.query(KnowledgeBase).filter(
KnowledgeBase.id == request.kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not kb:
raise HTTPException(
status_code=404,
detail=f"Knowledge base {request.kb_id} not found",
)
embeddings = EmbeddingsFactory.create()
vector_store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{request.kb_id}",
embedding_function=embeddings,
)
results = vector_store.similarity_search_with_score(request.query, k=request.top_k)
response = []
for doc, score in results:
response.append({
"content": doc.page_content,
"metadata": doc.metadata,
"score": float(score)
})
return {"results": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,84 @@
from typing import Any, Dict, List
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.security import get_current_user
from app.db.session import get_db
from app.models.knowledge import Document, KnowledgeBase
from app.models.user import User
from app.schemas.testing import TestingPipelineRequest, TestingPipelineResponse
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
from app.services.testing_pipeline import run_testing_pipeline
from app.services.vector_store import VectorStoreFactory
router = APIRouter()
async def _build_kb_vector_stores(db: Session, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
embeddings = EmbeddingsFactory.create()
kb_vector_stores: List[Dict[str, Any]] = []
for kb in knowledge_bases:
documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all()
if not documents:
continue
store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb.id}",
embedding_function=embeddings,
)
kb_vector_stores.append({"kb_id": kb.id, "store": store})
return kb_vector_stores
@router.post("/generate", response_model=TestingPipelineResponse)
async def generate_testing_content(
*,
payload: TestingPipelineRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Any:
_ = current_user
knowledge_context = (payload.knowledge_context or "").strip()
if payload.knowledge_base_ids:
knowledge_bases = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id.in_(payload.knowledge_base_ids),
KnowledgeBase.user_id == current_user.id,
)
.all()
)
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
if kb_vector_stores:
retriever = MultiKBRetriever(
reranker_weight=settings.RERANKER_WEIGHT,
)
retrieval_rows = await retriever.retrieve(
query=payload.requirement_text,
kb_vector_stores=kb_vector_stores,
fetch_k_per_kb=max(12, payload.retrieval_top_k * 2),
top_k=payload.retrieval_top_k,
)
if retrieval_rows:
knowledge_context = format_retrieval_context(retrieval_rows)
result = run_testing_pipeline(
user_requirement_text=payload.requirement_text,
requirement_type_input=payload.requirement_type,
debug=payload.debug,
knowledge_context=knowledge_context,
use_model_generation=payload.use_model_generation,
max_items_per_group=payload.max_items_per_group,
cases_per_item=payload.cases_per_item,
max_focus_points=payload.max_focus_points,
max_llm_calls=payload.max_llm_calls,
)
return result

View File

@@ -0,0 +1,175 @@
from pathlib import Path
from typing import Any, List
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
from sqlalchemy.orm import Session
from app.core.security import get_current_user
from app.db.session import get_db
from app.models.tooling import SRSExtraction, ToolJob
from app.models.user import User
from app.schemas.tooling import (
SRSToolCreateJobResponse,
SRSToolJobStatusResponse,
SRSToolRequirementsSaveRequest,
SRSToolResultResponse,
ToolDefinitionResponse,
)
from app.services.srs_job_service import (
build_result_response,
ensure_upload_path,
replace_requirements,
run_srs_job,
)
from app.tools.registry import ToolRegistry
from app.tools.srs_reqs_qwen import get_srs_tool
router = APIRouter()
# Register SRS tool when the router is imported.
get_srs_tool()
ALLOWED_EXTENSIONS = {".pdf", ".docx"}
@router.get("", response_model=List[ToolDefinitionResponse])
async def list_tools(
current_user: User = Depends(get_current_user),
) -> Any:
_ = current_user
return ToolRegistry.list()
@router.post("/srs/jobs", response_model=SRSToolCreateJobResponse)
async def create_srs_job(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
safe_name = Path(file.filename or "").name
ext = Path(safe_name).suffix.lower()
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(status_code=400, detail="仅支持 .pdf/.docx 文件")
job = ToolJob(
user_id=current_user.id,
tool_name="srs.requirement_extractor",
status="pending",
input_file_name=safe_name,
input_file_path="",
)
db.add(job)
db.commit()
db.refresh(job)
target_path = ensure_upload_path(job.id, safe_name)
try:
content = await file.read()
target_path.write_bytes(content)
except Exception as exc:
job.status = "failed"
job.error_message = f"保存上传文件失败: {exc}"
db.add(job)
db.commit()
raise HTTPException(status_code=500, detail="上传文件保存失败")
job.input_file_path = str(target_path.resolve())
db.add(job)
db.commit()
background_tasks.add_task(run_srs_job, job.id)
return {
"job_id": job.id,
"status": job.status,
}
@router.get("/srs/jobs/{job_id}", response_model=SRSToolJobStatusResponse)
async def get_srs_job_status(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
extraction = (
db.query(SRSExtraction)
.filter(SRSExtraction.job_id == job.id)
.first()
)
return {
"job_id": job.id,
"tool_name": job.tool_name,
"status": job.status,
"error_message": job.error_message,
"extraction_id": extraction.id if extraction else None,
"started_at": job.started_at,
"completed_at": job.completed_at,
}
@router.get("/srs/jobs/{job_id}/result", response_model=SRSToolResultResponse)
async def get_srs_job_result(
job_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
if job.status != "completed":
raise HTTPException(status_code=409, detail="任务尚未完成")
extraction = (
db.query(SRSExtraction)
.filter(SRSExtraction.job_id == job.id)
.first()
)
if not extraction:
raise HTTPException(status_code=404, detail="任务结果不存在")
return build_result_response(job, extraction)
@router.put("/srs/jobs/{job_id}/requirements", response_model=SRSToolResultResponse)
async def save_srs_requirements(
job_id: int,
payload: SRSToolRequirementsSaveRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Any:
job = (
db.query(ToolJob)
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
.first()
)
if not job:
raise HTTPException(status_code=404, detail="任务不存在")
extraction = (
db.query(SRSExtraction)
.filter(SRSExtraction.job_id == job.id)
.first()
)
if not extraction:
raise HTTPException(status_code=404, detail="任务结果不存在")
replace_requirements(db, extraction, [item.dict() for item in payload.requirements])
db.add(extraction)
db.commit()
db.refresh(extraction)
return build_result_response(job, extraction)