576 lines
17 KiB
Python
576 lines
17 KiB
Python
import hashlib
|
||
from typing import List, Any, Dict
|
||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Query
|
||
from sqlalchemy.orm import Session
|
||
from langchain_chroma import Chroma
|
||
from sqlalchemy import text
|
||
import logging
|
||
from datetime import datetime, timedelta
|
||
from pydantic import BaseModel
|
||
from sqlalchemy.orm import selectinload
|
||
import time
|
||
import asyncio
|
||
|
||
from app.db.session import get_db
|
||
from app.models.user import User
|
||
from app.core.security import get_current_user
|
||
from app.models.knowledge import KnowledgeBase, Document, ProcessingTask, DocumentChunk, DocumentUpload
|
||
from app.schemas.knowledge import (
|
||
KnowledgeBaseCreate,
|
||
KnowledgeBaseResponse,
|
||
KnowledgeBaseUpdate,
|
||
DocumentResponse,
|
||
PreviewRequest
|
||
)
|
||
from app.services.document_processor import process_document_background, upload_document, preview_document, PreviewResult
|
||
from app.core.config import settings
|
||
from app.core.minio import get_minio_client
|
||
from minio.error import MinioException
|
||
from app.services.vector_store import VectorStoreFactory
|
||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||
|
||
router = APIRouter()
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class TestRetrievalRequest(BaseModel):
|
||
query: str
|
||
kb_id: int
|
||
top_k: int
|
||
|
||
@router.post("", response_model=KnowledgeBaseResponse)
|
||
def create_knowledge_base(
|
||
*,
|
||
db: Session = Depends(get_db),
|
||
kb_in: KnowledgeBaseCreate,
|
||
current_user: User = Depends(get_current_user)
|
||
) -> Any:
|
||
"""
|
||
Create new knowledge base.
|
||
"""
|
||
kb = KnowledgeBase(
|
||
name=kb_in.name,
|
||
description=kb_in.description,
|
||
user_id=current_user.id
|
||
)
|
||
db.add(kb)
|
||
db.commit()
|
||
db.refresh(kb)
|
||
logger.info(f"Knowledge base created: {kb.name} for user {current_user.id}")
|
||
return kb
|
||
|
||
@router.get("", response_model=List[KnowledgeBaseResponse])
|
||
def get_knowledge_bases(
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
skip: int = 0,
|
||
limit: int = 100
|
||
) -> Any:
|
||
"""
|
||
Retrieve knowledge bases.
|
||
"""
|
||
knowledge_bases = (
|
||
db.query(KnowledgeBase)
|
||
.filter(KnowledgeBase.user_id == current_user.id)
|
||
.offset(skip)
|
||
.limit(limit)
|
||
.all()
|
||
)
|
||
return knowledge_bases
|
||
|
||
@router.get("/{kb_id}", response_model=KnowledgeBaseResponse)
|
||
def get_knowledge_base(
|
||
*,
|
||
db: Session = Depends(get_db),
|
||
kb_id: int,
|
||
current_user: User = Depends(get_current_user)
|
||
) -> Any:
|
||
"""
|
||
Get knowledge base by ID.
|
||
"""
|
||
from sqlalchemy.orm import joinedload
|
||
|
||
kb = (
|
||
db.query(KnowledgeBase)
|
||
.options(
|
||
joinedload(KnowledgeBase.documents)
|
||
.joinedload(Document.processing_tasks)
|
||
)
|
||
.filter(
|
||
KnowledgeBase.id == kb_id,
|
||
KnowledgeBase.user_id == current_user.id
|
||
)
|
||
.first()
|
||
)
|
||
|
||
if not kb:
|
||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||
|
||
return kb
|
||
|
||
@router.put("/{kb_id}", response_model=KnowledgeBaseResponse)
|
||
def update_knowledge_base(
|
||
*,
|
||
db: Session = Depends(get_db),
|
||
kb_id: int,
|
||
kb_in: KnowledgeBaseUpdate,
|
||
current_user: User = Depends(get_current_user)
|
||
) -> Any:
|
||
"""
|
||
Update knowledge base.
|
||
"""
|
||
kb = db.query(KnowledgeBase).filter(
|
||
KnowledgeBase.id == kb_id,
|
||
KnowledgeBase.user_id == current_user.id
|
||
).first()
|
||
|
||
if not kb:
|
||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||
|
||
for field, value in kb_in.dict(exclude_unset=True).items():
|
||
setattr(kb, field, value)
|
||
|
||
db.add(kb)
|
||
db.commit()
|
||
db.refresh(kb)
|
||
logger.info(f"Knowledge base updated: {kb.name} for user {current_user.id}")
|
||
return kb
|
||
|
||
@router.delete("/{kb_id}")
|
||
async def delete_knowledge_base(
|
||
*,
|
||
db: Session = Depends(get_db),
|
||
kb_id: int,
|
||
current_user: User = Depends(get_current_user)
|
||
) -> Any:
|
||
"""
|
||
Delete knowledge base and all associated resources.
|
||
"""
|
||
logger = logging.getLogger(__name__)
|
||
|
||
kb = (
|
||
db.query(KnowledgeBase)
|
||
.filter(
|
||
KnowledgeBase.id == kb_id,
|
||
KnowledgeBase.user_id == current_user.id
|
||
)
|
||
.first()
|
||
)
|
||
if not kb:
|
||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||
|
||
try:
|
||
# Get all document file paths before deletion
|
||
document_paths = [doc.file_path for doc in kb.documents]
|
||
|
||
# Initialize services
|
||
minio_client = get_minio_client()
|
||
embeddings = EmbeddingsFactory.create()
|
||
|
||
vector_store = VectorStoreFactory.create(
|
||
store_type=settings.VECTOR_STORE_TYPE,
|
||
collection_name=f"kb_{kb_id}",
|
||
embedding_function=embeddings,
|
||
)
|
||
|
||
# Clean up external resources first
|
||
cleanup_errors = []
|
||
|
||
# 1. Clean up MinIO files
|
||
try:
|
||
# Delete all objects with prefix kb_{kb_id}/
|
||
objects = minio_client.list_objects(settings.MINIO_BUCKET_NAME, prefix=f"kb_{kb_id}/")
|
||
for obj in objects:
|
||
minio_client.remove_object(settings.MINIO_BUCKET_NAME, obj.object_name)
|
||
logger.info(f"Cleaned up MinIO files for knowledge base {kb_id}")
|
||
except MinioException as e:
|
||
cleanup_errors.append(f"Failed to clean up MinIO files: {str(e)}")
|
||
logger.error(f"MinIO cleanup error for kb {kb_id}: {str(e)}")
|
||
|
||
# 2. Clean up vector store
|
||
try:
|
||
vector_store._store.delete_collection(f"kb_{kb_id}")
|
||
logger.info(f"Cleaned up vector store for knowledge base {kb_id}")
|
||
except Exception as e:
|
||
cleanup_errors.append(f"Failed to clean up vector store: {str(e)}")
|
||
logger.error(f"Vector store cleanup error for kb {kb_id}: {str(e)}")
|
||
|
||
# Finally, delete database records in a single transaction
|
||
db.delete(kb)
|
||
db.commit()
|
||
|
||
# Report any cleanup errors in the response
|
||
if cleanup_errors:
|
||
return {
|
||
"message": "Knowledge base deleted with cleanup warnings",
|
||
"warnings": cleanup_errors
|
||
}
|
||
|
||
return {"message": "Knowledge base and all associated resources deleted successfully"}
|
||
except Exception as e:
|
||
db.rollback()
|
||
logger.error(f"Failed to delete knowledge base {kb_id}: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Failed to delete knowledge base: {str(e)}")
|
||
|
||
# Batch upload documents
|
||
@router.post("/{kb_id}/documents/upload")
|
||
async def upload_kb_documents(
|
||
kb_id: int,
|
||
files: List[UploadFile],
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
Upload multiple documents to MinIO.
|
||
"""
|
||
kb = db.query(KnowledgeBase).filter(
|
||
KnowledgeBase.id == kb_id,
|
||
KnowledgeBase.user_id == current_user.id
|
||
).first()
|
||
if not kb:
|
||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||
|
||
results = []
|
||
for file in files:
|
||
# 1. 计算文件 hash
|
||
file_content = await file.read()
|
||
file_hash = hashlib.sha256(file_content).hexdigest()
|
||
|
||
# 2. 检查是否存在完全相同的文件(名称和hash都相同)
|
||
existing_document = db.query(Document).filter(
|
||
Document.file_name == file.filename,
|
||
Document.file_hash == file_hash,
|
||
Document.knowledge_base_id == kb_id
|
||
).first()
|
||
|
||
if existing_document:
|
||
# 完全相同的文件,直接返回
|
||
results.append({
|
||
"document_id": existing_document.id,
|
||
"file_name": existing_document.file_name,
|
||
"status": "exists",
|
||
"message": "文件已存在且已处理完成",
|
||
"skip_processing": True
|
||
})
|
||
continue
|
||
|
||
# 3. 上传到临时目录
|
||
temp_path = f"kb_{kb_id}/temp/{file.filename}"
|
||
await file.seek(0)
|
||
try:
|
||
minio_client = get_minio_client()
|
||
file_size = len(file_content) # 使用之前读取的文件内容长度
|
||
minio_client.put_object(
|
||
bucket_name=settings.MINIO_BUCKET_NAME,
|
||
object_name=temp_path,
|
||
data=file.file,
|
||
length=file_size, # 指定文件大小
|
||
content_type=file.content_type
|
||
)
|
||
except MinioException as e:
|
||
logger.error(f"Failed to upload file to MinIO: {str(e)}")
|
||
raise HTTPException(status_code=500, detail="Failed to upload file")
|
||
|
||
# 4. 创建上传记录
|
||
upload = DocumentUpload(
|
||
knowledge_base_id=kb_id,
|
||
file_name=file.filename,
|
||
file_hash=file_hash,
|
||
file_size=len(file_content),
|
||
content_type=file.content_type,
|
||
temp_path=temp_path
|
||
)
|
||
db.add(upload)
|
||
db.commit()
|
||
db.refresh(upload)
|
||
|
||
results.append({
|
||
"upload_id": upload.id,
|
||
"file_name": file.filename,
|
||
"temp_path": temp_path,
|
||
"status": "pending",
|
||
"skip_processing": False
|
||
})
|
||
|
||
return results
|
||
|
||
@router.post("/{kb_id}/documents/preview")
|
||
async def preview_kb_documents(
|
||
kb_id: int,
|
||
preview_request: PreviewRequest,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
) -> Dict[int, PreviewResult]:
|
||
"""
|
||
Preview multiple documents' chunks.
|
||
"""
|
||
results = {}
|
||
for doc_id in preview_request.document_ids:
|
||
document = db.query(Document).join(KnowledgeBase).filter(
|
||
Document.id == doc_id,
|
||
Document.knowledge_base_id == kb_id,
|
||
KnowledgeBase.user_id == current_user.id
|
||
).first()
|
||
|
||
if document:
|
||
file_path = document.file_path
|
||
else:
|
||
upload = db.query(DocumentUpload).join(KnowledgeBase).filter(
|
||
DocumentUpload.id == doc_id,
|
||
DocumentUpload.knowledge_base_id == kb_id,
|
||
KnowledgeBase.user_id == current_user.id
|
||
).first()
|
||
|
||
if not upload:
|
||
raise HTTPException(status_code=404, detail=f"Document {doc_id} not found")
|
||
|
||
file_path = upload.temp_path
|
||
|
||
preview = await preview_document(
|
||
file_path,
|
||
chunk_size=preview_request.chunk_size,
|
||
chunk_overlap=preview_request.chunk_overlap
|
||
)
|
||
results[doc_id] = preview
|
||
|
||
return results
|
||
|
||
@router.post("/{kb_id}/documents/process")
|
||
async def process_kb_documents(
|
||
kb_id: int,
|
||
upload_results: List[dict],
|
||
background_tasks: BackgroundTasks,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
Process multiple documents asynchronously.
|
||
"""
|
||
start_time = time.time()
|
||
|
||
kb = db.query(KnowledgeBase).filter(
|
||
KnowledgeBase.id == kb_id,
|
||
KnowledgeBase.user_id == current_user.id
|
||
).first()
|
||
|
||
if not kb:
|
||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||
|
||
task_info = []
|
||
upload_ids = []
|
||
|
||
for result in upload_results:
|
||
if result.get("skip_processing"):
|
||
continue
|
||
upload_ids.append(result["upload_id"])
|
||
|
||
if not upload_ids:
|
||
return {"tasks": []}
|
||
|
||
uploads = db.query(DocumentUpload).filter(DocumentUpload.id.in_(upload_ids)).all()
|
||
uploads_dict = {upload.id: upload for upload in uploads}
|
||
|
||
all_tasks = []
|
||
for upload_id in upload_ids:
|
||
upload = uploads_dict.get(upload_id)
|
||
if not upload:
|
||
continue
|
||
|
||
task = ProcessingTask(
|
||
document_upload_id=upload_id,
|
||
knowledge_base_id=kb_id,
|
||
status="pending"
|
||
)
|
||
all_tasks.append(task)
|
||
|
||
db.add_all(all_tasks)
|
||
db.commit()
|
||
|
||
for task in all_tasks:
|
||
db.refresh(task)
|
||
|
||
task_data = []
|
||
for i, upload_id in enumerate(upload_ids):
|
||
if i < len(all_tasks):
|
||
task = all_tasks[i]
|
||
upload = uploads_dict.get(upload_id)
|
||
|
||
task_info.append({
|
||
"upload_id": upload_id,
|
||
"task_id": task.id
|
||
})
|
||
|
||
if upload:
|
||
task_data.append({
|
||
"task_id": task.id,
|
||
"upload_id": upload_id,
|
||
"temp_path": upload.temp_path,
|
||
"file_name": upload.file_name
|
||
})
|
||
|
||
background_tasks.add_task(
|
||
add_processing_tasks_to_queue,
|
||
task_data,
|
||
kb_id
|
||
)
|
||
|
||
return {"tasks": task_info}
|
||
|
||
async def add_processing_tasks_to_queue(task_data, kb_id):
|
||
"""Helper function to add document processing tasks to the queue without blocking the main response."""
|
||
for data in task_data:
|
||
asyncio.create_task(
|
||
process_document_background(
|
||
data["temp_path"],
|
||
data["file_name"],
|
||
kb_id,
|
||
data["task_id"],
|
||
None
|
||
)
|
||
)
|
||
logger.info(f"Added {len(task_data)} document processing tasks to queue")
|
||
|
||
@router.post("/cleanup")
|
||
async def cleanup_temp_files(
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
Clean up expired temporary files.
|
||
"""
|
||
expired_time = datetime.utcnow() - timedelta(hours=24)
|
||
expired_uploads = db.query(DocumentUpload).filter(
|
||
DocumentUpload.created_at < expired_time
|
||
).all()
|
||
|
||
minio_client = get_minio_client()
|
||
for upload in expired_uploads:
|
||
try:
|
||
minio_client.remove_object(
|
||
bucket_name=settings.MINIO_BUCKET_NAME,
|
||
object_name=upload.temp_path
|
||
)
|
||
except MinioException as e:
|
||
logger.error(f"Failed to delete temp file {upload.temp_path}: {str(e)}")
|
||
|
||
db.delete(upload)
|
||
|
||
db.commit()
|
||
|
||
return {"message": f"Cleaned up {len(expired_uploads)} expired uploads"}
|
||
|
||
@router.get("/{kb_id}/documents/tasks")
|
||
async def get_processing_tasks(
|
||
kb_id: int,
|
||
task_ids: str = Query(..., description="Comma-separated list of task IDs to check status for"),
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
Get status of multiple processing tasks.
|
||
"""
|
||
task_id_list = [int(id.strip()) for id in task_ids.split(",")]
|
||
|
||
kb = db.query(KnowledgeBase).filter(
|
||
KnowledgeBase.id == kb_id,
|
||
KnowledgeBase.user_id == current_user.id
|
||
).first()
|
||
|
||
if not kb:
|
||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||
|
||
tasks = (
|
||
db.query(ProcessingTask)
|
||
.options(
|
||
selectinload(ProcessingTask.document_upload)
|
||
)
|
||
.filter(
|
||
ProcessingTask.id.in_(task_id_list),
|
||
ProcessingTask.knowledge_base_id == kb_id
|
||
)
|
||
.all()
|
||
)
|
||
|
||
return {
|
||
task.id: {
|
||
"document_id": task.document_id,
|
||
"status": task.status,
|
||
"error_message": task.error_message,
|
||
"upload_id": task.document_upload_id,
|
||
"file_name": task.document_upload.file_name if task.document_upload else None
|
||
}
|
||
for task in tasks
|
||
}
|
||
|
||
@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse)
|
||
async def get_document(
|
||
*,
|
||
db: Session = Depends(get_db),
|
||
kb_id: int,
|
||
doc_id: int,
|
||
current_user: User = Depends(get_current_user)
|
||
) -> Any:
|
||
"""
|
||
Get document details by ID.
|
||
"""
|
||
document = (
|
||
db.query(Document)
|
||
.join(KnowledgeBase)
|
||
.filter(
|
||
Document.id == doc_id,
|
||
Document.knowledge_base_id == kb_id,
|
||
KnowledgeBase.user_id == current_user.id
|
||
)
|
||
.first()
|
||
)
|
||
|
||
if not document:
|
||
raise HTTPException(status_code=404, detail="Document not found")
|
||
|
||
return document
|
||
|
||
@router.post("/test-retrieval")
|
||
async def test_retrieval(
|
||
request: TestRetrievalRequest,
|
||
background_tasks: BackgroundTasks,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
) -> Any:
|
||
"""
|
||
Test retrieval quality for a given query against a knowledge base.
|
||
"""
|
||
try:
|
||
kb = db.query(KnowledgeBase).filter(
|
||
KnowledgeBase.id == request.kb_id,
|
||
KnowledgeBase.user_id == current_user.id
|
||
).first()
|
||
|
||
if not kb:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail=f"Knowledge base {request.kb_id} not found",
|
||
)
|
||
|
||
embeddings = EmbeddingsFactory.create()
|
||
|
||
vector_store = VectorStoreFactory.create(
|
||
store_type=settings.VECTOR_STORE_TYPE,
|
||
collection_name=f"kb_{request.kb_id}",
|
||
embedding_function=embeddings,
|
||
)
|
||
|
||
results = vector_store.similarity_search_with_score(request.query, k=request.top_k)
|
||
|
||
response = []
|
||
for doc, score in results:
|
||
response.append({
|
||
"content": doc.page_content,
|
||
"metadata": doc.metadata,
|
||
"score": float(score)
|
||
})
|
||
|
||
return {"results": response}
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|