init. project
This commit is contained in:
0
rag-web-ui/backend/app/api/api_v1/__init__.py
Normal file
0
rag-web-ui/backend/app/api/api_v1/__init__.py
Normal file
11
rag-web-ui/backend/app/api/api_v1/api.py
Normal file
11
rag-web-ui/backend/app/api/api_v1/api.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from fastapi import APIRouter
|
||||
from app.api.api_v1 import api_keys, auth, chat, knowledge_base, testing, tools
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
api_router.include_router(knowledge_base.router, prefix="/knowledge-base", tags=["knowledge-base"])
|
||||
api_router.include_router(chat.router, prefix="/chat", tags=["chat"])
|
||||
api_router.include_router(api_keys.router, prefix="/api-keys", tags=["api-keys"])
|
||||
api_router.include_router(testing.router, prefix="/testing", tags=["testing"])
|
||||
api_router.include_router(tools.router, prefix="/tools", tags=["tools"])
|
||||
84
rag-web-ui/backend/app/api/api_v1/api_keys.py
Normal file
84
rag-web-ui/backend/app/api/api_v1/api_keys.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from typing import Any, List
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
import logging
|
||||
|
||||
from app import models, schemas
|
||||
from app.db.session import get_db
|
||||
from app.services.api_key import APIKeyService
|
||||
from app.core.security import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@router.get("/", response_model=List[schemas.APIKey])
|
||||
def read_api_keys(
|
||||
db: Session = Depends(get_db),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: models.User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieve API keys.
|
||||
"""
|
||||
api_keys = APIKeyService.get_api_keys(
|
||||
db=db, user_id=current_user.id, skip=skip, limit=limit
|
||||
)
|
||||
return api_keys
|
||||
|
||||
@router.post("/", response_model=schemas.APIKey)
|
||||
def create_api_key(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
api_key_in: schemas.APIKeyCreate,
|
||||
current_user: models.User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Create new API key.
|
||||
"""
|
||||
api_key = APIKeyService.create_api_key(
|
||||
db=db, user_id=current_user.id, name=api_key_in.name
|
||||
)
|
||||
logger.info(f"API key created: {api_key.key} for user {current_user.id}")
|
||||
return api_key
|
||||
|
||||
@router.put("/{id}", response_model=schemas.APIKey)
|
||||
def update_api_key(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
id: int,
|
||||
api_key_in: schemas.APIKeyUpdate,
|
||||
current_user: models.User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Update API key.
|
||||
"""
|
||||
api_key = APIKeyService.get_api_key(db=db, api_key_id=id)
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
if api_key.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
api_key = APIKeyService.update_api_key(db=db, api_key=api_key, update_data=api_key_in)
|
||||
logger.info(f"API key updated: {api_key.key} for user {current_user.id}")
|
||||
return api_key
|
||||
|
||||
@router.delete("/{id}", response_model=schemas.APIKey)
|
||||
def delete_api_key(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
id: int,
|
||||
current_user: models.User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Delete API key.
|
||||
"""
|
||||
api_key = APIKeyService.get_api_key(db=db, api_key_id=id)
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
if api_key.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
APIKeyService.delete_api_key(db=db, api_key=api_key)
|
||||
logger.info(f"API key deleted: {api_key.key} for user {current_user.id}")
|
||||
return api_key
|
||||
88
rag-web-ui/backend/app/api/api_v1/auth.py
Normal file
88
rag-web-ui/backend/app/api/api_v1/auth.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from app.core import security
|
||||
from app.core.security import get_current_user
|
||||
from app.core.config import settings
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.token import Token
|
||||
from app.schemas.user import UserCreate, UserResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/register", response_model=UserResponse)
|
||||
def register(*, db: Session = Depends(get_db), user_in: UserCreate) -> Any:
|
||||
"""
|
||||
Register a new user.
|
||||
"""
|
||||
try:
|
||||
# Check if user with this email exists
|
||||
user = db.query(User).filter(User.email == user_in.email).first()
|
||||
if user:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="A user with this email already exists.",
|
||||
)
|
||||
|
||||
# Check if user with this username exists
|
||||
user = db.query(User).filter(User.username == user_in.username).first()
|
||||
if user:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="A user with this username already exists.",
|
||||
)
|
||||
|
||||
# Create new user
|
||||
user = User(
|
||||
email=user_in.email,
|
||||
username=user_in.username,
|
||||
hashed_password=security.get_password_hash(user_in.password),
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
except RequestException as e:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Network error or server is unreachable. Please try again later.",
|
||||
) from e
|
||||
|
||||
@router.post("/token", response_model=Token)
|
||||
def login_access_token(
|
||||
db: Session = Depends(get_db), form_data: OAuth2PasswordRequestForm = Depends()
|
||||
) -> Any:
|
||||
"""
|
||||
OAuth2 compatible token login, get an access token for future requests.
|
||||
"""
|
||||
user = db.query(User).filter(User.username == form_data.username).first()
|
||||
if not user or not security.verify_password(form_data.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
elif not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Inactive user",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = security.create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
@router.post("/test-token", response_model=UserResponse)
|
||||
def test_token(current_user: User = Depends(get_current_user)) -> Any:
|
||||
"""
|
||||
Test access token by getting current user.
|
||||
"""
|
||||
return current_user
|
||||
155
rag-web-ui/backend/app/api/api_v1/chat.py
Normal file
155
rag-web-ui/backend/app/api/api_v1/chat.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from typing import List, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.models.chat import Chat, Message
|
||||
from app.models.knowledge import KnowledgeBase
|
||||
from app.schemas.chat import (
|
||||
ChatCreate,
|
||||
ChatResponse,
|
||||
ChatUpdate,
|
||||
MessageCreate,
|
||||
MessageResponse
|
||||
)
|
||||
from app.core.security import get_current_user
|
||||
from app.services.chat_service import generate_response
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/", response_model=ChatResponse)
|
||||
def create_chat(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
chat_in: ChatCreate,
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
# Verify knowledge bases exist and belong to user
|
||||
knowledge_bases = (
|
||||
db.query(KnowledgeBase)
|
||||
.filter(
|
||||
KnowledgeBase.id.in_(chat_in.knowledge_base_ids),
|
||||
KnowledgeBase.user_id == current_user.id
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if len(knowledge_bases) != len(chat_in.knowledge_base_ids):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="One or more knowledge bases not found"
|
||||
)
|
||||
|
||||
chat = Chat(
|
||||
title=chat_in.title,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
chat.knowledge_bases = knowledge_bases
|
||||
|
||||
db.add(chat)
|
||||
db.commit()
|
||||
db.refresh(chat)
|
||||
return chat
|
||||
|
||||
@router.get("/", response_model=List[ChatResponse])
|
||||
def get_chats(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> Any:
|
||||
chats = (
|
||||
db.query(Chat)
|
||||
.filter(Chat.user_id == current_user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return chats
|
||||
|
||||
@router.get("/{chat_id}", response_model=ChatResponse)
|
||||
def get_chat(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
chat_id: int,
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
chat = (
|
||||
db.query(Chat)
|
||||
.filter(
|
||||
Chat.id == chat_id,
|
||||
Chat.user_id == current_user.id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not chat:
|
||||
raise HTTPException(status_code=404, detail="Chat not found")
|
||||
return chat
|
||||
|
||||
@router.post("/{chat_id}/messages")
|
||||
async def create_message(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
chat_id: int,
|
||||
messages: dict,
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> StreamingResponse:
|
||||
chat = (
|
||||
db.query(Chat)
|
||||
.options(joinedload(Chat.knowledge_bases))
|
||||
.filter(
|
||||
Chat.id == chat_id,
|
||||
Chat.user_id == current_user.id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not chat:
|
||||
raise HTTPException(status_code=404, detail="Chat not found")
|
||||
|
||||
# Get the last user message
|
||||
last_message = messages["messages"][-1]
|
||||
if last_message["role"] != "user":
|
||||
raise HTTPException(status_code=400, detail="Last message must be from user")
|
||||
|
||||
# Get knowledge base IDs
|
||||
knowledge_base_ids = [kb.id for kb in chat.knowledge_bases]
|
||||
|
||||
async def response_stream():
|
||||
async for chunk in generate_response(
|
||||
query=last_message["content"],
|
||||
messages=messages,
|
||||
knowledge_base_ids=knowledge_base_ids,
|
||||
chat_id=chat_id,
|
||||
db=db
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
response_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"x-vercel-ai-data-stream": "v1"
|
||||
}
|
||||
)
|
||||
|
||||
@router.delete("/{chat_id}")
|
||||
def delete_chat(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
chat_id: int,
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
chat = (
|
||||
db.query(Chat)
|
||||
.filter(
|
||||
Chat.id == chat_id,
|
||||
Chat.user_id == current_user.id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not chat:
|
||||
raise HTTPException(status_code=404, detail="Chat not found")
|
||||
|
||||
db.delete(chat)
|
||||
db.commit()
|
||||
return {"status": "success"}
|
||||
575
rag-web-ui/backend/app/api/api_v1/knowledge_base.py
Normal file
575
rag-web-ui/backend/app/api/api_v1/knowledge_base.py
Normal file
@@ -0,0 +1,575 @@
|
||||
import hashlib
|
||||
from typing import List, Any, Dict
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from langchain_chroma import Chroma
|
||||
from sqlalchemy import text
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import selectinload
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.core.security import get_current_user
|
||||
from app.models.knowledge import KnowledgeBase, Document, ProcessingTask, DocumentChunk, DocumentUpload
|
||||
from app.schemas.knowledge import (
|
||||
KnowledgeBaseCreate,
|
||||
KnowledgeBaseResponse,
|
||||
KnowledgeBaseUpdate,
|
||||
DocumentResponse,
|
||||
PreviewRequest
|
||||
)
|
||||
from app.services.document_processor import process_document_background, upload_document, preview_document, PreviewResult
|
||||
from app.core.config import settings
|
||||
from app.core.minio import get_minio_client
|
||||
from minio.error import MinioException
|
||||
from app.services.vector_store import VectorStoreFactory
|
||||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TestRetrievalRequest(BaseModel):
|
||||
query: str
|
||||
kb_id: int
|
||||
top_k: int
|
||||
|
||||
@router.post("", response_model=KnowledgeBaseResponse)
|
||||
def create_knowledge_base(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
kb_in: KnowledgeBaseCreate,
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
"""
|
||||
Create new knowledge base.
|
||||
"""
|
||||
kb = KnowledgeBase(
|
||||
name=kb_in.name,
|
||||
description=kb_in.description,
|
||||
user_id=current_user.id
|
||||
)
|
||||
db.add(kb)
|
||||
db.commit()
|
||||
db.refresh(kb)
|
||||
logger.info(f"Knowledge base created: {kb.name} for user {current_user.id}")
|
||||
return kb
|
||||
|
||||
@router.get("", response_model=List[KnowledgeBaseResponse])
|
||||
def get_knowledge_bases(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieve knowledge bases.
|
||||
"""
|
||||
knowledge_bases = (
|
||||
db.query(KnowledgeBase)
|
||||
.filter(KnowledgeBase.user_id == current_user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return knowledge_bases
|
||||
|
||||
@router.get("/{kb_id}", response_model=KnowledgeBaseResponse)
|
||||
def get_knowledge_base(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
kb_id: int,
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
"""
|
||||
Get knowledge base by ID.
|
||||
"""
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
kb = (
|
||||
db.query(KnowledgeBase)
|
||||
.options(
|
||||
joinedload(KnowledgeBase.documents)
|
||||
.joinedload(Document.processing_tasks)
|
||||
)
|
||||
.filter(
|
||||
KnowledgeBase.id == kb_id,
|
||||
KnowledgeBase.user_id == current_user.id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
|
||||
return kb
|
||||
|
||||
@router.put("/{kb_id}", response_model=KnowledgeBaseResponse)
|
||||
def update_knowledge_base(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
kb_id: int,
|
||||
kb_in: KnowledgeBaseUpdate,
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
"""
|
||||
Update knowledge base.
|
||||
"""
|
||||
kb = db.query(KnowledgeBase).filter(
|
||||
KnowledgeBase.id == kb_id,
|
||||
KnowledgeBase.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
|
||||
for field, value in kb_in.dict(exclude_unset=True).items():
|
||||
setattr(kb, field, value)
|
||||
|
||||
db.add(kb)
|
||||
db.commit()
|
||||
db.refresh(kb)
|
||||
logger.info(f"Knowledge base updated: {kb.name} for user {current_user.id}")
|
||||
return kb
|
||||
|
||||
@router.delete("/{kb_id}")
|
||||
async def delete_knowledge_base(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
kb_id: int,
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
"""
|
||||
Delete knowledge base and all associated resources.
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
kb = (
|
||||
db.query(KnowledgeBase)
|
||||
.filter(
|
||||
KnowledgeBase.id == kb_id,
|
||||
KnowledgeBase.user_id == current_user.id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
|
||||
try:
|
||||
# Get all document file paths before deletion
|
||||
document_paths = [doc.file_path for doc in kb.documents]
|
||||
|
||||
# Initialize services
|
||||
minio_client = get_minio_client()
|
||||
embeddings = EmbeddingsFactory.create()
|
||||
|
||||
vector_store = VectorStoreFactory.create(
|
||||
store_type=settings.VECTOR_STORE_TYPE,
|
||||
collection_name=f"kb_{kb_id}",
|
||||
embedding_function=embeddings,
|
||||
)
|
||||
|
||||
# Clean up external resources first
|
||||
cleanup_errors = []
|
||||
|
||||
# 1. Clean up MinIO files
|
||||
try:
|
||||
# Delete all objects with prefix kb_{kb_id}/
|
||||
objects = minio_client.list_objects(settings.MINIO_BUCKET_NAME, prefix=f"kb_{kb_id}/")
|
||||
for obj in objects:
|
||||
minio_client.remove_object(settings.MINIO_BUCKET_NAME, obj.object_name)
|
||||
logger.info(f"Cleaned up MinIO files for knowledge base {kb_id}")
|
||||
except MinioException as e:
|
||||
cleanup_errors.append(f"Failed to clean up MinIO files: {str(e)}")
|
||||
logger.error(f"MinIO cleanup error for kb {kb_id}: {str(e)}")
|
||||
|
||||
# 2. Clean up vector store
|
||||
try:
|
||||
vector_store._store.delete_collection(f"kb_{kb_id}")
|
||||
logger.info(f"Cleaned up vector store for knowledge base {kb_id}")
|
||||
except Exception as e:
|
||||
cleanup_errors.append(f"Failed to clean up vector store: {str(e)}")
|
||||
logger.error(f"Vector store cleanup error for kb {kb_id}: {str(e)}")
|
||||
|
||||
# Finally, delete database records in a single transaction
|
||||
db.delete(kb)
|
||||
db.commit()
|
||||
|
||||
# Report any cleanup errors in the response
|
||||
if cleanup_errors:
|
||||
return {
|
||||
"message": "Knowledge base deleted with cleanup warnings",
|
||||
"warnings": cleanup_errors
|
||||
}
|
||||
|
||||
return {"message": "Knowledge base and all associated resources deleted successfully"}
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Failed to delete knowledge base {kb_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete knowledge base: {str(e)}")
|
||||
|
||||
# Batch upload documents
|
||||
@router.post("/{kb_id}/documents/upload")
|
||||
async def upload_kb_documents(
|
||||
kb_id: int,
|
||||
files: List[UploadFile],
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Upload multiple documents to MinIO.
|
||||
"""
|
||||
kb = db.query(KnowledgeBase).filter(
|
||||
KnowledgeBase.id == kb_id,
|
||||
KnowledgeBase.user_id == current_user.id
|
||||
).first()
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
|
||||
results = []
|
||||
for file in files:
|
||||
# 1. 计算文件 hash
|
||||
file_content = await file.read()
|
||||
file_hash = hashlib.sha256(file_content).hexdigest()
|
||||
|
||||
# 2. 检查是否存在完全相同的文件(名称和hash都相同)
|
||||
existing_document = db.query(Document).filter(
|
||||
Document.file_name == file.filename,
|
||||
Document.file_hash == file_hash,
|
||||
Document.knowledge_base_id == kb_id
|
||||
).first()
|
||||
|
||||
if existing_document:
|
||||
# 完全相同的文件,直接返回
|
||||
results.append({
|
||||
"document_id": existing_document.id,
|
||||
"file_name": existing_document.file_name,
|
||||
"status": "exists",
|
||||
"message": "文件已存在且已处理完成",
|
||||
"skip_processing": True
|
||||
})
|
||||
continue
|
||||
|
||||
# 3. 上传到临时目录
|
||||
temp_path = f"kb_{kb_id}/temp/{file.filename}"
|
||||
await file.seek(0)
|
||||
try:
|
||||
minio_client = get_minio_client()
|
||||
file_size = len(file_content) # 使用之前读取的文件内容长度
|
||||
minio_client.put_object(
|
||||
bucket_name=settings.MINIO_BUCKET_NAME,
|
||||
object_name=temp_path,
|
||||
data=file.file,
|
||||
length=file_size, # 指定文件大小
|
||||
content_type=file.content_type
|
||||
)
|
||||
except MinioException as e:
|
||||
logger.error(f"Failed to upload file to MinIO: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to upload file")
|
||||
|
||||
# 4. 创建上传记录
|
||||
upload = DocumentUpload(
|
||||
knowledge_base_id=kb_id,
|
||||
file_name=file.filename,
|
||||
file_hash=file_hash,
|
||||
file_size=len(file_content),
|
||||
content_type=file.content_type,
|
||||
temp_path=temp_path
|
||||
)
|
||||
db.add(upload)
|
||||
db.commit()
|
||||
db.refresh(upload)
|
||||
|
||||
results.append({
|
||||
"upload_id": upload.id,
|
||||
"file_name": file.filename,
|
||||
"temp_path": temp_path,
|
||||
"status": "pending",
|
||||
"skip_processing": False
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
@router.post("/{kb_id}/documents/preview")
|
||||
async def preview_kb_documents(
|
||||
kb_id: int,
|
||||
preview_request: PreviewRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Dict[int, PreviewResult]:
|
||||
"""
|
||||
Preview multiple documents' chunks.
|
||||
"""
|
||||
results = {}
|
||||
for doc_id in preview_request.document_ids:
|
||||
document = db.query(Document).join(KnowledgeBase).filter(
|
||||
Document.id == doc_id,
|
||||
Document.knowledge_base_id == kb_id,
|
||||
KnowledgeBase.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if document:
|
||||
file_path = document.file_path
|
||||
else:
|
||||
upload = db.query(DocumentUpload).join(KnowledgeBase).filter(
|
||||
DocumentUpload.id == doc_id,
|
||||
DocumentUpload.knowledge_base_id == kb_id,
|
||||
KnowledgeBase.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not upload:
|
||||
raise HTTPException(status_code=404, detail=f"Document {doc_id} not found")
|
||||
|
||||
file_path = upload.temp_path
|
||||
|
||||
preview = await preview_document(
|
||||
file_path,
|
||||
chunk_size=preview_request.chunk_size,
|
||||
chunk_overlap=preview_request.chunk_overlap
|
||||
)
|
||||
results[doc_id] = preview
|
||||
|
||||
return results
|
||||
|
||||
@router.post("/{kb_id}/documents/process")
|
||||
async def process_kb_documents(
|
||||
kb_id: int,
|
||||
upload_results: List[dict],
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Process multiple documents asynchronously.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
kb = db.query(KnowledgeBase).filter(
|
||||
KnowledgeBase.id == kb_id,
|
||||
KnowledgeBase.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
|
||||
task_info = []
|
||||
upload_ids = []
|
||||
|
||||
for result in upload_results:
|
||||
if result.get("skip_processing"):
|
||||
continue
|
||||
upload_ids.append(result["upload_id"])
|
||||
|
||||
if not upload_ids:
|
||||
return {"tasks": []}
|
||||
|
||||
uploads = db.query(DocumentUpload).filter(DocumentUpload.id.in_(upload_ids)).all()
|
||||
uploads_dict = {upload.id: upload for upload in uploads}
|
||||
|
||||
all_tasks = []
|
||||
for upload_id in upload_ids:
|
||||
upload = uploads_dict.get(upload_id)
|
||||
if not upload:
|
||||
continue
|
||||
|
||||
task = ProcessingTask(
|
||||
document_upload_id=upload_id,
|
||||
knowledge_base_id=kb_id,
|
||||
status="pending"
|
||||
)
|
||||
all_tasks.append(task)
|
||||
|
||||
db.add_all(all_tasks)
|
||||
db.commit()
|
||||
|
||||
for task in all_tasks:
|
||||
db.refresh(task)
|
||||
|
||||
task_data = []
|
||||
for i, upload_id in enumerate(upload_ids):
|
||||
if i < len(all_tasks):
|
||||
task = all_tasks[i]
|
||||
upload = uploads_dict.get(upload_id)
|
||||
|
||||
task_info.append({
|
||||
"upload_id": upload_id,
|
||||
"task_id": task.id
|
||||
})
|
||||
|
||||
if upload:
|
||||
task_data.append({
|
||||
"task_id": task.id,
|
||||
"upload_id": upload_id,
|
||||
"temp_path": upload.temp_path,
|
||||
"file_name": upload.file_name
|
||||
})
|
||||
|
||||
background_tasks.add_task(
|
||||
add_processing_tasks_to_queue,
|
||||
task_data,
|
||||
kb_id
|
||||
)
|
||||
|
||||
return {"tasks": task_info}
|
||||
|
||||
async def add_processing_tasks_to_queue(task_data, kb_id):
|
||||
"""Helper function to add document processing tasks to the queue without blocking the main response."""
|
||||
for data in task_data:
|
||||
asyncio.create_task(
|
||||
process_document_background(
|
||||
data["temp_path"],
|
||||
data["file_name"],
|
||||
kb_id,
|
||||
data["task_id"],
|
||||
None
|
||||
)
|
||||
)
|
||||
logger.info(f"Added {len(task_data)} document processing tasks to queue")
|
||||
|
||||
@router.post("/cleanup")
|
||||
async def cleanup_temp_files(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Clean up expired temporary files.
|
||||
"""
|
||||
expired_time = datetime.utcnow() - timedelta(hours=24)
|
||||
expired_uploads = db.query(DocumentUpload).filter(
|
||||
DocumentUpload.created_at < expired_time
|
||||
).all()
|
||||
|
||||
minio_client = get_minio_client()
|
||||
for upload in expired_uploads:
|
||||
try:
|
||||
minio_client.remove_object(
|
||||
bucket_name=settings.MINIO_BUCKET_NAME,
|
||||
object_name=upload.temp_path
|
||||
)
|
||||
except MinioException as e:
|
||||
logger.error(f"Failed to delete temp file {upload.temp_path}: {str(e)}")
|
||||
|
||||
db.delete(upload)
|
||||
|
||||
db.commit()
|
||||
|
||||
return {"message": f"Cleaned up {len(expired_uploads)} expired uploads"}
|
||||
|
||||
@router.get("/{kb_id}/documents/tasks")
|
||||
async def get_processing_tasks(
|
||||
kb_id: int,
|
||||
task_ids: str = Query(..., description="Comma-separated list of task IDs to check status for"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get status of multiple processing tasks.
|
||||
"""
|
||||
task_id_list = [int(id.strip()) for id in task_ids.split(",")]
|
||||
|
||||
kb = db.query(KnowledgeBase).filter(
|
||||
KnowledgeBase.id == kb_id,
|
||||
KnowledgeBase.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
|
||||
tasks = (
|
||||
db.query(ProcessingTask)
|
||||
.options(
|
||||
selectinload(ProcessingTask.document_upload)
|
||||
)
|
||||
.filter(
|
||||
ProcessingTask.id.in_(task_id_list),
|
||||
ProcessingTask.knowledge_base_id == kb_id
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
task.id: {
|
||||
"document_id": task.document_id,
|
||||
"status": task.status,
|
||||
"error_message": task.error_message,
|
||||
"upload_id": task.document_upload_id,
|
||||
"file_name": task.document_upload.file_name if task.document_upload else None
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse)
|
||||
async def get_document(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
kb_id: int,
|
||||
doc_id: int,
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
"""
|
||||
Get document details by ID.
|
||||
"""
|
||||
document = (
|
||||
db.query(Document)
|
||||
.join(KnowledgeBase)
|
||||
.filter(
|
||||
Document.id == doc_id,
|
||||
Document.knowledge_base_id == kb_id,
|
||||
KnowledgeBase.user_id == current_user.id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return document
|
||||
|
||||
@router.post("/test-retrieval")
|
||||
async def test_retrieval(
|
||||
request: TestRetrievalRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
"""
|
||||
Test retrieval quality for a given query against a knowledge base.
|
||||
"""
|
||||
try:
|
||||
kb = db.query(KnowledgeBase).filter(
|
||||
KnowledgeBase.id == request.kb_id,
|
||||
KnowledgeBase.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not kb:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Knowledge base {request.kb_id} not found",
|
||||
)
|
||||
|
||||
embeddings = EmbeddingsFactory.create()
|
||||
|
||||
vector_store = VectorStoreFactory.create(
|
||||
store_type=settings.VECTOR_STORE_TYPE,
|
||||
collection_name=f"kb_{request.kb_id}",
|
||||
embedding_function=embeddings,
|
||||
)
|
||||
|
||||
results = vector_store.similarity_search_with_score(request.query, k=request.top_k)
|
||||
|
||||
response = []
|
||||
for doc, score in results:
|
||||
response.append({
|
||||
"content": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
"score": float(score)
|
||||
})
|
||||
|
||||
return {"results": response}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
84
rag-web-ui/backend/app/api/api_v1/testing.py
Normal file
84
rag-web-ui/backend/app/api/api_v1/testing.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.security import get_current_user
|
||||
from app.db.session import get_db
|
||||
from app.models.knowledge import Document, KnowledgeBase
|
||||
from app.models.user import User
|
||||
from app.schemas.testing import TestingPipelineRequest, TestingPipelineResponse
|
||||
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
||||
from app.services.retrieval.multi_kb_retriever import MultiKBRetriever, format_retrieval_context
|
||||
from app.services.testing_pipeline import run_testing_pipeline
|
||||
from app.services.vector_store import VectorStoreFactory
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _build_kb_vector_stores(db: Session, knowledge_bases: List[KnowledgeBase]) -> List[Dict[str, Any]]:
|
||||
embeddings = EmbeddingsFactory.create()
|
||||
kb_vector_stores: List[Dict[str, Any]] = []
|
||||
|
||||
for kb in knowledge_bases:
|
||||
documents = db.query(Document).filter(Document.knowledge_base_id == kb.id).all()
|
||||
if not documents:
|
||||
continue
|
||||
|
||||
store = VectorStoreFactory.create(
|
||||
store_type=settings.VECTOR_STORE_TYPE,
|
||||
collection_name=f"kb_{kb.id}",
|
||||
embedding_function=embeddings,
|
||||
)
|
||||
kb_vector_stores.append({"kb_id": kb.id, "store": store})
|
||||
|
||||
return kb_vector_stores
|
||||
|
||||
|
||||
@router.post("/generate", response_model=TestingPipelineResponse)
|
||||
async def generate_testing_content(
|
||||
*,
|
||||
payload: TestingPipelineRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
_ = current_user
|
||||
|
||||
knowledge_context = (payload.knowledge_context or "").strip()
|
||||
if payload.knowledge_base_ids:
|
||||
knowledge_bases = (
|
||||
db.query(KnowledgeBase)
|
||||
.filter(
|
||||
KnowledgeBase.id.in_(payload.knowledge_base_ids),
|
||||
KnowledgeBase.user_id == current_user.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
kb_vector_stores = await _build_kb_vector_stores(db, knowledge_bases)
|
||||
if kb_vector_stores:
|
||||
retriever = MultiKBRetriever(
|
||||
reranker_weight=settings.RERANKER_WEIGHT,
|
||||
)
|
||||
retrieval_rows = await retriever.retrieve(
|
||||
query=payload.requirement_text,
|
||||
kb_vector_stores=kb_vector_stores,
|
||||
fetch_k_per_kb=max(12, payload.retrieval_top_k * 2),
|
||||
top_k=payload.retrieval_top_k,
|
||||
)
|
||||
if retrieval_rows:
|
||||
knowledge_context = format_retrieval_context(retrieval_rows)
|
||||
|
||||
result = run_testing_pipeline(
|
||||
user_requirement_text=payload.requirement_text,
|
||||
requirement_type_input=payload.requirement_type,
|
||||
debug=payload.debug,
|
||||
knowledge_context=knowledge_context,
|
||||
use_model_generation=payload.use_model_generation,
|
||||
max_items_per_group=payload.max_items_per_group,
|
||||
cases_per_item=payload.cases_per_item,
|
||||
max_focus_points=payload.max_focus_points,
|
||||
max_llm_calls=payload.max_llm_calls,
|
||||
)
|
||||
return result
|
||||
175
rag-web-ui/backend/app/api/api_v1/tools.py
Normal file
175
rag-web-ui/backend/app/api/api_v1/tools.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.db.session import get_db
|
||||
from app.models.tooling import SRSExtraction, ToolJob
|
||||
from app.models.user import User
|
||||
from app.schemas.tooling import (
|
||||
SRSToolCreateJobResponse,
|
||||
SRSToolJobStatusResponse,
|
||||
SRSToolRequirementsSaveRequest,
|
||||
SRSToolResultResponse,
|
||||
ToolDefinitionResponse,
|
||||
)
|
||||
from app.services.srs_job_service import (
|
||||
build_result_response,
|
||||
ensure_upload_path,
|
||||
replace_requirements,
|
||||
run_srs_job,
|
||||
)
|
||||
from app.tools.registry import ToolRegistry
|
||||
from app.tools.srs_reqs_qwen import get_srs_tool
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Register SRS tool when the router is imported.
|
||||
get_srs_tool()
|
||||
|
||||
ALLOWED_EXTENSIONS = {".pdf", ".docx"}
|
||||
|
||||
|
||||
@router.get("", response_model=List[ToolDefinitionResponse])
|
||||
async def list_tools(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
_ = current_user
|
||||
return ToolRegistry.list()
|
||||
|
||||
|
||||
@router.post("/srs/jobs", response_model=SRSToolCreateJobResponse)
|
||||
async def create_srs_job(
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
safe_name = Path(file.filename or "").name
|
||||
ext = Path(safe_name).suffix.lower()
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(status_code=400, detail="仅支持 .pdf/.docx 文件")
|
||||
|
||||
job = ToolJob(
|
||||
user_id=current_user.id,
|
||||
tool_name="srs.requirement_extractor",
|
||||
status="pending",
|
||||
input_file_name=safe_name,
|
||||
input_file_path="",
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
target_path = ensure_upload_path(job.id, safe_name)
|
||||
try:
|
||||
content = await file.read()
|
||||
target_path.write_bytes(content)
|
||||
except Exception as exc:
|
||||
job.status = "failed"
|
||||
job.error_message = f"保存上传文件失败: {exc}"
|
||||
db.add(job)
|
||||
db.commit()
|
||||
raise HTTPException(status_code=500, detail="上传文件保存失败")
|
||||
|
||||
job.input_file_path = str(target_path.resolve())
|
||||
db.add(job)
|
||||
db.commit()
|
||||
|
||||
background_tasks.add_task(run_srs_job, job.id)
|
||||
|
||||
return {
|
||||
"job_id": job.id,
|
||||
"status": job.status,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/srs/jobs/{job_id}", response_model=SRSToolJobStatusResponse)
|
||||
async def get_srs_job_status(
|
||||
job_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
job = (
|
||||
db.query(ToolJob)
|
||||
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
extraction = (
|
||||
db.query(SRSExtraction)
|
||||
.filter(SRSExtraction.job_id == job.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
return {
|
||||
"job_id": job.id,
|
||||
"tool_name": job.tool_name,
|
||||
"status": job.status,
|
||||
"error_message": job.error_message,
|
||||
"extraction_id": extraction.id if extraction else None,
|
||||
"started_at": job.started_at,
|
||||
"completed_at": job.completed_at,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/srs/jobs/{job_id}/result", response_model=SRSToolResultResponse)
|
||||
async def get_srs_job_result(
|
||||
job_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
job = (
|
||||
db.query(ToolJob)
|
||||
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
if job.status != "completed":
|
||||
raise HTTPException(status_code=409, detail="任务尚未完成")
|
||||
|
||||
extraction = (
|
||||
db.query(SRSExtraction)
|
||||
.filter(SRSExtraction.job_id == job.id)
|
||||
.first()
|
||||
)
|
||||
if not extraction:
|
||||
raise HTTPException(status_code=404, detail="任务结果不存在")
|
||||
|
||||
return build_result_response(job, extraction)
|
||||
|
||||
|
||||
@router.put("/srs/jobs/{job_id}/requirements", response_model=SRSToolResultResponse)
|
||||
async def save_srs_requirements(
|
||||
job_id: int,
|
||||
payload: SRSToolRequirementsSaveRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
job = (
|
||||
db.query(ToolJob)
|
||||
.filter(ToolJob.id == job_id, ToolJob.user_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
extraction = (
|
||||
db.query(SRSExtraction)
|
||||
.filter(SRSExtraction.job_id == job.id)
|
||||
.first()
|
||||
)
|
||||
if not extraction:
|
||||
raise HTTPException(status_code=404, detail="任务结果不存在")
|
||||
|
||||
replace_requirements(db, extraction, [item.dict() for item in payload.requirements])
|
||||
db.add(extraction)
|
||||
db.commit()
|
||||
db.refresh(extraction)
|
||||
|
||||
return build_result_response(job, extraction)
|
||||
Reference in New Issue
Block a user