diff --git a/python/python/lancedb/embeddings/colpali.py b/python/python/lancedb/embeddings/colpali.py index 52b0d113d..5a0e16c00 100644 --- a/python/python/lancedb/embeddings/colpali.py +++ b/python/python/lancedb/embeddings/colpali.py @@ -9,6 +9,8 @@ import numpy as np import io import warnings +from pydantic import Field + from ..util import attempt_import_or_raise from .base import EmbeddingFunction from .registry import register @@ -26,7 +28,7 @@ class ColPaliEmbeddings(EmbeddingFunction): Parameters ---------- - model_name : str + colpali_model_name : str The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0") Supports models based on these engines: - ColPali: "vidore/colpali-v1.3" and others @@ -57,7 +59,10 @@ class ColPaliEmbeddings(EmbeddingFunction): useful for large models that do not fit in memory. """ - model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0" + colpali_model_name: str = Field( + default="Metric-AI/ColQwen2.5-3b-multilingual-v1.0", + validation_alias="model_name", + ) device: str = "auto" dtype: str = "bfloat16" use_token_pooling: bool = True @@ -107,7 +112,7 @@ class ColPaliEmbeddings(EmbeddingFunction): self._processor, self._token_pooler, ) = self._load_model( - self.model_name, + self.colpali_model_name, dtype, device, self.pooling_strategy, diff --git a/python/python/lancedb/embeddings/siglip.py b/python/python/lancedb/embeddings/siglip.py index 7e9c6adc9..6ce03c8a0 100644 --- a/python/python/lancedb/embeddings/siglip.py +++ b/python/python/lancedb/embeddings/siglip.py @@ -10,7 +10,7 @@ import urllib.parse as urlparse import numpy as np import pyarrow as pa from tqdm import tqdm -from pydantic import PrivateAttr +from pydantic import Field, PrivateAttr from ..util import attempt_import_or_raise from .base import EmbeddingFunction @@ -24,7 +24,10 @@ if TYPE_CHECKING: @register("siglip") class SigLipEmbeddings(EmbeddingFunction): - model_name: str = "google/siglip-base-patch16-224" + siglip_model_name: str = Field( + default="google/siglip-base-patch16-224", + validation_alias="model_name", + ) device: str = "cpu" batch_size: int = 64 normalize: bool = True @@ -39,8 +42,10 @@ class SigLipEmbeddings(EmbeddingFunction): transformers = attempt_import_or_raise("transformers") self._torch = attempt_import_or_raise("torch") - self._processor = transformers.AutoProcessor.from_pretrained(self.model_name) - self._model = transformers.SiglipModel.from_pretrained(self.model_name) + self._processor = transformers.AutoProcessor.from_pretrained( + self.siglip_model_name + ) + self._model = transformers.SiglipModel.from_pretrained(self.siglip_model_name) self._model.to(self.device) self._model.eval() self._ndims = None