Files

576 lines
17 KiB
Python
Raw Permalink Normal View History

2026-04-13 11:34:23 +08:00
import hashlib
from typing import List, Any, Dict
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Query
from sqlalchemy.orm import Session
from langchain_chroma import Chroma
from sqlalchemy import text
import logging
from datetime import datetime, timedelta
from pydantic import BaseModel
from sqlalchemy.orm import selectinload
import time
import asyncio
from app.db.session import get_db
from app.models.user import User
from app.core.security import get_current_user
from app.models.knowledge import KnowledgeBase, Document, ProcessingTask, DocumentChunk, DocumentUpload
from app.schemas.knowledge import (
KnowledgeBaseCreate,
KnowledgeBaseResponse,
KnowledgeBaseUpdate,
DocumentResponse,
PreviewRequest
)
from app.services.document_processor import process_document_background, upload_document, preview_document, PreviewResult
from app.core.config import settings
from app.core.minio import get_minio_client
from minio.error import MinioException
from app.services.vector_store import VectorStoreFactory
from app.services.embedding.embedding_factory import EmbeddingsFactory
router = APIRouter()
logger = logging.getLogger(__name__)
class TestRetrievalRequest(BaseModel):
query: str
kb_id: int
top_k: int
@router.post("", response_model=KnowledgeBaseResponse)
def create_knowledge_base(
*,
db: Session = Depends(get_db),
kb_in: KnowledgeBaseCreate,
current_user: User = Depends(get_current_user)
) -> Any:
"""
Create new knowledge base.
"""
kb = KnowledgeBase(
name=kb_in.name,
description=kb_in.description,
user_id=current_user.id
)
db.add(kb)
db.commit()
db.refresh(kb)
logger.info(f"Knowledge base created: {kb.name} for user {current_user.id}")
return kb
@router.get("", response_model=List[KnowledgeBaseResponse])
def get_knowledge_bases(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
skip: int = 0,
limit: int = 100
) -> Any:
"""
Retrieve knowledge bases.
"""
knowledge_bases = (
db.query(KnowledgeBase)
.filter(KnowledgeBase.user_id == current_user.id)
.offset(skip)
.limit(limit)
.all()
)
return knowledge_bases
@router.get("/{kb_id}", response_model=KnowledgeBaseResponse)
def get_knowledge_base(
*,
db: Session = Depends(get_db),
kb_id: int,
current_user: User = Depends(get_current_user)
) -> Any:
"""
Get knowledge base by ID.
"""
from sqlalchemy.orm import joinedload
kb = (
db.query(KnowledgeBase)
.options(
joinedload(KnowledgeBase.documents)
.joinedload(Document.processing_tasks)
)
.filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
)
.first()
)
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
return kb
@router.put("/{kb_id}", response_model=KnowledgeBaseResponse)
def update_knowledge_base(
*,
db: Session = Depends(get_db),
kb_id: int,
kb_in: KnowledgeBaseUpdate,
current_user: User = Depends(get_current_user)
) -> Any:
"""
Update knowledge base.
"""
kb = db.query(KnowledgeBase).filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
for field, value in kb_in.dict(exclude_unset=True).items():
setattr(kb, field, value)
db.add(kb)
db.commit()
db.refresh(kb)
logger.info(f"Knowledge base updated: {kb.name} for user {current_user.id}")
return kb
@router.delete("/{kb_id}")
async def delete_knowledge_base(
*,
db: Session = Depends(get_db),
kb_id: int,
current_user: User = Depends(get_current_user)
) -> Any:
"""
Delete knowledge base and all associated resources.
"""
logger = logging.getLogger(__name__)
kb = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
)
.first()
)
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
try:
# Get all document file paths before deletion
document_paths = [doc.file_path for doc in kb.documents]
# Initialize services
minio_client = get_minio_client()
embeddings = EmbeddingsFactory.create()
vector_store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb_id}",
embedding_function=embeddings,
)
# Clean up external resources first
cleanup_errors = []
# 1. Clean up MinIO files
try:
# Delete all objects with prefix kb_{kb_id}/
objects = minio_client.list_objects(settings.MINIO_BUCKET_NAME, prefix=f"kb_{kb_id}/")
for obj in objects:
minio_client.remove_object(settings.MINIO_BUCKET_NAME, obj.object_name)
logger.info(f"Cleaned up MinIO files for knowledge base {kb_id}")
except MinioException as e:
cleanup_errors.append(f"Failed to clean up MinIO files: {str(e)}")
logger.error(f"MinIO cleanup error for kb {kb_id}: {str(e)}")
# 2. Clean up vector store
try:
vector_store._store.delete_collection(f"kb_{kb_id}")
logger.info(f"Cleaned up vector store for knowledge base {kb_id}")
except Exception as e:
cleanup_errors.append(f"Failed to clean up vector store: {str(e)}")
logger.error(f"Vector store cleanup error for kb {kb_id}: {str(e)}")
# Finally, delete database records in a single transaction
db.delete(kb)
db.commit()
# Report any cleanup errors in the response
if cleanup_errors:
return {
"message": "Knowledge base deleted with cleanup warnings",
"warnings": cleanup_errors
}
return {"message": "Knowledge base and all associated resources deleted successfully"}
except Exception as e:
db.rollback()
logger.error(f"Failed to delete knowledge base {kb_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to delete knowledge base: {str(e)}")
# Batch upload documents
@router.post("/{kb_id}/documents/upload")
async def upload_kb_documents(
kb_id: int,
files: List[UploadFile],
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Upload multiple documents to MinIO.
"""
kb = db.query(KnowledgeBase).filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
results = []
for file in files:
# 1. 计算文件 hash
file_content = await file.read()
file_hash = hashlib.sha256(file_content).hexdigest()
# 2. 检查是否存在完全相同的文件名称和hash都相同
existing_document = db.query(Document).filter(
Document.file_name == file.filename,
Document.file_hash == file_hash,
Document.knowledge_base_id == kb_id
).first()
if existing_document:
# 完全相同的文件,直接返回
results.append({
"document_id": existing_document.id,
"file_name": existing_document.file_name,
"status": "exists",
"message": "文件已存在且已处理完成",
"skip_processing": True
})
continue
# 3. 上传到临时目录
temp_path = f"kb_{kb_id}/temp/{file.filename}"
await file.seek(0)
try:
minio_client = get_minio_client()
file_size = len(file_content) # 使用之前读取的文件内容长度
minio_client.put_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=temp_path,
data=file.file,
length=file_size, # 指定文件大小
content_type=file.content_type
)
except MinioException as e:
logger.error(f"Failed to upload file to MinIO: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to upload file")
# 4. 创建上传记录
upload = DocumentUpload(
knowledge_base_id=kb_id,
file_name=file.filename,
file_hash=file_hash,
file_size=len(file_content),
content_type=file.content_type,
temp_path=temp_path
)
db.add(upload)
db.commit()
db.refresh(upload)
results.append({
"upload_id": upload.id,
"file_name": file.filename,
"temp_path": temp_path,
"status": "pending",
"skip_processing": False
})
return results
@router.post("/{kb_id}/documents/preview")
async def preview_kb_documents(
kb_id: int,
preview_request: PreviewRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> Dict[int, PreviewResult]:
"""
Preview multiple documents' chunks.
"""
results = {}
for doc_id in preview_request.document_ids:
document = db.query(Document).join(KnowledgeBase).filter(
Document.id == doc_id,
Document.knowledge_base_id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if document:
file_path = document.file_path
else:
upload = db.query(DocumentUpload).join(KnowledgeBase).filter(
DocumentUpload.id == doc_id,
DocumentUpload.knowledge_base_id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not upload:
raise HTTPException(status_code=404, detail=f"Document {doc_id} not found")
file_path = upload.temp_path
preview = await preview_document(
file_path,
chunk_size=preview_request.chunk_size,
chunk_overlap=preview_request.chunk_overlap
)
results[doc_id] = preview
return results
@router.post("/{kb_id}/documents/process")
async def process_kb_documents(
kb_id: int,
upload_results: List[dict],
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Process multiple documents asynchronously.
"""
start_time = time.time()
kb = db.query(KnowledgeBase).filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
task_info = []
upload_ids = []
for result in upload_results:
if result.get("skip_processing"):
continue
upload_ids.append(result["upload_id"])
if not upload_ids:
return {"tasks": []}
uploads = db.query(DocumentUpload).filter(DocumentUpload.id.in_(upload_ids)).all()
uploads_dict = {upload.id: upload for upload in uploads}
all_tasks = []
for upload_id in upload_ids:
upload = uploads_dict.get(upload_id)
if not upload:
continue
task = ProcessingTask(
document_upload_id=upload_id,
knowledge_base_id=kb_id,
status="pending"
)
all_tasks.append(task)
db.add_all(all_tasks)
db.commit()
for task in all_tasks:
db.refresh(task)
task_data = []
for i, upload_id in enumerate(upload_ids):
if i < len(all_tasks):
task = all_tasks[i]
upload = uploads_dict.get(upload_id)
task_info.append({
"upload_id": upload_id,
"task_id": task.id
})
if upload:
task_data.append({
"task_id": task.id,
"upload_id": upload_id,
"temp_path": upload.temp_path,
"file_name": upload.file_name
})
background_tasks.add_task(
add_processing_tasks_to_queue,
task_data,
kb_id
)
return {"tasks": task_info}
async def add_processing_tasks_to_queue(task_data, kb_id):
"""Helper function to add document processing tasks to the queue without blocking the main response."""
for data in task_data:
asyncio.create_task(
process_document_background(
data["temp_path"],
data["file_name"],
kb_id,
data["task_id"],
None
)
)
logger.info(f"Added {len(task_data)} document processing tasks to queue")
@router.post("/cleanup")
async def cleanup_temp_files(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Clean up expired temporary files.
"""
expired_time = datetime.utcnow() - timedelta(hours=24)
expired_uploads = db.query(DocumentUpload).filter(
DocumentUpload.created_at < expired_time
).all()
minio_client = get_minio_client()
for upload in expired_uploads:
try:
minio_client.remove_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=upload.temp_path
)
except MinioException as e:
logger.error(f"Failed to delete temp file {upload.temp_path}: {str(e)}")
db.delete(upload)
db.commit()
return {"message": f"Cleaned up {len(expired_uploads)} expired uploads"}
@router.get("/{kb_id}/documents/tasks")
async def get_processing_tasks(
kb_id: int,
task_ids: str = Query(..., description="Comma-separated list of task IDs to check status for"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Get status of multiple processing tasks.
"""
task_id_list = [int(id.strip()) for id in task_ids.split(",")]
kb = db.query(KnowledgeBase).filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
tasks = (
db.query(ProcessingTask)
.options(
selectinload(ProcessingTask.document_upload)
)
.filter(
ProcessingTask.id.in_(task_id_list),
ProcessingTask.knowledge_base_id == kb_id
)
.all()
)
return {
task.id: {
"document_id": task.document_id,
"status": task.status,
"error_message": task.error_message,
"upload_id": task.document_upload_id,
"file_name": task.document_upload.file_name if task.document_upload else None
}
for task in tasks
}
@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse)
async def get_document(
*,
db: Session = Depends(get_db),
kb_id: int,
doc_id: int,
current_user: User = Depends(get_current_user)
) -> Any:
"""
Get document details by ID.
"""
document = (
db.query(Document)
.join(KnowledgeBase)
.filter(
Document.id == doc_id,
Document.knowledge_base_id == kb_id,
KnowledgeBase.user_id == current_user.id
)
.first()
)
if not document:
raise HTTPException(status_code=404, detail="Document not found")
return document
@router.post("/test-retrieval")
async def test_retrieval(
request: TestRetrievalRequest,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> Any:
"""
Test retrieval quality for a given query against a knowledge base.
"""
try:
kb = db.query(KnowledgeBase).filter(
KnowledgeBase.id == request.kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not kb:
raise HTTPException(
status_code=404,
detail=f"Knowledge base {request.kb_id} not found",
)
embeddings = EmbeddingsFactory.create()
vector_store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{request.kb_id}",
embedding_function=embeddings,
)
results = vector_store.similarity_search_with_score(request.query, k=request.top_k)
response = []
for doc, score in results:
response.append({
"content": doc.page_content,
"metadata": doc.metadata,
"score": float(score)
})
return {"results": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))