mirror of
https://github.com/lancedb/lancedb.git
synced 2026-03-26 02:20:40 +00:00
Compare commits
1 Commits
v0.27.0-be
...
dantasse/e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
65c14f6b40 |
@@ -9,6 +9,8 @@ import numpy as np
|
||||
import io
|
||||
import warnings
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
@@ -26,7 +28,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str
|
||||
colpali_model_name : str
|
||||
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
|
||||
Supports models based on these engines:
|
||||
- ColPali: "vidore/colpali-v1.3" and others
|
||||
@@ -57,7 +59,10 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
useful for large models that do not fit in memory.
|
||||
"""
|
||||
|
||||
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
|
||||
colpali_model_name: str = Field(
|
||||
default="Metric-AI/ColQwen2.5-3b-multilingual-v1.0",
|
||||
validation_alias="model_name",
|
||||
)
|
||||
device: str = "auto"
|
||||
dtype: str = "bfloat16"
|
||||
use_token_pooling: bool = True
|
||||
@@ -107,7 +112,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
self._processor,
|
||||
self._token_pooler,
|
||||
) = self._load_model(
|
||||
self.model_name,
|
||||
self.colpali_model_name,
|
||||
dtype,
|
||||
device,
|
||||
self.pooling_strategy,
|
||||
|
||||
@@ -10,7 +10,7 @@ import urllib.parse as urlparse
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from tqdm import tqdm
|
||||
from pydantic import PrivateAttr
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
@@ -24,7 +24,10 @@ if TYPE_CHECKING:
|
||||
|
||||
@register("siglip")
|
||||
class SigLipEmbeddings(EmbeddingFunction):
|
||||
model_name: str = "google/siglip-base-patch16-224"
|
||||
siglip_model_name: str = Field(
|
||||
default="google/siglip-base-patch16-224",
|
||||
validation_alias="model_name",
|
||||
)
|
||||
device: str = "cpu"
|
||||
batch_size: int = 64
|
||||
normalize: bool = True
|
||||
@@ -39,8 +42,10 @@ class SigLipEmbeddings(EmbeddingFunction):
|
||||
transformers = attempt_import_or_raise("transformers")
|
||||
self._torch = attempt_import_or_raise("torch")
|
||||
|
||||
self._processor = transformers.AutoProcessor.from_pretrained(self.model_name)
|
||||
self._model = transformers.SiglipModel.from_pretrained(self.model_name)
|
||||
self._processor = transformers.AutoProcessor.from_pretrained(
|
||||
self.siglip_model_name
|
||||
)
|
||||
self._model = transformers.SiglipModel.from_pretrained(self.siglip_model_name)
|
||||
self._model.to(self.device)
|
||||
self._model.eval()
|
||||
self._ndims = None
|
||||
|
||||
Reference in New Issue
Block a user