增加代码知识库;修复文档处理内容;增加API设置
This commit is contained in:
211
rag-web-ui/backend/app/services/model_config.py
Normal file
211
rag-web-ui/backend/app/services/model_config.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.model_config import UserModelConfig
|
||||
from app.schemas.model_config import ModelConfigCreate, ModelConfigUpdate
|
||||
|
||||
|
||||
PROVIDER_OPTIONS: List[Dict[str, Any]] = [
|
||||
{
|
||||
"provider": "dashscope",
|
||||
"label": "DashScope",
|
||||
"default_api_base": settings.DASH_SCOPE_API_BASE,
|
||||
"default_chat_model": settings.DASH_SCOPE_CHAT_MODEL or "qwen3-max",
|
||||
"default_embedding_model": settings.DASH_SCOPE_EMBEDDINGS_MODEL or "text-embedding-v4",
|
||||
"chat_models": ["qwen3-max", "qwen-plus", "qwen-turbo", "qwen-max"],
|
||||
"embedding_models": ["text-embedding-v4", "text-embedding-v3", "text-embedding-v2"],
|
||||
"requires_api_key": True,
|
||||
"supports_custom_api_base": True,
|
||||
},
|
||||
{
|
||||
"provider": "openai",
|
||||
"label": "OpenAI",
|
||||
"default_api_base": settings.OPENAI_API_BASE,
|
||||
"default_chat_model": settings.OPENAI_MODEL or "gpt-4o",
|
||||
"default_embedding_model": settings.OPENAI_EMBEDDINGS_MODEL or "text-embedding-3-small",
|
||||
"chat_models": ["gpt-4o", "gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"],
|
||||
"embedding_models": ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
|
||||
"requires_api_key": True,
|
||||
"supports_custom_api_base": True,
|
||||
},
|
||||
{
|
||||
"provider": "openai_compatible",
|
||||
"label": "OpenAI Compatible",
|
||||
"default_api_base": "",
|
||||
"default_chat_model": "qwen3-max",
|
||||
"default_embedding_model": "text-embedding-v4",
|
||||
"chat_models": ["qwen3-max", "deepseek-chat", "gpt-4o-mini"],
|
||||
"embedding_models": ["text-embedding-v4", "text-embedding-3-small"],
|
||||
"requires_api_key": True,
|
||||
"supports_custom_api_base": True,
|
||||
},
|
||||
{
|
||||
"provider": "ollama",
|
||||
"label": "Ollama",
|
||||
"default_api_base": settings.OLLAMA_API_BASE,
|
||||
"default_chat_model": settings.OLLAMA_MODEL,
|
||||
"default_embedding_model": settings.OLLAMA_EMBEDDINGS_MODEL,
|
||||
"chat_models": [settings.OLLAMA_MODEL, "llama3.1", "qwen2.5", "deepseek-r1:7b"],
|
||||
"embedding_models": [settings.OLLAMA_EMBEDDINGS_MODEL, "nomic-embed-text", "mxbai-embed-large"],
|
||||
"requires_api_key": False,
|
||||
"supports_custom_api_base": True,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def provider_options_response() -> Dict[str, Any]:
|
||||
first = PROVIDER_OPTIONS[0]
|
||||
return {
|
||||
"providers": PROVIDER_OPTIONS,
|
||||
"defaults": {
|
||||
"provider": first["provider"],
|
||||
"api_base": first["default_api_base"],
|
||||
"chat_model": first["default_chat_model"],
|
||||
"embedding_model": first["default_embedding_model"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _provider_option(provider: str) -> Dict[str, Any]:
|
||||
normalized = (provider or "dashscope").strip().lower()
|
||||
for option in PROVIDER_OPTIONS:
|
||||
if option["provider"] == normalized:
|
||||
return option
|
||||
raise ValueError(f"Unsupported model provider: {provider}")
|
||||
|
||||
|
||||
def _normalized_payload(payload: Dict[str, Any], existing: Optional[UserModelConfig] = None) -> Dict[str, Any]:
|
||||
provider = str(payload.get("provider") or getattr(existing, "provider", "dashscope")).strip().lower()
|
||||
option = _provider_option(provider)
|
||||
api_base = payload.get("api_base")
|
||||
if api_base is None and existing is not None:
|
||||
api_base = existing.api_base
|
||||
if not api_base:
|
||||
api_base = option["default_api_base"]
|
||||
if option["supports_custom_api_base"] and provider == "openai_compatible" and not api_base:
|
||||
raise ValueError("OpenAI Compatible provider requires an API base URL.")
|
||||
|
||||
chat_model = str(payload.get("chat_model") or getattr(existing, "chat_model", "") or option["default_chat_model"]).strip()
|
||||
embedding_model = str(
|
||||
payload.get("embedding_model")
|
||||
or getattr(existing, "embedding_model", "")
|
||||
or option["default_embedding_model"]
|
||||
).strip()
|
||||
if not chat_model:
|
||||
raise ValueError("Chat model is required.")
|
||||
if not embedding_model:
|
||||
raise ValueError("Embedding model is required.")
|
||||
|
||||
api_key = payload.get("api_key")
|
||||
if api_key is None and existing is not None:
|
||||
api_key = existing.api_key
|
||||
api_key = str(api_key or "").strip()
|
||||
if option["requires_api_key"] and not api_key:
|
||||
raise ValueError("API key is required for this provider.")
|
||||
|
||||
name = payload.get("name")
|
||||
if name is None and existing is not None:
|
||||
name = existing.name
|
||||
name = str(name or "").strip()
|
||||
if not name:
|
||||
name = option["label"]
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"provider": provider,
|
||||
"api_key": api_key,
|
||||
"api_base": str(api_base).strip() if api_base else None,
|
||||
"chat_model": chat_model,
|
||||
"embedding_model": embedding_model,
|
||||
"is_active": bool(payload.get("is_active", getattr(existing, "is_active", True))),
|
||||
}
|
||||
|
||||
|
||||
class ModelConfigService:
|
||||
@staticmethod
|
||||
def list_configs(db: Session, user_id: int) -> List[UserModelConfig]:
|
||||
return (
|
||||
db.query(UserModelConfig)
|
||||
.filter(UserModelConfig.user_id == user_id)
|
||||
.order_by(UserModelConfig.is_active.desc(), UserModelConfig.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_config(db: Session, user_id: int, config_id: int) -> Optional[UserModelConfig]:
|
||||
return (
|
||||
db.query(UserModelConfig)
|
||||
.filter(UserModelConfig.id == config_id, UserModelConfig.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_active_config(db: Session, user_id: int) -> Optional[UserModelConfig]:
|
||||
return (
|
||||
db.query(UserModelConfig)
|
||||
.filter(UserModelConfig.user_id == user_id, UserModelConfig.is_active.is_(True))
|
||||
.order_by(UserModelConfig.updated_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def require_active_config(db: Session, user_id: int) -> UserModelConfig:
|
||||
config = ModelConfigService.get_active_config(db, user_id)
|
||||
if config is None:
|
||||
raise ValueError("请先在 API 密钥页面新增并启用模型配置。")
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def create_config(db: Session, user_id: int, payload: ModelConfigCreate) -> UserModelConfig:
|
||||
data = _normalized_payload(payload.model_dump())
|
||||
if data["is_active"]:
|
||||
ModelConfigService._deactivate_user_configs(db, user_id)
|
||||
item = UserModelConfig(user_id=user_id, **data)
|
||||
db.add(item)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
return item
|
||||
|
||||
@staticmethod
|
||||
def update_config(
|
||||
db: Session,
|
||||
item: UserModelConfig,
|
||||
payload: ModelConfigUpdate,
|
||||
) -> UserModelConfig:
|
||||
raw = payload.model_dump(exclude_unset=True)
|
||||
if raw.get("api_key") == "":
|
||||
raw.pop("api_key")
|
||||
data = _normalized_payload(raw, existing=item)
|
||||
if data["is_active"]:
|
||||
ModelConfigService._deactivate_user_configs(db, item.user_id, exclude_id=item.id)
|
||||
for field, value in data.items():
|
||||
setattr(item, field, value)
|
||||
db.add(item)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
return item
|
||||
|
||||
@staticmethod
|
||||
def delete_config(db: Session, item: UserModelConfig) -> None:
|
||||
db.delete(item)
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def touch_last_used(db: Session, item: UserModelConfig) -> UserModelConfig:
|
||||
item.last_used_at = datetime.utcnow()
|
||||
db.add(item)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
return item
|
||||
|
||||
@staticmethod
|
||||
def _deactivate_user_configs(db: Session, user_id: int, exclude_id: Optional[int] = None) -> None:
|
||||
query = db.query(UserModelConfig).filter(UserModelConfig.user_id == user_id)
|
||||
if exclude_id is not None:
|
||||
query = query.filter(UserModelConfig.id != exclude_id)
|
||||
query.update({UserModelConfig.is_active: False}, synchronize_session=False)
|
||||
Reference in New Issue
Block a user