Compare commits

...

1 Commits

Author SHA1 Message Date
Dan Tasse
65c14f6b40 Avoid embedding warnings 2026-01-30 12:35:45 -05:00
2 changed files with 17 additions and 7 deletions

View File

@@ -9,6 +9,8 @@ import numpy as np
import io
import warnings
from pydantic import Field
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
from .registry import register
@@ -26,7 +28,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
Parameters
----------
model_name : str
colpali_model_name : str
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
Supports models based on these engines:
- ColPali: "vidore/colpali-v1.3" and others
@@ -57,7 +59,10 @@ class ColPaliEmbeddings(EmbeddingFunction):
useful for large models that do not fit in memory.
"""
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
colpali_model_name: str = Field(
default="Metric-AI/ColQwen2.5-3b-multilingual-v1.0",
validation_alias="model_name",
)
device: str = "auto"
dtype: str = "bfloat16"
use_token_pooling: bool = True
@@ -107,7 +112,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
self._processor,
self._token_pooler,
) = self._load_model(
self.model_name,
self.colpali_model_name,
dtype,
device,
self.pooling_strategy,

View File

@@ -10,7 +10,7 @@ import urllib.parse as urlparse
import numpy as np
import pyarrow as pa
from tqdm import tqdm
from pydantic import PrivateAttr
from pydantic import Field, PrivateAttr
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
@@ -24,7 +24,10 @@ if TYPE_CHECKING:
@register("siglip")
class SigLipEmbeddings(EmbeddingFunction):
model_name: str = "google/siglip-base-patch16-224"
siglip_model_name: str = Field(
default="google/siglip-base-patch16-224",
validation_alias="model_name",
)
device: str = "cpu"
batch_size: int = 64
normalize: bool = True
@@ -39,8 +42,10 @@ class SigLipEmbeddings(EmbeddingFunction):
transformers = attempt_import_or_raise("transformers")
self._torch = attempt_import_or_raise("torch")
self._processor = transformers.AutoProcessor.from_pretrained(self.model_name)
self._model = transformers.SiglipModel.from_pretrained(self.model_name)
self._processor = transformers.AutoProcessor.from_pretrained(
self.siglip_model_name
)
self._model = transformers.SiglipModel.from_pretrained(self.siglip_model_name)
self._model.to(self.device)
self._model.eval()
self._ndims = None