Files
rag_agent/rag-web-ui/backend/app/services/model_config.py

212 lines
7.9 KiB
Python
Raw Normal View History

from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from app.core.config import settings
from app.models.model_config import UserModelConfig
from app.schemas.model_config import ModelConfigCreate, ModelConfigUpdate
PROVIDER_OPTIONS: List[Dict[str, Any]] = [
{
"provider": "dashscope",
"label": "DashScope",
"default_api_base": settings.DASH_SCOPE_API_BASE,
"default_chat_model": settings.DASH_SCOPE_CHAT_MODEL or "qwen3-max",
"default_embedding_model": settings.DASH_SCOPE_EMBEDDINGS_MODEL or "text-embedding-v4",
"chat_models": ["qwen3-max", "qwen-plus", "qwen-turbo", "qwen-max"],
"embedding_models": ["text-embedding-v4", "text-embedding-v3", "text-embedding-v2"],
"requires_api_key": True,
"supports_custom_api_base": True,
},
{
"provider": "openai",
"label": "OpenAI",
"default_api_base": settings.OPENAI_API_BASE,
"default_chat_model": settings.OPENAI_MODEL or "gpt-4o",
"default_embedding_model": settings.OPENAI_EMBEDDINGS_MODEL or "text-embedding-3-small",
"chat_models": ["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"],
"embedding_models": ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
"requires_api_key": True,
"supports_custom_api_base": True,
},
{
"provider": "openai_compatible",
"label": "OpenAI Compatible",
"default_api_base": "",
"default_chat_model": "qwen3-max",
"default_embedding_model": "text-embedding-v4",
"chat_models": ["qwen3-max", "deepseek-chat", "gpt-4o-mini"],
"embedding_models": ["text-embedding-v4", "text-embedding-3-small"],
"requires_api_key": True,
"supports_custom_api_base": True,
},
{
"provider": "ollama",
"label": "Ollama",
"default_api_base": settings.OLLAMA_API_BASE,
"default_chat_model": settings.OLLAMA_MODEL,
"default_embedding_model": settings.OLLAMA_EMBEDDINGS_MODEL,
"chat_models": [settings.OLLAMA_MODEL, "llama3.1", "qwen2.5", "deepseek-r1:7b"],
"embedding_models": [settings.OLLAMA_EMBEDDINGS_MODEL, "nomic-embed-text", "mxbai-embed-large"],
"requires_api_key": False,
"supports_custom_api_base": True,
},
]
def provider_options_response() -> Dict[str, Any]:
first = PROVIDER_OPTIONS[0]
return {
"providers": PROVIDER_OPTIONS,
"defaults": {
"provider": first["provider"],
"api_base": first["default_api_base"],
"chat_model": first["default_chat_model"],
"embedding_model": first["default_embedding_model"],
},
}
def _provider_option(provider: str) -> Dict[str, Any]:
normalized = (provider or "dashscope").strip().lower()
for option in PROVIDER_OPTIONS:
if option["provider"] == normalized:
return option
raise ValueError(f"Unsupported model provider: {provider}")
def _normalized_payload(payload: Dict[str, Any], existing: Optional[UserModelConfig] = None) -> Dict[str, Any]:
provider = str(payload.get("provider") or getattr(existing, "provider", "dashscope")).strip().lower()
option = _provider_option(provider)
api_base = payload.get("api_base")
if api_base is None and existing is not None:
api_base = existing.api_base
if not api_base:
api_base = option["default_api_base"]
if option["supports_custom_api_base"] and provider == "openai_compatible" and not api_base:
raise ValueError("OpenAI Compatible provider requires an API base URL.")
chat_model = str(payload.get("chat_model") or getattr(existing, "chat_model", "") or option["default_chat_model"]).strip()
embedding_model = str(
payload.get("embedding_model")
or getattr(existing, "embedding_model", "")
or option["default_embedding_model"]
).strip()
if not chat_model:
raise ValueError("Chat model is required.")
if not embedding_model:
raise ValueError("Embedding model is required.")
api_key = payload.get("api_key")
if api_key is None and existing is not None:
api_key = existing.api_key
api_key = str(api_key or "").strip()
if option["requires_api_key"] and not api_key:
raise ValueError("API key is required for this provider.")
name = payload.get("name")
if name is None and existing is not None:
name = existing.name
name = str(name or "").strip()
if not name:
name = option["label"]
return {
"name": name,
"provider": provider,
"api_key": api_key,
"api_base": str(api_base).strip() if api_base else None,
"chat_model": chat_model,
"embedding_model": embedding_model,
"is_active": bool(payload.get("is_active", getattr(existing, "is_active", True))),
}
class ModelConfigService:
@staticmethod
def list_configs(db: Session, user_id: int) -> List[UserModelConfig]:
return (
db.query(UserModelConfig)
.filter(UserModelConfig.user_id == user_id)
.order_by(UserModelConfig.is_active.desc(), UserModelConfig.updated_at.desc())
.all()
)
@staticmethod
def get_config(db: Session, user_id: int, config_id: int) -> Optional[UserModelConfig]:
return (
db.query(UserModelConfig)
.filter(UserModelConfig.id == config_id, UserModelConfig.user_id == user_id)
.first()
)
@staticmethod
def get_active_config(db: Session, user_id: int) -> Optional[UserModelConfig]:
return (
db.query(UserModelConfig)
.filter(UserModelConfig.user_id == user_id, UserModelConfig.is_active.is_(True))
.order_by(UserModelConfig.updated_at.desc())
.first()
)
@staticmethod
def require_active_config(db: Session, user_id: int) -> UserModelConfig:
config = ModelConfigService.get_active_config(db, user_id)
if config is None:
raise ValueError("请先在 API 密钥页面新增并启用模型配置。")
return config
@staticmethod
def create_config(db: Session, user_id: int, payload: ModelConfigCreate) -> UserModelConfig:
data = _normalized_payload(payload.model_dump())
if data["is_active"]:
ModelConfigService._deactivate_user_configs(db, user_id)
item = UserModelConfig(user_id=user_id, **data)
db.add(item)
db.commit()
db.refresh(item)
return item
@staticmethod
def update_config(
db: Session,
item: UserModelConfig,
payload: ModelConfigUpdate,
) -> UserModelConfig:
raw = payload.model_dump(exclude_unset=True)
if raw.get("api_key") == "":
raw.pop("api_key")
data = _normalized_payload(raw, existing=item)
if data["is_active"]:
ModelConfigService._deactivate_user_configs(db, item.user_id, exclude_id=item.id)
for field, value in data.items():
setattr(item, field, value)
db.add(item)
db.commit()
db.refresh(item)
return item
@staticmethod
def delete_config(db: Session, item: UserModelConfig) -> None:
db.delete(item)
db.commit()
@staticmethod
def touch_last_used(db: Session, item: UserModelConfig) -> UserModelConfig:
item.last_used_at = datetime.utcnow()
db.add(item)
db.commit()
db.refresh(item)
return item
@staticmethod
def _deactivate_user_configs(db: Session, user_id: int, exclude_id: Optional[int] = None) -> None:
query = db.query(UserModelConfig).filter(UserModelConfig.user_id == user_id)
if exclude_id is not None:
query = query.filter(UserModelConfig.id != exclude_id)
query.update({UserModelConfig.is_active: False}, synchronize_session=False)