init. project
This commit is contained in:
0
rag-web-ui/backend/app/core/__init__.py
Normal file
0
rag-web-ui/backend/app/core/__init__.py
Normal file
123
rag-web-ui/backend/app/core/config.py
Normal file
123
rag-web-ui/backend/app/core/config.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "RAG Web UI" # Project name
|
||||
VERSION: str = "0.1.0" # Project version
|
||||
API_V1_STR: str = "/api" # API version string
|
||||
|
||||
# MySQL settings
|
||||
MYSQL_SERVER: str = os.getenv("MYSQL_SERVER", "localhost")
|
||||
MYSQL_PORT: int = int(os.getenv("MYSQL_PORT", "3306"))
|
||||
MYSQL_USER: str = os.getenv("MYSQL_USER", "ragagent")
|
||||
MYSQL_PASSWORD: str = os.getenv("MYSQL_PASSWORD", "ragagent")
|
||||
MYSQL_DATABASE: str = os.getenv("MYSQL_DATABASE", "ragagent")
|
||||
SQLALCHEMY_DATABASE_URI: Optional[str] = None
|
||||
|
||||
@property
|
||||
def get_database_url(self) -> str:
|
||||
if self.SQLALCHEMY_DATABASE_URI:
|
||||
return self.SQLALCHEMY_DATABASE_URI
|
||||
return (
|
||||
f"mysql+mysqlconnector://{self.MYSQL_USER}:{self.MYSQL_PASSWORD}"
|
||||
f"@{self.MYSQL_SERVER}:{self.MYSQL_PORT}/{self.MYSQL_DATABASE}"
|
||||
)
|
||||
|
||||
# JWT settings
|
||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-here")
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "10080"))
|
||||
|
||||
# Chat Provider settings
|
||||
CHAT_PROVIDER: str = os.getenv("CHAT_PROVIDER", "openai")
|
||||
|
||||
# Embeddings settings
|
||||
EMBEDDINGS_PROVIDER: str = os.getenv("EMBEDDINGS_PROVIDER", "openai")
|
||||
|
||||
# MinIO settings
|
||||
MINIO_ENDPOINT: str = os.getenv("MINIO_ENDPOINT", "localhost:9000")
|
||||
MINIO_ACCESS_KEY: str = os.getenv("MINIO_ACCESS_KEY", "minioadmin")
|
||||
MINIO_SECRET_KEY: str = os.getenv("MINIO_SECRET_KEY", "minioadmin")
|
||||
MINIO_BUCKET_NAME: str = os.getenv("MINIO_BUCKET_NAME", "documents")
|
||||
|
||||
# Shared model API key fallback
|
||||
API_KEY: str = os.getenv("API_KEY", "")
|
||||
|
||||
# OpenAI settings
|
||||
OPENAI_API_BASE: str = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
|
||||
OPENAI_API_KEY: str = os.getenv(
|
||||
"OPENAI_API_KEY", os.getenv("API_KEY", "your-openai-api-key-here")
|
||||
)
|
||||
OPENAI_MODEL: str = os.getenv("OPENAI_MODEL", "gpt-4")
|
||||
OPENAI_EMBEDDINGS_MODEL: str = os.getenv("OPENAI_EMBEDDINGS_MODEL", "text-embedding-ada-002")
|
||||
|
||||
# DashScope settings
|
||||
DASH_SCOPE_API_KEY: str = os.getenv(
|
||||
"DASH_SCOPE_API_KEY",
|
||||
os.getenv("DASHSCOPE_API_KEY", os.getenv("API_KEY", "")),
|
||||
)
|
||||
DASH_SCOPE_API_BASE: str = os.getenv(
|
||||
"DASH_SCOPE_API_BASE", "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
)
|
||||
DASH_SCOPE_CHAT_MODEL: str = os.getenv("DASH_SCOPE_CHAT_MODEL", "qwen3-max")
|
||||
DASH_SCOPE_EMBEDDINGS_MODEL: str = os.getenv("DASH_SCOPE_EMBEDDINGS_MODEL", "")
|
||||
|
||||
# Vector Store settings
|
||||
VECTOR_STORE_TYPE: str = os.getenv("VECTOR_STORE_TYPE", "chroma")
|
||||
|
||||
# External reranker settings
|
||||
RERANKER_API_URL: str = os.getenv("RERANKER_API_URL", "")
|
||||
RERANKER_API_KEY: str = os.getenv(
|
||||
"RERANKER_API_KEY",
|
||||
os.getenv(
|
||||
"DASH_SCOPE_API_KEY",
|
||||
os.getenv("DASHSCOPE_API_KEY", os.getenv("API_KEY", "")),
|
||||
),
|
||||
)
|
||||
RERANKER_MODEL: str = os.getenv("RERANKER_MODEL", "")
|
||||
RERANKER_TIMEOUT_SECONDS: float = float(os.getenv("RERANKER_TIMEOUT_SECONDS", "8"))
|
||||
RERANKER_WEIGHT: float = float(os.getenv("RERANKER_WEIGHT", "0.75"))
|
||||
|
||||
# GraphRAG settings
|
||||
GRAPHRAG_ENABLED: bool = os.getenv("GRAPHRAG_ENABLED", "false").lower() == "true"
|
||||
GRAPHRAG_WORKING_DIR: str = os.getenv("GRAPHRAG_WORKING_DIR", "./graphrag_cache")
|
||||
GRAPHRAG_GRAPH_STORAGE: str = os.getenv("GRAPHRAG_GRAPH_STORAGE", "neo4j")
|
||||
GRAPHRAG_QUERY_LEVEL: int = int(os.getenv("GRAPHRAG_QUERY_LEVEL", "2"))
|
||||
GRAPHRAG_LOCAL_TOP_K: int = int(os.getenv("GRAPHRAG_LOCAL_TOP_K", "20"))
|
||||
GRAPHRAG_ENTITY_EXTRACT_MAX_GLEANING: int = int(os.getenv("GRAPHRAG_ENTITY_EXTRACT_MAX_GLEANING", "1"))
|
||||
GRAPHRAG_EMBEDDING_DIM: int = int(os.getenv("GRAPHRAG_EMBEDDING_DIM", "1024"))
|
||||
GRAPHRAG_EMBEDDING_MAX_TOKEN_SIZE: int = int(os.getenv("GRAPHRAG_EMBEDDING_MAX_TOKEN_SIZE", "8192"))
|
||||
|
||||
# Neo4j settings
|
||||
NEO4J_URL: str = os.getenv("NEO4J_URL", "bolt://localhost:7687")
|
||||
NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j")
|
||||
NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "neo4j")
|
||||
|
||||
# Chroma DB settings
|
||||
CHROMA_DB_HOST: str = os.getenv("CHROMA_DB_HOST", "chromadb")
|
||||
CHROMA_DB_PORT: int = int(os.getenv("CHROMA_DB_PORT", "8000"))
|
||||
|
||||
# Qdrant DB settings
|
||||
QDRANT_URL: str = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||
QDRANT_PREFER_GRPC: bool = os.getenv("QDRANT_PREFER_GRPC", "true").lower() == "true"
|
||||
|
||||
# Deepseek settings
|
||||
DEEPSEEK_API_KEY: str = ""
|
||||
DEEPSEEK_API_BASE: str = "https://api.deepseek.com/v1" # 默认 API 地址
|
||||
DEEPSEEK_MODEL: str = "deepseek-chat" # 默认模型名称
|
||||
|
||||
# Ollama settings
|
||||
OLLAMA_API_BASE: str = "http://localhost:11434"
|
||||
OLLAMA_MODEL: str = "deepseek-r1:7b"
|
||||
OLLAMA_EMBEDDINGS_MODEL: str = os.getenv(
|
||||
"OLLAMA_EMBEDDINGS_MODEL", "nomic-embed-text"
|
||||
) # Added this line
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
29
rag-web-ui/backend/app/core/minio.py
Normal file
29
rag-web-ui/backend/app/core/minio.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import logging
|
||||
from minio import Minio
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_minio_client() -> Minio:
|
||||
"""
|
||||
Get a MinIO client instance.
|
||||
"""
|
||||
logger.info("Creating MinIO client instance.")
|
||||
return Minio(
|
||||
settings.MINIO_ENDPOINT,
|
||||
access_key=settings.MINIO_ACCESS_KEY,
|
||||
secret_key=settings.MINIO_SECRET_KEY,
|
||||
secure=False # Set to True if using HTTPS
|
||||
)
|
||||
|
||||
def init_minio():
|
||||
"""
|
||||
Initialize MinIO by creating the bucket if it doesn't exist.
|
||||
"""
|
||||
client = get_minio_client()
|
||||
logger.info(f"Checking if bucket {settings.MINIO_BUCKET_NAME} exists.")
|
||||
if not client.bucket_exists(settings.MINIO_BUCKET_NAME):
|
||||
logger.info(f"Bucket {settings.MINIO_BUCKET_NAME} does not exist. Creating bucket.")
|
||||
client.make_bucket(settings.MINIO_BUCKET_NAME)
|
||||
else:
|
||||
logger.info(f"Bucket {settings.MINIO_BUCKET_NAME} already exists.")
|
||||
27
rag-web-ui/backend/app/core/runtime_checks.py
Normal file
27
rag-web-ui/backend/app/core/runtime_checks.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import logging
|
||||
|
||||
from app.core.config import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_runtime_settings(settings: Settings) -> None:
|
||||
errors = []
|
||||
|
||||
if settings.GRAPHRAG_ENABLED:
|
||||
if settings.GRAPHRAG_GRAPH_STORAGE.lower() not in {"neo4j", "networkx"}:
|
||||
errors.append("GRAPHRAG_GRAPH_STORAGE must be either 'neo4j' or 'networkx'.")
|
||||
|
||||
if settings.GRAPHRAG_GRAPH_STORAGE.lower() == "neo4j":
|
||||
if not settings.NEO4J_URL:
|
||||
errors.append("NEO4J_URL is required when GraphRAG Neo4j storage is enabled.")
|
||||
if not settings.NEO4J_USERNAME:
|
||||
errors.append("NEO4J_USERNAME is required when GraphRAG Neo4j storage is enabled.")
|
||||
if not settings.NEO4J_PASSWORD:
|
||||
errors.append("NEO4J_PASSWORD is required when GraphRAG Neo4j storage is enabled.")
|
||||
|
||||
if settings.RERANKER_API_URL and not settings.RERANKER_MODEL:
|
||||
logger.warning("RERANKER_API_URL is configured but RERANKER_MODEL is empty. The API may reject requests.")
|
||||
|
||||
if errors:
|
||||
raise ValueError("Runtime configuration validation failed: " + " | ".join(errors))
|
||||
84
rag-web-ui/backend/app/core/security.py
Normal file
84
rag-web-ui/backend/app/core/security.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from jose import JWTError, jwt
|
||||
import bcrypt
|
||||
from app.core.config import settings
|
||||
from fastapi import Depends, HTTPException, status, Security
|
||||
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.services.api_key import APIKeyService
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login/access-token")
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
def get_current_user(
|
||||
db: Session = Depends(get_db),
|
||||
token: str = Depends(oauth2_scheme)
|
||||
) -> User:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Inactive user",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return user
|
||||
|
||||
def get_api_key_user(
|
||||
db: Session = Depends(get_db),
|
||||
api_key: str = Security(api_key_header),
|
||||
) -> User:
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="API key header missing",
|
||||
)
|
||||
|
||||
api_key_obj = APIKeyService.get_api_key_by_key(db=db, key=api_key)
|
||||
if not api_key_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
)
|
||||
|
||||
if not api_key_obj.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Inactive API key",
|
||||
)
|
||||
|
||||
APIKeyService.update_last_used(db=db, api_key=api_key_obj)
|
||||
return api_key_obj.user
|
||||
Reference in New Issue
Block a user