51 lines
1.5 KiB
Python
51 lines
1.5 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import yaml
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class ProviderConfig:
|
||
|
|
api_key: str
|
||
|
|
base_url: str
|
||
|
|
max_tokens: int
|
||
|
|
model: str
|
||
|
|
temperature: float
|
||
|
|
|
||
|
|
@property
|
||
|
|
def chat_completions_url(self) -> str:
|
||
|
|
base_url = self.base_url.rstrip("/")
|
||
|
|
if base_url.endswith("/chat/completions"):
|
||
|
|
return base_url
|
||
|
|
return f"{base_url}/chat/completions"
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class ApiSettings:
|
||
|
|
provider_name: str
|
||
|
|
provider: ProviderConfig
|
||
|
|
|
||
|
|
|
||
|
|
def load_api_config(path: Path | str = Path("configs/api_config.yaml"), provider_name: str | None = None) -> ApiSettings:
|
||
|
|
config_path = Path(path)
|
||
|
|
data = yaml.safe_load(config_path.read_text(encoding="utf-8")) or {}
|
||
|
|
selected_name = provider_name or data.get("default_provider")
|
||
|
|
providers = data.get("providers", {})
|
||
|
|
|
||
|
|
if not selected_name:
|
||
|
|
raise ValueError("api_config.yaml missing default_provider")
|
||
|
|
if selected_name not in providers:
|
||
|
|
raise ValueError(f"provider not found in api_config.yaml: {selected_name}")
|
||
|
|
|
||
|
|
provider_data = providers[selected_name]
|
||
|
|
provider = ProviderConfig(
|
||
|
|
api_key=str(provider_data.get("api_key", "")),
|
||
|
|
base_url=str(provider_data["base_url"]),
|
||
|
|
max_tokens=int(provider_data.get("max_tokens", 4096)),
|
||
|
|
model=str(provider_data["model"]),
|
||
|
|
temperature=float(provider_data.get("temperature", 0.7)),
|
||
|
|
)
|
||
|
|
return ApiSettings(provider_name=selected_name, provider=provider)
|