212 lines
7.9 KiB
Python
212 lines
7.9 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from datetime import datetime
|
||
|
|
from typing import Any, Dict, List, Optional
|
||
|
|
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
|
||
|
|
from app.core.config import settings
|
||
|
|
from app.models.model_config import UserModelConfig
|
||
|
|
from app.schemas.model_config import ModelConfigCreate, ModelConfigUpdate
|
||
|
|
|
||
|
|
|
||
|
|
PROVIDER_OPTIONS: List[Dict[str, Any]] = [
|
||
|
|
{
|
||
|
|
"provider": "dashscope",
|
||
|
|
"label": "DashScope",
|
||
|
|
"default_api_base": settings.DASH_SCOPE_API_BASE,
|
||
|
|
"default_chat_model": settings.DASH_SCOPE_CHAT_MODEL or "qwen3-max",
|
||
|
|
"default_embedding_model": settings.DASH_SCOPE_EMBEDDINGS_MODEL or "text-embedding-v4",
|
||
|
|
"chat_models": ["qwen3-max", "qwen-plus", "qwen-turbo", "qwen-max"],
|
||
|
|
"embedding_models": ["text-embedding-v4", "text-embedding-v3", "text-embedding-v2"],
|
||
|
|
"requires_api_key": True,
|
||
|
|
"supports_custom_api_base": True,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"provider": "openai",
|
||
|
|
"label": "OpenAI",
|
||
|
|
"default_api_base": settings.OPENAI_API_BASE,
|
||
|
|
"default_chat_model": settings.OPENAI_MODEL or "gpt-4o",
|
||
|
|
"default_embedding_model": settings.OPENAI_EMBEDDINGS_MODEL or "text-embedding-3-small",
|
||
|
|
"chat_models": ["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"],
|
||
|
|
"embedding_models": ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
|
||
|
|
"requires_api_key": True,
|
||
|
|
"supports_custom_api_base": True,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"provider": "openai_compatible",
|
||
|
|
"label": "OpenAI Compatible",
|
||
|
|
"default_api_base": "",
|
||
|
|
"default_chat_model": "qwen3-max",
|
||
|
|
"default_embedding_model": "text-embedding-v4",
|
||
|
|
"chat_models": ["qwen3-max", "deepseek-chat", "gpt-4o-mini"],
|
||
|
|
"embedding_models": ["text-embedding-v4", "text-embedding-3-small"],
|
||
|
|
"requires_api_key": True,
|
||
|
|
"supports_custom_api_base": True,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"provider": "ollama",
|
||
|
|
"label": "Ollama",
|
||
|
|
"default_api_base": settings.OLLAMA_API_BASE,
|
||
|
|
"default_chat_model": settings.OLLAMA_MODEL,
|
||
|
|
"default_embedding_model": settings.OLLAMA_EMBEDDINGS_MODEL,
|
||
|
|
"chat_models": [settings.OLLAMA_MODEL, "llama3.1", "qwen2.5", "deepseek-r1:7b"],
|
||
|
|
"embedding_models": [settings.OLLAMA_EMBEDDINGS_MODEL, "nomic-embed-text", "mxbai-embed-large"],
|
||
|
|
"requires_api_key": False,
|
||
|
|
"supports_custom_api_base": True,
|
||
|
|
},
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
def provider_options_response() -> Dict[str, Any]:
|
||
|
|
first = PROVIDER_OPTIONS[0]
|
||
|
|
return {
|
||
|
|
"providers": PROVIDER_OPTIONS,
|
||
|
|
"defaults": {
|
||
|
|
"provider": first["provider"],
|
||
|
|
"api_base": first["default_api_base"],
|
||
|
|
"chat_model": first["default_chat_model"],
|
||
|
|
"embedding_model": first["default_embedding_model"],
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def _provider_option(provider: str) -> Dict[str, Any]:
|
||
|
|
normalized = (provider or "dashscope").strip().lower()
|
||
|
|
for option in PROVIDER_OPTIONS:
|
||
|
|
if option["provider"] == normalized:
|
||
|
|
return option
|
||
|
|
raise ValueError(f"Unsupported model provider: {provider}")
|
||
|
|
|
||
|
|
|
||
|
|
def _normalized_payload(payload: Dict[str, Any], existing: Optional[UserModelConfig] = None) -> Dict[str, Any]:
|
||
|
|
provider = str(payload.get("provider") or getattr(existing, "provider", "dashscope")).strip().lower()
|
||
|
|
option = _provider_option(provider)
|
||
|
|
api_base = payload.get("api_base")
|
||
|
|
if api_base is None and existing is not None:
|
||
|
|
api_base = existing.api_base
|
||
|
|
if not api_base:
|
||
|
|
api_base = option["default_api_base"]
|
||
|
|
if option["supports_custom_api_base"] and provider == "openai_compatible" and not api_base:
|
||
|
|
raise ValueError("OpenAI Compatible provider requires an API base URL.")
|
||
|
|
|
||
|
|
chat_model = str(payload.get("chat_model") or getattr(existing, "chat_model", "") or option["default_chat_model"]).strip()
|
||
|
|
embedding_model = str(
|
||
|
|
payload.get("embedding_model")
|
||
|
|
or getattr(existing, "embedding_model", "")
|
||
|
|
or option["default_embedding_model"]
|
||
|
|
).strip()
|
||
|
|
if not chat_model:
|
||
|
|
raise ValueError("Chat model is required.")
|
||
|
|
if not embedding_model:
|
||
|
|
raise ValueError("Embedding model is required.")
|
||
|
|
|
||
|
|
api_key = payload.get("api_key")
|
||
|
|
if api_key is None and existing is not None:
|
||
|
|
api_key = existing.api_key
|
||
|
|
api_key = str(api_key or "").strip()
|
||
|
|
if option["requires_api_key"] and not api_key:
|
||
|
|
raise ValueError("API key is required for this provider.")
|
||
|
|
|
||
|
|
name = payload.get("name")
|
||
|
|
if name is None and existing is not None:
|
||
|
|
name = existing.name
|
||
|
|
name = str(name or "").strip()
|
||
|
|
if not name:
|
||
|
|
name = option["label"]
|
||
|
|
|
||
|
|
return {
|
||
|
|
"name": name,
|
||
|
|
"provider": provider,
|
||
|
|
"api_key": api_key,
|
||
|
|
"api_base": str(api_base).strip() if api_base else None,
|
||
|
|
"chat_model": chat_model,
|
||
|
|
"embedding_model": embedding_model,
|
||
|
|
"is_active": bool(payload.get("is_active", getattr(existing, "is_active", True))),
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
class ModelConfigService:
|
||
|
|
@staticmethod
|
||
|
|
def list_configs(db: Session, user_id: int) -> List[UserModelConfig]:
|
||
|
|
return (
|
||
|
|
db.query(UserModelConfig)
|
||
|
|
.filter(UserModelConfig.user_id == user_id)
|
||
|
|
.order_by(UserModelConfig.is_active.desc(), UserModelConfig.updated_at.desc())
|
||
|
|
.all()
|
||
|
|
)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_config(db: Session, user_id: int, config_id: int) -> Optional[UserModelConfig]:
|
||
|
|
return (
|
||
|
|
db.query(UserModelConfig)
|
||
|
|
.filter(UserModelConfig.id == config_id, UserModelConfig.user_id == user_id)
|
||
|
|
.first()
|
||
|
|
)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_active_config(db: Session, user_id: int) -> Optional[UserModelConfig]:
|
||
|
|
return (
|
||
|
|
db.query(UserModelConfig)
|
||
|
|
.filter(UserModelConfig.user_id == user_id, UserModelConfig.is_active.is_(True))
|
||
|
|
.order_by(UserModelConfig.updated_at.desc())
|
||
|
|
.first()
|
||
|
|
)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def require_active_config(db: Session, user_id: int) -> UserModelConfig:
|
||
|
|
config = ModelConfigService.get_active_config(db, user_id)
|
||
|
|
if config is None:
|
||
|
|
raise ValueError("请先在 API 密钥页面新增并启用模型配置。")
|
||
|
|
return config
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def create_config(db: Session, user_id: int, payload: ModelConfigCreate) -> UserModelConfig:
|
||
|
|
data = _normalized_payload(payload.model_dump())
|
||
|
|
if data["is_active"]:
|
||
|
|
ModelConfigService._deactivate_user_configs(db, user_id)
|
||
|
|
item = UserModelConfig(user_id=user_id, **data)
|
||
|
|
db.add(item)
|
||
|
|
db.commit()
|
||
|
|
db.refresh(item)
|
||
|
|
return item
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def update_config(
|
||
|
|
db: Session,
|
||
|
|
item: UserModelConfig,
|
||
|
|
payload: ModelConfigUpdate,
|
||
|
|
) -> UserModelConfig:
|
||
|
|
raw = payload.model_dump(exclude_unset=True)
|
||
|
|
if raw.get("api_key") == "":
|
||
|
|
raw.pop("api_key")
|
||
|
|
data = _normalized_payload(raw, existing=item)
|
||
|
|
if data["is_active"]:
|
||
|
|
ModelConfigService._deactivate_user_configs(db, item.user_id, exclude_id=item.id)
|
||
|
|
for field, value in data.items():
|
||
|
|
setattr(item, field, value)
|
||
|
|
db.add(item)
|
||
|
|
db.commit()
|
||
|
|
db.refresh(item)
|
||
|
|
return item
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def delete_config(db: Session, item: UserModelConfig) -> None:
|
||
|
|
db.delete(item)
|
||
|
|
db.commit()
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def touch_last_used(db: Session, item: UserModelConfig) -> UserModelConfig:
|
||
|
|
item.last_used_at = datetime.utcnow()
|
||
|
|
db.add(item)
|
||
|
|
db.commit()
|
||
|
|
db.refresh(item)
|
||
|
|
return item
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _deactivate_user_configs(db: Session, user_id: int, exclude_id: Optional[int] = None) -> None:
|
||
|
|
query = db.query(UserModelConfig).filter(UserModelConfig.user_id == user_id)
|
||
|
|
if exclude_id is not None:
|
||
|
|
query = query.filter(UserModelConfig.id != exclude_id)
|
||
|
|
query.update({UserModelConfig.is_active: False}, synchronize_session=False)
|