Compare commits

..

1 Commits

Author SHA1 Message Date
Lance Release
143184c0ae Bump version: 0.25.2 → 0.25.3-beta.0 2025-10-14 02:25:16 +00:00
8 changed files with 235 additions and 430 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.25.2"
current_version = "0.25.3-beta.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.25.2"
version = "0.25.3-beta.0"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true

View File

@@ -80,7 +80,7 @@ embeddings = [
"pillow",
"open-clip-torch",
"cohere",
"colpali-engine>=0.3.12",
"colpali-engine>=0.3.10",
"huggingface_hub",
"InstructorEmbedding",
"google.generativeai",

View File

@@ -19,5 +19,5 @@ from .imagebind import ImageBindEmbeddings
from .jinaai import JinaEmbeddings
from .watsonx import WatsonxEmbeddings
from .voyageai import VoyageAIEmbeddingFunction
from .colpali import MultimodalLateInteractionEmbeddings, ColPaliEmbeddings # noqa: F401
from .colpali import ColPaliEmbeddings
from .siglip import SigLipEmbeddings

512
python/python/lancedb/embeddings/colpali.py Executable file → Normal file
View File

@@ -1,347 +1,255 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Late-interaction embeddings powered by colpali-engine."""
from __future__ import annotations
from functools import lru_cache
from typing import List, Union, Optional, Any
import numpy as np
import io
from typing import Any, Dict, List, Optional, Sequence
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
from .registry import register
from .utils import IMAGES, TEXT, is_flash_attn_2_available, weak_lru
from .utils import TEXT, IMAGES, is_flash_attn_2_available
_FAMILY_ALIASES = {
"colsmol": {"colsmol", "colsmolvlm", "smol"},
"colqwen2.5": {"colqwen2.5", "colqwen25", "colqwen-2.5"},
"colqwen2": {"colqwen2", "colqwen-2"},
"colpali": {"colpali", "paligemma"},
}
@register("colpali")
class ColPaliEmbeddings(EmbeddingFunction):
"""
An embedding function that uses the ColPali engine for
multimodal multi-vector embeddings.
_FAMILY_CLASSES = {
"colpali": ("ColPali", "ColPaliProcessor"),
"colqwen2.5": ("ColQwen2_5", "ColQwen2_5_Processor"),
"colqwen2": ("ColQwen2", "ColQwen2Processor"),
"colsmol": ("ColIdefics3", "ColIdefics3Processor"),
}
This embedding function supports ColQwen2.5 models, producing multivector outputs
for both text and image inputs. The output embeddings are lists of vectors, each
vector being 128-dimensional by default, represented as List[List[float]].
Parameters
----------
model_name : str
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
device : str
The device for inference (default "cuda:0").
dtype : str
Data type for model weights (default "bfloat16").
use_token_pooling : bool
Whether to use token pooling to reduce embedding size (default True).
pool_factor : int
Factor to reduce sequence length if token pooling is enabled (default 2).
quantization_config : Optional[BitsAndBytesConfig]
Quantization configuration for the model. (default None, bitsandbytes needed)
batch_size : int
Batch size for processing inputs (default 2).
"""
def _torch() -> Any:
return attempt_import_or_raise("torch", "torch")
def _torch_dtype(dtype: str) -> Any:
torch = _torch()
mapping = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
}
if dtype not in mapping:
raise ValueError(
"Unsupported dtype '{}'. Expected one of {}".format(
dtype, ", ".join(sorted(mapping))
)
)
return mapping[dtype]
def _load_pooler(use_pooler: bool) -> Optional[Any]:
if not use_pooler:
return None
token_pooling = attempt_import_or_raise(
"colpali_engine.compression.token_pooling", "colpali-engine"
)
pooler_cls = getattr(token_pooling, "HierarchicalTokenPooler", None)
if pooler_cls is None:
raise ImportError(
"colpali_engine HierarchicalTokenPooler not available; update colpali-engine"
)
return pooler_cls()
def _move_to_device(batch: Any, device: Any) -> Any:
if device is None:
return batch
torch = _torch()
if isinstance(device, str):
device_obj = torch.device(device)
else:
device_obj = device
if isinstance(batch, dict):
return {k: _move_to_device(v, device_obj) for k, v in batch.items()}
if hasattr(batch, "to"):
return batch.to(device_obj)
return batch
@register("multimodal-late-interaction")
class MultimodalLateInteractionEmbeddings(EmbeddingFunction):
"""Late-interaction embeddings for ViDoRe models."""
model_name: str = "vidore/colSmol-256M"
model_family: Optional[str] = None
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
device: str = "auto"
dtype: str = "bfloat16"
use_token_pooling: bool = True
pool_factor: int = 2
batch_size: int = 4
quantization_config: Optional[Any] = None
batch_size: int = 2
def __init__(self, *args: Any, **kwargs: Any) -> None:
_model = None
_processor = None
_token_pooler = None
_vector_dim = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._family = self._resolve_family(self.model_name, self.model_family)
self._vector_dim: Optional[int] = None
(
self._model,
self._processor,
self._token_pooler,
) = self._load_model(
self.model_name,
self.dtype,
self.device,
self.use_token_pooling,
self.quantization_config,
)
@property
def model(self) -> Any:
"""The cached model."""
return self._get_models()[0]
@property
def processor(self) -> Any:
"""The cached processor."""
return self._get_models()[1]
@property
def pooler(self) -> Optional[Any]:
"""The cached pooler."""
return self._get_models()[2]
@property
def target_device(self) -> Optional[Any]:
"""The cached target device."""
return self._get_models()[3]
# ------------------------------------------------------------------
# Family detection
# ------------------------------------------------------------------
@classmethod
def _resolve_family(cls, model_name: str, explicit: Optional[str]) -> str:
if explicit:
family = explicit.lower()
if family not in _FAMILY_CLASSES:
raise ValueError(
"Unknown model_family '{}'. Expected one of {}".format(
explicit, ", ".join(sorted(_FAMILY_CLASSES))
)
)
return family
lowered = model_name.lower()
for family, aliases in _FAMILY_ALIASES.items():
if any(alias in lowered for alias in aliases):
return family
return "colpali"
# ------------------------------------------------------------------
# Model loading
# ------------------------------------------------------------------
@weak_lru(maxsize=1)
def _get_models(self) -> tuple[Any, Optional[Any], Optional[Any], Optional[Any]]:
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali-engine")
@staticmethod
@lru_cache(maxsize=1)
def _load_model(
model_name: str,
dtype: str,
device: str,
use_token_pooling: bool,
quantization_config: Optional[Any],
):
"""
Initialize and cache the ColPali model, processor, and token pooler.
"""
torch = attempt_import_or_raise("torch", "torch")
transformers = attempt_import_or_raise("transformers", "transformers")
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
if (
self.quantization_config is not None
and not isinstance(
self.quantization_config, transformers.BitsAndBytesConfig
)
):
raise ValueError(
"quantization_config must be a transformers.BitsAndBytesConfig instance"
)
if quantization_config is not None:
if not isinstance(quantization_config, transformers.BitsAndBytesConfig):
raise ValueError("quantization_config must be a BitsAndBytesConfig")
model_cls_name, processor_cls_name = _FAMILY_CLASSES[self._family]
model_cls = getattr(colpali_engine.models, model_cls_name)
processor_cls = getattr(colpali_engine.models, processor_cls_name)
torch = _torch()
device_map = self.device
target_device: Optional[Any] = None
if device_map == "auto":
if torch.cuda.is_available():
device_map = "cuda:0"
target_device = torch.device("cuda:0")
elif (
getattr(torch.backends, "mps", None)
and torch.backends.mps.is_available()
):
device_map = "mps"
target_device = torch.device("mps")
else:
device_map = "cpu"
target_device = torch.device("cpu")
if dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif dtype == "float16":
torch_dtype = torch.float16
elif dtype == "float64":
torch_dtype = torch.float64
else:
try:
target_device = torch.device(device_map)
except (TypeError, ValueError): # pragma: no cover - device map dicts
target_device = None
torch_dtype = _torch_dtype(self.dtype)
if isinstance(device_map, str) and device_map == "cpu" and torch_dtype in {
torch.bfloat16,
torch.float16,
}:
torch_dtype = torch.float32
load_kwargs: Dict[str, Any] = {
"torch_dtype": torch_dtype,
"device_map": device_map,
}
if self.quantization_config is not None:
load_kwargs["quantization_config"] = self.quantization_config
attn_impl = "flash_attention_2" if is_flash_attn_2_available() else None
if attn_impl is not None:
load_kwargs["attn_implementation"] = attn_impl
model = model_cls.from_pretrained(self.model_name, **load_kwargs)
if hasattr(model, "eval"):
model = model.eval()
processor = processor_cls.from_pretrained(self.model_name)
pooler = _load_pooler(self.use_token_pooling)
if target_device is None and hasattr(model, "device"):
target_device = getattr(model, "device")
return model, processor, pooler, target_device
# ------------------------------------------------------------------
# Encoding helpers
# ------------------------------------------------------------------
def _pool_tensor(self, tensor: Any) -> Any:
if self.pooler is None:
return tensor
torch = _torch()
assert isinstance(tensor, torch.Tensor)
expanded = False
if tensor.ndim == 2:
tensor = tensor.unsqueeze(0)
expanded = True
kwargs = {"pool_factor": self.pool_factor, "padding": True}
tokenizer = getattr(
getattr(self.processor, "tokenizer", None), "padding_side", None
model = colpali_engine.models.ColQwen2_5.from_pretrained(
model_name,
torch_dtype=torch_dtype,
device_map=device,
quantization_config=quantization_config
if quantization_config is not None
else None,
attn_implementation="flash_attention_2"
if is_flash_attn_2_available()
else None,
).eval()
processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained(
model_name
)
if tokenizer is not None:
kwargs["padding_side"] = tokenizer
pooled = self.pooler.pool_embeddings(tensor, **kwargs)
if expanded:
pooled = pooled.squeeze(0)
return pooled
token_pooler = HierarchicalTokenPooler() if use_token_pooling else None
return model, processor, token_pooler
def ndims(self):
"""
Return the dimension of a vector in the multivector output (e.g., 128).
"""
torch = attempt_import_or_raise("torch", "torch")
if self._vector_dim is None:
dummy_query = "test"
batch_queries = self._processor.process_queries([dummy_query]).to(
self._model.device
)
with torch.no_grad():
query_embeddings = self._model(**batch_queries)
if self.use_token_pooling and self._token_pooler is not None:
query_embeddings = self._token_pooler.pool_embeddings(
query_embeddings,
pool_factor=self.pool_factor,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
self._vector_dim = query_embeddings[0].shape[-1]
return self._vector_dim
def _process_embeddings(self, embeddings):
"""
Format model embeddings into List[List[float]].
Use token pooling if enabled.
"""
torch = attempt_import_or_raise("torch", "torch")
if self.use_token_pooling and self._token_pooler is not None:
embeddings = self._token_pooler.pool_embeddings(
embeddings,
pool_factor=self.pool_factor,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
def _normalize_output(self, embeddings: Any) -> List[List[List[float]]]:
torch = _torch()
if hasattr(embeddings, "last_hidden_state"):
return self._normalize_output(embeddings.last_hidden_state)
if isinstance(embeddings, dict) and "last_hidden_state" in embeddings:
return self._normalize_output(embeddings["last_hidden_state"])
if isinstance(embeddings, torch.Tensor):
pooled = self._pool_tensor(embeddings).detach().cpu()
if pooled.ndim == 2:
pooled = pooled.unsqueeze(0)
target = torch.float64 if self.dtype == "float64" else torch.float32
return pooled.to(target).numpy().tolist()
if isinstance(embeddings, (list, tuple)):
results: List[List[List[float]]] = []
for item in embeddings:
results.extend(self._normalize_output(item))
return results
raise TypeError(f"Unsupported embedding type {type(embeddings)}")
# ------------------------------------------------------------------
# Text encoding
# ------------------------------------------------------------------
def _encode_text(self, batch: Sequence[str]) -> List[List[List[float]]]:
if not self.processor or not hasattr(self.processor, "process_queries"):
raise RuntimeError("Processor for text queries is not available for this model")
payload = self.processor.process_queries(batch)
payload = _move_to_device(payload, self.target_device)
torch = _torch()
with torch.no_grad():
outputs = self.model(**payload)
return self._normalize_output(outputs)
# ------------------------------------------------------------------
# Image encoding
# ------------------------------------------------------------------
def _prepare_images(self, images: IMAGES) -> List[Any]:
PIL = attempt_import_or_raise("PIL", "pillow")
requests = attempt_import_or_raise("requests", "requests")
prepared: List[Any] = []
for image in self.sanitize_input(images):
if isinstance(image, str) and image.startswith(("http://", "https://")):
response = requests.get(image, timeout=10)
response.raise_for_status()
prepared.append(PIL.Image.open(io.BytesIO(response.content)))
elif isinstance(image, str):
with PIL.Image.open(image) as img:
prepared.append(img.copy())
elif isinstance(image, bytes):
prepared.append(PIL.Image.open(io.BytesIO(image)))
else:
prepared.append(image)
return prepared
def _encode_images(self, images: Sequence[Any]) -> List[List[List[float]]]:
if not self.processor or not hasattr(self.processor, "process_images"):
raise RuntimeError("Processor for images is not available for this model")
payload = self.processor.process_images(images)
payload = _move_to_device(payload, self.target_device)
torch = _torch()
with torch.no_grad():
outputs = self.model(**payload)
return self._normalize_output(outputs)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def _batched(self, values: Sequence[Any], encoder) -> List[List[List[float]]]:
results: List[List[List[float]]] = []
for start in range(0, len(values), self.batch_size):
chunk = values[start : start + self.batch_size]
results.extend(encoder(chunk))
return results
tensors = embeddings.detach().cpu()
if tensors.dtype == torch.bfloat16:
tensors = tensors.to(torch.float32)
return (
tensors.numpy()
.astype(np.float64 if self.dtype == "float64" else np.float32)
.tolist()
)
return []
def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]:
"""
Generate embeddings for text input.
"""
torch = attempt_import_or_raise("torch", "torch")
text = self.sanitize_input(text)
if len(text) == 0:
return []
return self._batched(text, self._encode_text)
all_embeddings = []
for i in range(0, len(text), self.batch_size):
batch_text = text[i : i + self.batch_size]
batch_queries = self._processor.process_queries(batch_text).to(
self._model.device
)
with torch.no_grad():
query_embeddings = self._model(**batch_queries)
all_embeddings.extend(self._process_embeddings(query_embeddings))
return all_embeddings
def _prepare_images(self, images: IMAGES) -> List:
"""
Convert image inputs to PIL Images.
"""
PIL = attempt_import_or_raise("PIL", "pillow")
requests = attempt_import_or_raise("requests", "requests")
images = self.sanitize_input(images)
pil_images = []
try:
for image in images:
if isinstance(image, str):
if image.startswith(("http://", "https://")):
response = requests.get(image, timeout=10)
response.raise_for_status()
pil_images.append(PIL.Image.open(io.BytesIO(response.content)))
else:
with PIL.Image.open(image) as im:
pil_images.append(im.copy())
elif isinstance(image, bytes):
pil_images.append(PIL.Image.open(io.BytesIO(image)))
else:
# Assume it's a PIL Image; will raise if invalid
pil_images.append(image)
except Exception as e:
raise ValueError(f"Failed to process image: {e}")
return pil_images
def generate_image_embeddings(self, images: IMAGES) -> List[List[List[float]]]:
prepared = self._prepare_images(images)
if len(prepared) == 0:
return []
return self._batched(prepared, self._encode_images)
"""
Generate embeddings for a batch of images.
"""
torch = attempt_import_or_raise("torch", "torch")
pil_images = self._prepare_images(images)
all_embeddings = []
for i in range(0, len(pil_images), self.batch_size):
batch_images = pil_images[i : i + self.batch_size]
batch_images = self._processor.process_images(batch_images).to(
self._model.device
)
with torch.no_grad():
image_embeddings = self._model(**batch_images)
all_embeddings.extend(self._process_embeddings(image_embeddings))
return all_embeddings
def compute_query_embeddings(
self, query: str, *args: Any, **kwargs: Any
self, query: Union[str, IMAGES], *args, **kwargs
) -> List[List[List[float]]]:
"""
Compute embeddings for a single user query (text only).
"""
if not isinstance(query, str):
raise ValueError("Late interaction queries must be text")
raise ValueError(
"Query must be a string, image to image search is not supported"
)
return self.generate_text_embeddings([query])
def compute_source_embeddings(
self, images: IMAGES, *args: Any, **kwargs: Any
self, images: IMAGES, *args, **kwargs
) -> List[List[List[float]]]:
"""
Compute embeddings for a batch of source images.
Parameters
----------
images : Union[str, bytes, List, pa.Array, pa.ChunkedArray, np.ndarray]
Batch of images (paths, URLs, bytes, or PIL Images).
"""
images = self.sanitize_input(images)
return self.generate_image_embeddings(images)
def ndims(self) -> int:
if self._vector_dim is None:
probe = self.generate_text_embeddings(["probe"])
if not probe or not probe[0]:
raise RuntimeError("Failed to determine embedding dimension")
self._vector_dim = len(probe[0][0])
return self._vector_dim
# Backwards compatibility: keep the historical "colpali" key
register("colpali")(MultimodalLateInteractionEmbeddings)
# Legacy class name kept for backwards compatibility in imports
ColPaliEmbeddings = MultimodalLateInteractionEmbeddings

View File

@@ -17,7 +17,6 @@ from lancedb.embeddings import (
EmbeddingFunctionRegistry,
)
from lancedb.embeddings.base import TextEmbeddingFunction
from lancedb.embeddings.colpali import MultimodalLateInteractionEmbeddings
from lancedb.embeddings.registry import get_registry, register
from lancedb.embeddings.utils import retry
from lancedb.pydantic import LanceModel, Vector
@@ -516,16 +515,3 @@ def test_openai_propagates_api_key(monkeypatch):
query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
assert len(actual.text) > 0
def test_multimodal_late_interaction_family_detection():
resolver = MultimodalLateInteractionEmbeddings._resolve_family
assert resolver("vidore/colSmol-256M", None) == "colsmol"
assert resolver("vidore/colqwen2.5-v0.2", None) == "colqwen2.5"
assert resolver("vidore/colqwen2-v1.0", None) == "colqwen2"
assert resolver("vidore/colpali-v1.3", None) == "colpali"
assert resolver("custom/model", None) == "colpali"
assert resolver("any/model", "colqwen2") == "colqwen2"
with pytest.raises(ValueError):
resolver("any/model", "unknown")

View File

@@ -4,7 +4,6 @@
import importlib
import io
import os
import lancedb
import numpy as np
import pandas as pd
@@ -34,24 +33,6 @@ try:
except Exception:
_imagebind = None
try:
import torch
except ImportError:
torch = None
HAS_ACCEL = bool(
torch
and (
torch.cuda.is_available()
or getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
)
)
RUN_HEAVY_VIDORE = os.getenv("LANCEDB_TEST_FULL_LATE_INTERACTION") in {"1", "true", "yes"}
HEAVY_SKIP = pytest.mark.skipif(
not (RUN_HEAVY_VIDORE and HAS_ACCEL),
reason="Set LANCEDB_TEST_FULL_LATE_INTERACTION=1 and run on GPU to exercise large vidore checkpoints",
)
@pytest.mark.slow
@pytest.mark.parametrize(
@@ -616,44 +597,21 @@ def test_voyageai_multimodal_embedding_text_function():
importlib.util.find_spec("colpali_engine") is None,
reason="colpali_engine not installed",
)
@pytest.mark.parametrize(
"model_name",
[
#pytest.param("vidore/colSmol-256M", id="colSmol"),
pytest.param(
"vidore/colqwen2-v1.0",
id="colQwen2",
#marks=HEAVY_SKIP,
),
pytest.param(
"vidore/colqwen2.5-v0.2",
id="colQwen2.5",
# marks=HEAVY_SKIP,
),
pytest.param(
"vidore/colpali-v1.3-merged",
id="colPali",
#marks=HEAVY_SKIP,
),
],
)
def test_multimodal_late_interaction_models(tmp_path, model_name):
def test_colpali(tmp_path):
import requests
from lancedb.pydantic import LanceModel, Vector
from lancedb.pydantic import LanceModel
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("multimodal-late-interaction").create(
model_name=model_name,
device="auto",
batch_size=1,
)
func = registry.get("colpali").create()
class MediaItems(LanceModel):
text: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
image_vectors: MultiVector(func.ndims()) = func.VectorField()
image_vectors: MultiVector(func.ndims()) = (
func.VectorField()
) # Multivector image embeddings
table = db.create_table("media", schema=MediaItems)
@@ -661,30 +619,41 @@ def test_multimodal_late_interaction_models(tmp_path, model_name):
"a cute cat playing with yarn",
"a puppy in a flower field",
"a red sports car on the highway",
"a vintage bicycle leaning against a wall",
"a plate of delicious pasta",
"fresh fruit salad in a bowl",
]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# Get images as bytes
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
)
result = (
# Test text-to-image search
image_results = (
table.search("fluffy companion", vector_column_name="image_vectors")
.limit(1)
.to_pydantic(MediaItems)[0]
)
assert any(keyword in result.text.lower() for keyword in ("cat", "puppy"))
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
# Verify multivector dimensions
first_row = table.to_arrow().to_pylist()[0]
assert len(first_row["image_vectors"]) > 1
assert len(first_row["image_vectors"][0]) == func.ndims()
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
assert len(first_row["image_vectors"][0]) == func.ndims(), (
"Vector dimension mismatch"
)
@pytest.mark.slow

View File

@@ -1,58 +0,0 @@
import requests
from lancedb.pydantic import LanceModel, Vector
import importlib
import io
import os
import lancedb
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector, MultiVector
db = lancedb.connect("~/.db")
registry = get_registry()
func = registry.get("multimodal-late-interaction").create(
model_name="vidore/colQwen2.5-v0.2",
device="auto",
batch_size=1,
)
class MediaItems(LanceModel):
text: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
image_vectors: MultiVector(func.ndims()) = func.VectorField()
table = db.create_table("media", schema=MediaItems, mode="overwrite")
texts = [
"a cute cat playing with yarn",
"a puppy in a flower field",
"a red sports car on the highway",
]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
]
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
)
result = (
table.search("fluffy companion", vector_column_name="image_vectors")
.limit(1)
.to_pydantic(MediaItems)[0]
)
assert any(keyword in result.text.lower() for keyword in ("cat", "puppy"))
first_row = table.to_arrow().to_pylist()[0]
assert len(first_row["image_vectors"]) > 1
assert len(first_row["image_vectors"][0]) == func.ndims()