Files
rag_agent/rag-web-ui/backend/app/api/api_v1/knowledge_base.py
2026-04-13 11:34:23 +08:00

576 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import hashlib
from typing import List, Any, Dict
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Query
from sqlalchemy.orm import Session
from langchain_chroma import Chroma
from sqlalchemy import text
import logging
from datetime import datetime, timedelta
from pydantic import BaseModel
from sqlalchemy.orm import selectinload
import time
import asyncio
from app.db.session import get_db
from app.models.user import User
from app.core.security import get_current_user
from app.models.knowledge import KnowledgeBase, Document, ProcessingTask, DocumentChunk, DocumentUpload
from app.schemas.knowledge import (
KnowledgeBaseCreate,
KnowledgeBaseResponse,
KnowledgeBaseUpdate,
DocumentResponse,
PreviewRequest
)
from app.services.document_processor import process_document_background, upload_document, preview_document, PreviewResult
from app.core.config import settings
from app.core.minio import get_minio_client
from minio.error import MinioException
from app.services.vector_store import VectorStoreFactory
from app.services.embedding.embedding_factory import EmbeddingsFactory
router = APIRouter()
logger = logging.getLogger(__name__)
class TestRetrievalRequest(BaseModel):
query: str
kb_id: int
top_k: int
@router.post("", response_model=KnowledgeBaseResponse)
def create_knowledge_base(
*,
db: Session = Depends(get_db),
kb_in: KnowledgeBaseCreate,
current_user: User = Depends(get_current_user)
) -> Any:
"""
Create new knowledge base.
"""
kb = KnowledgeBase(
name=kb_in.name,
description=kb_in.description,
user_id=current_user.id
)
db.add(kb)
db.commit()
db.refresh(kb)
logger.info(f"Knowledge base created: {kb.name} for user {current_user.id}")
return kb
@router.get("", response_model=List[KnowledgeBaseResponse])
def get_knowledge_bases(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
skip: int = 0,
limit: int = 100
) -> Any:
"""
Retrieve knowledge bases.
"""
knowledge_bases = (
db.query(KnowledgeBase)
.filter(KnowledgeBase.user_id == current_user.id)
.offset(skip)
.limit(limit)
.all()
)
return knowledge_bases
@router.get("/{kb_id}", response_model=KnowledgeBaseResponse)
def get_knowledge_base(
*,
db: Session = Depends(get_db),
kb_id: int,
current_user: User = Depends(get_current_user)
) -> Any:
"""
Get knowledge base by ID.
"""
from sqlalchemy.orm import joinedload
kb = (
db.query(KnowledgeBase)
.options(
joinedload(KnowledgeBase.documents)
.joinedload(Document.processing_tasks)
)
.filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
)
.first()
)
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
return kb
@router.put("/{kb_id}", response_model=KnowledgeBaseResponse)
def update_knowledge_base(
*,
db: Session = Depends(get_db),
kb_id: int,
kb_in: KnowledgeBaseUpdate,
current_user: User = Depends(get_current_user)
) -> Any:
"""
Update knowledge base.
"""
kb = db.query(KnowledgeBase).filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
for field, value in kb_in.dict(exclude_unset=True).items():
setattr(kb, field, value)
db.add(kb)
db.commit()
db.refresh(kb)
logger.info(f"Knowledge base updated: {kb.name} for user {current_user.id}")
return kb
@router.delete("/{kb_id}")
async def delete_knowledge_base(
*,
db: Session = Depends(get_db),
kb_id: int,
current_user: User = Depends(get_current_user)
) -> Any:
"""
Delete knowledge base and all associated resources.
"""
logger = logging.getLogger(__name__)
kb = (
db.query(KnowledgeBase)
.filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
)
.first()
)
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
try:
# Get all document file paths before deletion
document_paths = [doc.file_path for doc in kb.documents]
# Initialize services
minio_client = get_minio_client()
embeddings = EmbeddingsFactory.create()
vector_store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb_id}",
embedding_function=embeddings,
)
# Clean up external resources first
cleanup_errors = []
# 1. Clean up MinIO files
try:
# Delete all objects with prefix kb_{kb_id}/
objects = minio_client.list_objects(settings.MINIO_BUCKET_NAME, prefix=f"kb_{kb_id}/")
for obj in objects:
minio_client.remove_object(settings.MINIO_BUCKET_NAME, obj.object_name)
logger.info(f"Cleaned up MinIO files for knowledge base {kb_id}")
except MinioException as e:
cleanup_errors.append(f"Failed to clean up MinIO files: {str(e)}")
logger.error(f"MinIO cleanup error for kb {kb_id}: {str(e)}")
# 2. Clean up vector store
try:
vector_store._store.delete_collection(f"kb_{kb_id}")
logger.info(f"Cleaned up vector store for knowledge base {kb_id}")
except Exception as e:
cleanup_errors.append(f"Failed to clean up vector store: {str(e)}")
logger.error(f"Vector store cleanup error for kb {kb_id}: {str(e)}")
# Finally, delete database records in a single transaction
db.delete(kb)
db.commit()
# Report any cleanup errors in the response
if cleanup_errors:
return {
"message": "Knowledge base deleted with cleanup warnings",
"warnings": cleanup_errors
}
return {"message": "Knowledge base and all associated resources deleted successfully"}
except Exception as e:
db.rollback()
logger.error(f"Failed to delete knowledge base {kb_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to delete knowledge base: {str(e)}")
# Batch upload documents
@router.post("/{kb_id}/documents/upload")
async def upload_kb_documents(
kb_id: int,
files: List[UploadFile],
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Upload multiple documents to MinIO.
"""
kb = db.query(KnowledgeBase).filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
results = []
for file in files:
# 1. 计算文件 hash
file_content = await file.read()
file_hash = hashlib.sha256(file_content).hexdigest()
# 2. 检查是否存在完全相同的文件名称和hash都相同
existing_document = db.query(Document).filter(
Document.file_name == file.filename,
Document.file_hash == file_hash,
Document.knowledge_base_id == kb_id
).first()
if existing_document:
# 完全相同的文件,直接返回
results.append({
"document_id": existing_document.id,
"file_name": existing_document.file_name,
"status": "exists",
"message": "文件已存在且已处理完成",
"skip_processing": True
})
continue
# 3. 上传到临时目录
temp_path = f"kb_{kb_id}/temp/{file.filename}"
await file.seek(0)
try:
minio_client = get_minio_client()
file_size = len(file_content) # 使用之前读取的文件内容长度
minio_client.put_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=temp_path,
data=file.file,
length=file_size, # 指定文件大小
content_type=file.content_type
)
except MinioException as e:
logger.error(f"Failed to upload file to MinIO: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to upload file")
# 4. 创建上传记录
upload = DocumentUpload(
knowledge_base_id=kb_id,
file_name=file.filename,
file_hash=file_hash,
file_size=len(file_content),
content_type=file.content_type,
temp_path=temp_path
)
db.add(upload)
db.commit()
db.refresh(upload)
results.append({
"upload_id": upload.id,
"file_name": file.filename,
"temp_path": temp_path,
"status": "pending",
"skip_processing": False
})
return results
@router.post("/{kb_id}/documents/preview")
async def preview_kb_documents(
kb_id: int,
preview_request: PreviewRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> Dict[int, PreviewResult]:
"""
Preview multiple documents' chunks.
"""
results = {}
for doc_id in preview_request.document_ids:
document = db.query(Document).join(KnowledgeBase).filter(
Document.id == doc_id,
Document.knowledge_base_id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if document:
file_path = document.file_path
else:
upload = db.query(DocumentUpload).join(KnowledgeBase).filter(
DocumentUpload.id == doc_id,
DocumentUpload.knowledge_base_id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not upload:
raise HTTPException(status_code=404, detail=f"Document {doc_id} not found")
file_path = upload.temp_path
preview = await preview_document(
file_path,
chunk_size=preview_request.chunk_size,
chunk_overlap=preview_request.chunk_overlap
)
results[doc_id] = preview
return results
@router.post("/{kb_id}/documents/process")
async def process_kb_documents(
kb_id: int,
upload_results: List[dict],
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Process multiple documents asynchronously.
"""
start_time = time.time()
kb = db.query(KnowledgeBase).filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
task_info = []
upload_ids = []
for result in upload_results:
if result.get("skip_processing"):
continue
upload_ids.append(result["upload_id"])
if not upload_ids:
return {"tasks": []}
uploads = db.query(DocumentUpload).filter(DocumentUpload.id.in_(upload_ids)).all()
uploads_dict = {upload.id: upload for upload in uploads}
all_tasks = []
for upload_id in upload_ids:
upload = uploads_dict.get(upload_id)
if not upload:
continue
task = ProcessingTask(
document_upload_id=upload_id,
knowledge_base_id=kb_id,
status="pending"
)
all_tasks.append(task)
db.add_all(all_tasks)
db.commit()
for task in all_tasks:
db.refresh(task)
task_data = []
for i, upload_id in enumerate(upload_ids):
if i < len(all_tasks):
task = all_tasks[i]
upload = uploads_dict.get(upload_id)
task_info.append({
"upload_id": upload_id,
"task_id": task.id
})
if upload:
task_data.append({
"task_id": task.id,
"upload_id": upload_id,
"temp_path": upload.temp_path,
"file_name": upload.file_name
})
background_tasks.add_task(
add_processing_tasks_to_queue,
task_data,
kb_id
)
return {"tasks": task_info}
async def add_processing_tasks_to_queue(task_data, kb_id):
"""Helper function to add document processing tasks to the queue without blocking the main response."""
for data in task_data:
asyncio.create_task(
process_document_background(
data["temp_path"],
data["file_name"],
kb_id,
data["task_id"],
None
)
)
logger.info(f"Added {len(task_data)} document processing tasks to queue")
@router.post("/cleanup")
async def cleanup_temp_files(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Clean up expired temporary files.
"""
expired_time = datetime.utcnow() - timedelta(hours=24)
expired_uploads = db.query(DocumentUpload).filter(
DocumentUpload.created_at < expired_time
).all()
minio_client = get_minio_client()
for upload in expired_uploads:
try:
minio_client.remove_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=upload.temp_path
)
except MinioException as e:
logger.error(f"Failed to delete temp file {upload.temp_path}: {str(e)}")
db.delete(upload)
db.commit()
return {"message": f"Cleaned up {len(expired_uploads)} expired uploads"}
@router.get("/{kb_id}/documents/tasks")
async def get_processing_tasks(
kb_id: int,
task_ids: str = Query(..., description="Comma-separated list of task IDs to check status for"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Get status of multiple processing tasks.
"""
task_id_list = [int(id.strip()) for id in task_ids.split(",")]
kb = db.query(KnowledgeBase).filter(
KnowledgeBase.id == kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
tasks = (
db.query(ProcessingTask)
.options(
selectinload(ProcessingTask.document_upload)
)
.filter(
ProcessingTask.id.in_(task_id_list),
ProcessingTask.knowledge_base_id == kb_id
)
.all()
)
return {
task.id: {
"document_id": task.document_id,
"status": task.status,
"error_message": task.error_message,
"upload_id": task.document_upload_id,
"file_name": task.document_upload.file_name if task.document_upload else None
}
for task in tasks
}
@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse)
async def get_document(
*,
db: Session = Depends(get_db),
kb_id: int,
doc_id: int,
current_user: User = Depends(get_current_user)
) -> Any:
"""
Get document details by ID.
"""
document = (
db.query(Document)
.join(KnowledgeBase)
.filter(
Document.id == doc_id,
Document.knowledge_base_id == kb_id,
KnowledgeBase.user_id == current_user.id
)
.first()
)
if not document:
raise HTTPException(status_code=404, detail="Document not found")
return document
@router.post("/test-retrieval")
async def test_retrieval(
request: TestRetrievalRequest,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> Any:
"""
Test retrieval quality for a given query against a knowledge base.
"""
try:
kb = db.query(KnowledgeBase).filter(
KnowledgeBase.id == request.kb_id,
KnowledgeBase.user_id == current_user.id
).first()
if not kb:
raise HTTPException(
status_code=404,
detail=f"Knowledge base {request.kb_id} not found",
)
embeddings = EmbeddingsFactory.create()
vector_store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{request.kb_id}",
embedding_function=embeddings,
)
results = vector_store.similarity_search_with_score(request.query, k=request.top_k)
response = []
for doc, score in results:
response.append({
"content": doc.page_content,
"metadata": doc.metadata,
"score": float(score)
})
return {"results": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))