from __future__ import annotations from datetime import datetime from typing import Any, Dict, List, Optional from sqlalchemy.orm import Session from app.core.config import settings from app.models.model_config import UserModelConfig from app.schemas.model_config import ModelConfigCreate, ModelConfigUpdate PROVIDER_OPTIONS: List[Dict[str, Any]] = [ { "provider": "dashscope", "label": "DashScope", "default_api_base": settings.DASH_SCOPE_API_BASE, "default_chat_model": settings.DASH_SCOPE_CHAT_MODEL or "qwen3-max", "default_embedding_model": settings.DASH_SCOPE_EMBEDDINGS_MODEL or "text-embedding-v4", "chat_models": ["qwen3-max", "qwen-plus", "qwen-turbo", "qwen-max"], "embedding_models": ["text-embedding-v4", "text-embedding-v3", "text-embedding-v2"], "requires_api_key": True, "supports_custom_api_base": True, }, { "provider": "openai", "label": "OpenAI", "default_api_base": settings.OPENAI_API_BASE, "default_chat_model": settings.OPENAI_MODEL or "gpt-4o", "default_embedding_model": settings.OPENAI_EMBEDDINGS_MODEL or "text-embedding-3-small", "chat_models": ["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"], "embedding_models": ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"], "requires_api_key": True, "supports_custom_api_base": True, }, { "provider": "openai_compatible", "label": "OpenAI Compatible", "default_api_base": "", "default_chat_model": "qwen3-max", "default_embedding_model": "text-embedding-v4", "chat_models": ["qwen3-max", "deepseek-chat", "gpt-4o-mini"], "embedding_models": ["text-embedding-v4", "text-embedding-3-small"], "requires_api_key": True, "supports_custom_api_base": True, }, { "provider": "ollama", "label": "Ollama", "default_api_base": settings.OLLAMA_API_BASE, "default_chat_model": settings.OLLAMA_MODEL, "default_embedding_model": settings.OLLAMA_EMBEDDINGS_MODEL, "chat_models": [settings.OLLAMA_MODEL, "llama3.1", "qwen2.5", "deepseek-r1:7b"], "embedding_models": [settings.OLLAMA_EMBEDDINGS_MODEL, "nomic-embed-text", "mxbai-embed-large"], "requires_api_key": False, "supports_custom_api_base": True, }, ] def provider_options_response() -> Dict[str, Any]: first = PROVIDER_OPTIONS[0] return { "providers": PROVIDER_OPTIONS, "defaults": { "provider": first["provider"], "api_base": first["default_api_base"], "chat_model": first["default_chat_model"], "embedding_model": first["default_embedding_model"], }, } def _provider_option(provider: str) -> Dict[str, Any]: normalized = (provider or "dashscope").strip().lower() for option in PROVIDER_OPTIONS: if option["provider"] == normalized: return option raise ValueError(f"Unsupported model provider: {provider}") def _normalized_payload(payload: Dict[str, Any], existing: Optional[UserModelConfig] = None) -> Dict[str, Any]: provider = str(payload.get("provider") or getattr(existing, "provider", "dashscope")).strip().lower() option = _provider_option(provider) api_base = payload.get("api_base") if api_base is None and existing is not None: api_base = existing.api_base if not api_base: api_base = option["default_api_base"] if option["supports_custom_api_base"] and provider == "openai_compatible" and not api_base: raise ValueError("OpenAI Compatible provider requires an API base URL.") chat_model = str(payload.get("chat_model") or getattr(existing, "chat_model", "") or option["default_chat_model"]).strip() embedding_model = str( payload.get("embedding_model") or getattr(existing, "embedding_model", "") or option["default_embedding_model"] ).strip() if not chat_model: raise ValueError("Chat model is required.") if not embedding_model: raise ValueError("Embedding model is required.") api_key = payload.get("api_key") if api_key is None and existing is not None: api_key = existing.api_key api_key = str(api_key or "").strip() if option["requires_api_key"] and not api_key: raise ValueError("API key is required for this provider.") name = payload.get("name") if name is None and existing is not None: name = existing.name name = str(name or "").strip() if not name: name = option["label"] return { "name": name, "provider": provider, "api_key": api_key, "api_base": str(api_base).strip() if api_base else None, "chat_model": chat_model, "embedding_model": embedding_model, "is_active": bool(payload.get("is_active", getattr(existing, "is_active", True))), } class ModelConfigService: @staticmethod def list_configs(db: Session, user_id: int) -> List[UserModelConfig]: return ( db.query(UserModelConfig) .filter(UserModelConfig.user_id == user_id) .order_by(UserModelConfig.is_active.desc(), UserModelConfig.updated_at.desc()) .all() ) @staticmethod def get_config(db: Session, user_id: int, config_id: int) -> Optional[UserModelConfig]: return ( db.query(UserModelConfig) .filter(UserModelConfig.id == config_id, UserModelConfig.user_id == user_id) .first() ) @staticmethod def get_active_config(db: Session, user_id: int) -> Optional[UserModelConfig]: return ( db.query(UserModelConfig) .filter(UserModelConfig.user_id == user_id, UserModelConfig.is_active.is_(True)) .order_by(UserModelConfig.updated_at.desc()) .first() ) @staticmethod def require_active_config(db: Session, user_id: int) -> UserModelConfig: config = ModelConfigService.get_active_config(db, user_id) if config is None: raise ValueError("请先在 API 密钥页面新增并启用模型配置。") return config @staticmethod def create_config(db: Session, user_id: int, payload: ModelConfigCreate) -> UserModelConfig: data = _normalized_payload(payload.model_dump()) if data["is_active"]: ModelConfigService._deactivate_user_configs(db, user_id) item = UserModelConfig(user_id=user_id, **data) db.add(item) db.commit() db.refresh(item) return item @staticmethod def update_config( db: Session, item: UserModelConfig, payload: ModelConfigUpdate, ) -> UserModelConfig: raw = payload.model_dump(exclude_unset=True) if raw.get("api_key") == "": raw.pop("api_key") data = _normalized_payload(raw, existing=item) if data["is_active"]: ModelConfigService._deactivate_user_configs(db, item.user_id, exclude_id=item.id) for field, value in data.items(): setattr(item, field, value) db.add(item) db.commit() db.refresh(item) return item @staticmethod def delete_config(db: Session, item: UserModelConfig) -> None: db.delete(item) db.commit() @staticmethod def touch_last_used(db: Session, item: UserModelConfig) -> UserModelConfig: item.last_used_at = datetime.utcnow() db.add(item) db.commit() db.refresh(item) return item @staticmethod def _deactivate_user_configs(db: Session, user_id: int, exclude_id: Optional[int] = None) -> None: query = db.query(UserModelConfig).filter(UserModelConfig.user_id == user_id) if exclude_id is not None: query = query.filter(UserModelConfig.id != exclude_id) query.update({UserModelConfig.is_active: False}, synchronize_session=False)