Compare commits

...

1 Commits

Author SHA1 Message Date
Dan Tasse
65c14f6b40 Avoid embedding warnings 2026-01-30 12:35:45 -05:00
2 changed files with 17 additions and 7 deletions

View File

@@ -9,6 +9,8 @@ import numpy as np
import io import io
import warnings import warnings
from pydantic import Field
from ..util import attempt_import_or_raise from ..util import attempt_import_or_raise
from .base import EmbeddingFunction from .base import EmbeddingFunction
from .registry import register from .registry import register
@@ -26,7 +28,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
Parameters Parameters
---------- ----------
model_name : str colpali_model_name : str
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0") The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
Supports models based on these engines: Supports models based on these engines:
- ColPali: "vidore/colpali-v1.3" and others - ColPali: "vidore/colpali-v1.3" and others
@@ -57,7 +59,10 @@ class ColPaliEmbeddings(EmbeddingFunction):
useful for large models that do not fit in memory. useful for large models that do not fit in memory.
""" """
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0" colpali_model_name: str = Field(
default="Metric-AI/ColQwen2.5-3b-multilingual-v1.0",
validation_alias="model_name",
)
device: str = "auto" device: str = "auto"
dtype: str = "bfloat16" dtype: str = "bfloat16"
use_token_pooling: bool = True use_token_pooling: bool = True
@@ -107,7 +112,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
self._processor, self._processor,
self._token_pooler, self._token_pooler,
) = self._load_model( ) = self._load_model(
self.model_name, self.colpali_model_name,
dtype, dtype,
device, device,
self.pooling_strategy, self.pooling_strategy,

View File

@@ -10,7 +10,7 @@ import urllib.parse as urlparse
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
from tqdm import tqdm from tqdm import tqdm
from pydantic import PrivateAttr from pydantic import Field, PrivateAttr
from ..util import attempt_import_or_raise from ..util import attempt_import_or_raise
from .base import EmbeddingFunction from .base import EmbeddingFunction
@@ -24,7 +24,10 @@ if TYPE_CHECKING:
@register("siglip") @register("siglip")
class SigLipEmbeddings(EmbeddingFunction): class SigLipEmbeddings(EmbeddingFunction):
model_name: str = "google/siglip-base-patch16-224" siglip_model_name: str = Field(
default="google/siglip-base-patch16-224",
validation_alias="model_name",
)
device: str = "cpu" device: str = "cpu"
batch_size: int = 64 batch_size: int = 64
normalize: bool = True normalize: bool = True
@@ -39,8 +42,10 @@ class SigLipEmbeddings(EmbeddingFunction):
transformers = attempt_import_or_raise("transformers") transformers = attempt_import_or_raise("transformers")
self._torch = attempt_import_or_raise("torch") self._torch = attempt_import_or_raise("torch")
self._processor = transformers.AutoProcessor.from_pretrained(self.model_name) self._processor = transformers.AutoProcessor.from_pretrained(
self._model = transformers.SiglipModel.from_pretrained(self.model_name) self.siglip_model_name
)
self._model = transformers.SiglipModel.from_pretrained(self.siglip_model_name)
self._model.to(self.device) self._model.to(self.device)
self._model.eval() self._model.eval()
self._ndims = None self._ndims = None