104 lines
3.7 KiB
Python
104 lines
3.7 KiB
Python
from typing import Optional
|
|
from langchain_core.language_models import BaseChatModel
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_deepseek import ChatDeepSeek
|
|
from langchain_ollama import OllamaLLM
|
|
from app.core.config import settings
|
|
|
|
class LLMFactory:
|
|
@staticmethod
|
|
def create(
|
|
provider: Optional[str] = None,
|
|
temperature: float = 0,
|
|
streaming: bool = True,
|
|
model_profile: Optional[object] = None,
|
|
) -> BaseChatModel:
|
|
"""
|
|
Create a LLM instance based on the provider
|
|
"""
|
|
if model_profile is not None:
|
|
provider = (provider or getattr(model_profile, "provider", None) or "dashscope").lower()
|
|
model = getattr(model_profile, "chat_model", None) or _default_chat_model(provider)
|
|
api_key = getattr(model_profile, "api_key", "") or ""
|
|
api_base = getattr(model_profile, "api_base", None) or _default_api_base(provider)
|
|
else:
|
|
provider = provider or settings.CHAT_PROVIDER
|
|
model = _default_chat_model(provider)
|
|
api_key = _default_api_key(provider)
|
|
api_base = _default_api_base(provider)
|
|
|
|
if provider.lower() == "openai":
|
|
return ChatOpenAI(
|
|
temperature=temperature,
|
|
streaming=streaming,
|
|
model=model,
|
|
openai_api_key=api_key,
|
|
openai_api_base=api_base
|
|
)
|
|
elif provider.lower() == "deepseek":
|
|
return ChatDeepSeek(
|
|
temperature=temperature,
|
|
streaming=streaming,
|
|
model=model,
|
|
api_key=api_key,
|
|
api_base=api_base
|
|
)
|
|
elif provider.lower() in {"dashscope", "openai_compatible"}:
|
|
return ChatOpenAI(
|
|
temperature=temperature,
|
|
streaming=streaming,
|
|
model=model,
|
|
openai_api_key=api_key,
|
|
openai_api_base=api_base,
|
|
)
|
|
elif provider.lower() == "ollama":
|
|
# Initialize Ollama model
|
|
return OllamaLLM(
|
|
model=model,
|
|
base_url=api_base,
|
|
temperature=temperature,
|
|
streaming=streaming
|
|
)
|
|
# Add more providers here as needed
|
|
# elif provider.lower() == "anthropic":
|
|
# return ChatAnthropic(...)
|
|
else:
|
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
|
|
|
|
|
def _default_chat_model(provider: Optional[str]) -> str:
|
|
provider = (provider or settings.CHAT_PROVIDER).lower()
|
|
if provider == "openai":
|
|
return settings.OPENAI_MODEL
|
|
if provider == "deepseek":
|
|
return settings.DEEPSEEK_MODEL
|
|
if provider == "dashscope":
|
|
return settings.DASH_SCOPE_CHAT_MODEL
|
|
if provider == "ollama":
|
|
return settings.OLLAMA_MODEL
|
|
return settings.DASH_SCOPE_CHAT_MODEL or settings.OPENAI_MODEL
|
|
|
|
|
|
def _default_api_key(provider: Optional[str]) -> str:
|
|
provider = (provider or settings.CHAT_PROVIDER).lower()
|
|
if provider == "openai":
|
|
return settings.OPENAI_API_KEY
|
|
if provider == "deepseek":
|
|
return settings.DEEPSEEK_API_KEY
|
|
if provider == "dashscope":
|
|
return settings.DASH_SCOPE_API_KEY
|
|
return settings.API_KEY
|
|
|
|
|
|
def _default_api_base(provider: Optional[str]) -> str:
|
|
provider = (provider or settings.CHAT_PROVIDER).lower()
|
|
if provider == "openai":
|
|
return settings.OPENAI_API_BASE
|
|
if provider == "deepseek":
|
|
return settings.DEEPSEEK_API_BASE
|
|
if provider == "dashscope":
|
|
return settings.DASH_SCOPE_API_BASE
|
|
if provider == "ollama":
|
|
return settings.OLLAMA_API_BASE
|
|
return settings.DASH_SCOPE_API_BASE
|