87 lines
3.5 KiB
Python
87 lines
3.5 KiB
Python
from app.core.config import settings
|
|
from langchain_openai import OpenAIEmbeddings
|
|
from langchain_ollama import OllamaEmbeddings
|
|
from typing import Optional
|
|
# If you plan on adding other embeddings, import them here
|
|
# from some_other_module import AnotherEmbeddingClass
|
|
|
|
|
|
class EmbeddingsFactory:
|
|
@staticmethod
|
|
def create(provider: Optional[str] = None, model_profile: Optional[object] = None):
|
|
"""
|
|
Factory method to create an embeddings instance based on .env config.
|
|
"""
|
|
if model_profile is not None:
|
|
embeddings_provider = (provider or getattr(model_profile, "provider", None) or "dashscope").lower()
|
|
api_key = getattr(model_profile, "api_key", "") or ""
|
|
api_base = getattr(model_profile, "api_base", None) or _default_api_base(embeddings_provider)
|
|
model = getattr(model_profile, "embedding_model", None) or _default_embedding_model(embeddings_provider)
|
|
else:
|
|
embeddings_provider = (provider or settings.EMBEDDINGS_PROVIDER).lower()
|
|
api_key = _default_api_key(embeddings_provider)
|
|
api_base = _default_api_base(embeddings_provider)
|
|
model = _default_embedding_model(embeddings_provider)
|
|
|
|
if embeddings_provider == "openai":
|
|
return OpenAIEmbeddings(
|
|
openai_api_key=api_key,
|
|
openai_api_base=api_base,
|
|
model=model
|
|
)
|
|
elif embeddings_provider in {"dashscope", "openai_compatible"}:
|
|
return OpenAIEmbeddings(
|
|
openai_api_key=api_key,
|
|
openai_api_base=api_base,
|
|
model=model,
|
|
# DashScope OpenAI-compatible embedding expects string input,
|
|
# while LangChain's len-safe path may send token ids.
|
|
check_embedding_ctx_length=False,
|
|
tiktoken_enabled=False,
|
|
skip_empty=True,
|
|
# DashScope embedding API supports at most 10 inputs per batch.
|
|
chunk_size=10,
|
|
)
|
|
elif embeddings_provider == "ollama":
|
|
return OllamaEmbeddings(
|
|
model=model,
|
|
base_url=api_base
|
|
)
|
|
|
|
# Extend with other providers:
|
|
# elif embeddings_provider == "another_provider":
|
|
# return AnotherEmbeddingClass(...)
|
|
else:
|
|
raise ValueError(f"Unsupported embeddings provider: {embeddings_provider}")
|
|
|
|
|
|
def _default_embedding_model(provider: Optional[str]) -> str:
|
|
provider = (provider or settings.EMBEDDINGS_PROVIDER).lower()
|
|
if provider == "openai":
|
|
return settings.OPENAI_EMBEDDINGS_MODEL
|
|
if provider == "dashscope":
|
|
return settings.DASH_SCOPE_EMBEDDINGS_MODEL or "text-embedding-v4"
|
|
if provider == "ollama":
|
|
return settings.OLLAMA_EMBEDDINGS_MODEL
|
|
return settings.DASH_SCOPE_EMBEDDINGS_MODEL or settings.OPENAI_EMBEDDINGS_MODEL
|
|
|
|
|
|
def _default_api_key(provider: Optional[str]) -> str:
|
|
provider = (provider or settings.EMBEDDINGS_PROVIDER).lower()
|
|
if provider == "openai":
|
|
return settings.OPENAI_API_KEY
|
|
if provider == "dashscope":
|
|
return settings.DASH_SCOPE_API_KEY
|
|
return settings.API_KEY
|
|
|
|
|
|
def _default_api_base(provider: Optional[str]) -> str:
|
|
provider = (provider or settings.EMBEDDINGS_PROVIDER).lower()
|
|
if provider == "openai":
|
|
return settings.OPENAI_API_BASE
|
|
if provider == "dashscope":
|
|
return settings.DASH_SCOPE_API_BASE
|
|
if provider == "ollama":
|
|
return settings.OLLAMA_API_BASE
|
|
return settings.DASH_SCOPE_API_BASE
|