import hashlib from typing import List, Any, Dict from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Query from sqlalchemy.orm import Session from langchain_chroma import Chroma from sqlalchemy import text import logging from datetime import datetime, timedelta from pydantic import BaseModel from sqlalchemy.orm import selectinload import time import asyncio from app.db.session import get_db from app.models.user import User from app.core.security import get_current_user from app.models.knowledge import KnowledgeBase, Document, ProcessingTask, DocumentChunk, DocumentUpload from app.schemas.knowledge import ( KnowledgeBaseCreate, KnowledgeBaseResponse, KnowledgeBaseUpdate, DocumentResponse, PreviewRequest ) from app.services.document_processor import process_document_background, upload_document, preview_document, PreviewResult from app.core.config import settings from app.core.minio import get_minio_client from minio.error import MinioException from app.services.vector_store import VectorStoreFactory from app.services.embedding.embedding_factory import EmbeddingsFactory router = APIRouter() logger = logging.getLogger(__name__) class TestRetrievalRequest(BaseModel): query: str kb_id: int top_k: int @router.post("", response_model=KnowledgeBaseResponse) def create_knowledge_base( *, db: Session = Depends(get_db), kb_in: KnowledgeBaseCreate, current_user: User = Depends(get_current_user) ) -> Any: """ Create new knowledge base. """ kb = KnowledgeBase( name=kb_in.name, description=kb_in.description, user_id=current_user.id ) db.add(kb) db.commit() db.refresh(kb) logger.info(f"Knowledge base created: {kb.name} for user {current_user.id}") return kb @router.get("", response_model=List[KnowledgeBaseResponse]) def get_knowledge_bases( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), skip: int = 0, limit: int = 100 ) -> Any: """ Retrieve knowledge bases. """ knowledge_bases = ( db.query(KnowledgeBase) .filter(KnowledgeBase.user_id == current_user.id) .offset(skip) .limit(limit) .all() ) return knowledge_bases @router.get("/{kb_id}", response_model=KnowledgeBaseResponse) def get_knowledge_base( *, db: Session = Depends(get_db), kb_id: int, current_user: User = Depends(get_current_user) ) -> Any: """ Get knowledge base by ID. """ from sqlalchemy.orm import joinedload kb = ( db.query(KnowledgeBase) .options( joinedload(KnowledgeBase.documents) .joinedload(Document.processing_tasks) ) .filter( KnowledgeBase.id == kb_id, KnowledgeBase.user_id == current_user.id ) .first() ) if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") return kb @router.put("/{kb_id}", response_model=KnowledgeBaseResponse) def update_knowledge_base( *, db: Session = Depends(get_db), kb_id: int, kb_in: KnowledgeBaseUpdate, current_user: User = Depends(get_current_user) ) -> Any: """ Update knowledge base. """ kb = db.query(KnowledgeBase).filter( KnowledgeBase.id == kb_id, KnowledgeBase.user_id == current_user.id ).first() if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") for field, value in kb_in.dict(exclude_unset=True).items(): setattr(kb, field, value) db.add(kb) db.commit() db.refresh(kb) logger.info(f"Knowledge base updated: {kb.name} for user {current_user.id}") return kb @router.delete("/{kb_id}") async def delete_knowledge_base( *, db: Session = Depends(get_db), kb_id: int, current_user: User = Depends(get_current_user) ) -> Any: """ Delete knowledge base and all associated resources. """ logger = logging.getLogger(__name__) kb = ( db.query(KnowledgeBase) .filter( KnowledgeBase.id == kb_id, KnowledgeBase.user_id == current_user.id ) .first() ) if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") try: # Get all document file paths before deletion document_paths = [doc.file_path for doc in kb.documents] # Initialize services minio_client = get_minio_client() embeddings = EmbeddingsFactory.create() vector_store = VectorStoreFactory.create( store_type=settings.VECTOR_STORE_TYPE, collection_name=f"kb_{kb_id}", embedding_function=embeddings, ) # Clean up external resources first cleanup_errors = [] # 1. Clean up MinIO files try: # Delete all objects with prefix kb_{kb_id}/ objects = minio_client.list_objects(settings.MINIO_BUCKET_NAME, prefix=f"kb_{kb_id}/") for obj in objects: minio_client.remove_object(settings.MINIO_BUCKET_NAME, obj.object_name) logger.info(f"Cleaned up MinIO files for knowledge base {kb_id}") except MinioException as e: cleanup_errors.append(f"Failed to clean up MinIO files: {str(e)}") logger.error(f"MinIO cleanup error for kb {kb_id}: {str(e)}") # 2. Clean up vector store try: vector_store._store.delete_collection(f"kb_{kb_id}") logger.info(f"Cleaned up vector store for knowledge base {kb_id}") except Exception as e: cleanup_errors.append(f"Failed to clean up vector store: {str(e)}") logger.error(f"Vector store cleanup error for kb {kb_id}: {str(e)}") # Finally, delete database records in a single transaction db.delete(kb) db.commit() # Report any cleanup errors in the response if cleanup_errors: return { "message": "Knowledge base deleted with cleanup warnings", "warnings": cleanup_errors } return {"message": "Knowledge base and all associated resources deleted successfully"} except Exception as e: db.rollback() logger.error(f"Failed to delete knowledge base {kb_id}: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to delete knowledge base: {str(e)}") # Batch upload documents @router.post("/{kb_id}/documents/upload") async def upload_kb_documents( kb_id: int, files: List[UploadFile], db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """ Upload multiple documents to MinIO. """ kb = db.query(KnowledgeBase).filter( KnowledgeBase.id == kb_id, KnowledgeBase.user_id == current_user.id ).first() if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") results = [] for file in files: # 1. 计算文件 hash file_content = await file.read() file_hash = hashlib.sha256(file_content).hexdigest() # 2. 检查是否存在完全相同的文件(名称和hash都相同) existing_document = db.query(Document).filter( Document.file_name == file.filename, Document.file_hash == file_hash, Document.knowledge_base_id == kb_id ).first() if existing_document: # 完全相同的文件,直接返回 results.append({ "document_id": existing_document.id, "file_name": existing_document.file_name, "status": "exists", "message": "文件已存在且已处理完成", "skip_processing": True }) continue # 3. 上传到临时目录 temp_path = f"kb_{kb_id}/temp/{file.filename}" await file.seek(0) try: minio_client = get_minio_client() file_size = len(file_content) # 使用之前读取的文件内容长度 minio_client.put_object( bucket_name=settings.MINIO_BUCKET_NAME, object_name=temp_path, data=file.file, length=file_size, # 指定文件大小 content_type=file.content_type ) except MinioException as e: logger.error(f"Failed to upload file to MinIO: {str(e)}") raise HTTPException(status_code=500, detail="Failed to upload file") # 4. 创建上传记录 upload = DocumentUpload( knowledge_base_id=kb_id, file_name=file.filename, file_hash=file_hash, file_size=len(file_content), content_type=file.content_type, temp_path=temp_path ) db.add(upload) db.commit() db.refresh(upload) results.append({ "upload_id": upload.id, "file_name": file.filename, "temp_path": temp_path, "status": "pending", "skip_processing": False }) return results @router.post("/{kb_id}/documents/preview") async def preview_kb_documents( kb_id: int, preview_request: PreviewRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ) -> Dict[int, PreviewResult]: """ Preview multiple documents' chunks. """ results = {} for doc_id in preview_request.document_ids: document = db.query(Document).join(KnowledgeBase).filter( Document.id == doc_id, Document.knowledge_base_id == kb_id, KnowledgeBase.user_id == current_user.id ).first() if document: file_path = document.file_path else: upload = db.query(DocumentUpload).join(KnowledgeBase).filter( DocumentUpload.id == doc_id, DocumentUpload.knowledge_base_id == kb_id, KnowledgeBase.user_id == current_user.id ).first() if not upload: raise HTTPException(status_code=404, detail=f"Document {doc_id} not found") file_path = upload.temp_path preview = await preview_document( file_path, chunk_size=preview_request.chunk_size, chunk_overlap=preview_request.chunk_overlap ) results[doc_id] = preview return results @router.post("/{kb_id}/documents/process") async def process_kb_documents( kb_id: int, upload_results: List[dict], background_tasks: BackgroundTasks, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """ Process multiple documents asynchronously. """ start_time = time.time() kb = db.query(KnowledgeBase).filter( KnowledgeBase.id == kb_id, KnowledgeBase.user_id == current_user.id ).first() if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") task_info = [] upload_ids = [] for result in upload_results: if result.get("skip_processing"): continue upload_ids.append(result["upload_id"]) if not upload_ids: return {"tasks": []} uploads = db.query(DocumentUpload).filter(DocumentUpload.id.in_(upload_ids)).all() uploads_dict = {upload.id: upload for upload in uploads} all_tasks = [] for upload_id in upload_ids: upload = uploads_dict.get(upload_id) if not upload: continue task = ProcessingTask( document_upload_id=upload_id, knowledge_base_id=kb_id, status="pending" ) all_tasks.append(task) db.add_all(all_tasks) db.commit() for task in all_tasks: db.refresh(task) task_data = [] for i, upload_id in enumerate(upload_ids): if i < len(all_tasks): task = all_tasks[i] upload = uploads_dict.get(upload_id) task_info.append({ "upload_id": upload_id, "task_id": task.id }) if upload: task_data.append({ "task_id": task.id, "upload_id": upload_id, "temp_path": upload.temp_path, "file_name": upload.file_name }) background_tasks.add_task( add_processing_tasks_to_queue, task_data, kb_id ) return {"tasks": task_info} async def add_processing_tasks_to_queue(task_data, kb_id): """Helper function to add document processing tasks to the queue without blocking the main response.""" for data in task_data: asyncio.create_task( process_document_background( data["temp_path"], data["file_name"], kb_id, data["task_id"], None ) ) logger.info(f"Added {len(task_data)} document processing tasks to queue") @router.post("/cleanup") async def cleanup_temp_files( db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """ Clean up expired temporary files. """ expired_time = datetime.utcnow() - timedelta(hours=24) expired_uploads = db.query(DocumentUpload).filter( DocumentUpload.created_at < expired_time ).all() minio_client = get_minio_client() for upload in expired_uploads: try: minio_client.remove_object( bucket_name=settings.MINIO_BUCKET_NAME, object_name=upload.temp_path ) except MinioException as e: logger.error(f"Failed to delete temp file {upload.temp_path}: {str(e)}") db.delete(upload) db.commit() return {"message": f"Cleaned up {len(expired_uploads)} expired uploads"} @router.get("/{kb_id}/documents/tasks") async def get_processing_tasks( kb_id: int, task_ids: str = Query(..., description="Comma-separated list of task IDs to check status for"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """ Get status of multiple processing tasks. """ task_id_list = [int(id.strip()) for id in task_ids.split(",")] kb = db.query(KnowledgeBase).filter( KnowledgeBase.id == kb_id, KnowledgeBase.user_id == current_user.id ).first() if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") tasks = ( db.query(ProcessingTask) .options( selectinload(ProcessingTask.document_upload) ) .filter( ProcessingTask.id.in_(task_id_list), ProcessingTask.knowledge_base_id == kb_id ) .all() ) return { task.id: { "document_id": task.document_id, "status": task.status, "error_message": task.error_message, "upload_id": task.document_upload_id, "file_name": task.document_upload.file_name if task.document_upload else None } for task in tasks } @router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse) async def get_document( *, db: Session = Depends(get_db), kb_id: int, doc_id: int, current_user: User = Depends(get_current_user) ) -> Any: """ Get document details by ID. """ document = ( db.query(Document) .join(KnowledgeBase) .filter( Document.id == doc_id, Document.knowledge_base_id == kb_id, KnowledgeBase.user_id == current_user.id ) .first() ) if not document: raise HTTPException(status_code=404, detail="Document not found") return document @router.post("/test-retrieval") async def test_retrieval( request: TestRetrievalRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ) -> Any: """ Test retrieval quality for a given query against a knowledge base. """ try: kb = db.query(KnowledgeBase).filter( KnowledgeBase.id == request.kb_id, KnowledgeBase.user_id == current_user.id ).first() if not kb: raise HTTPException( status_code=404, detail=f"Knowledge base {request.kb_id} not found", ) embeddings = EmbeddingsFactory.create() vector_store = VectorStoreFactory.create( store_type=settings.VECTOR_STORE_TYPE, collection_name=f"kb_{request.kb_id}", embedding_function=embeddings, ) results = vector_store.similarity_search_with_score(request.query, k=request.top_k) response = [] for doc, score in results: response.append({ "content": doc.page_content, "metadata": doc.metadata, "score": float(score) }) return {"results": response} except Exception as e: raise HTTPException(status_code=500, detail=str(e))