Compare commits

..

3 Commits

Author SHA1 Message Date
Lei Xu
b57faf9835 rename gha task 2026-01-29 16:26:28 -08:00
Lei Xu
4800f31479 fixlint 2026-01-29 16:21:23 -08:00
Lei Xu
ec464ad01e remove pydantic 1 support 2026-01-29 16:18:56 -08:00
22 changed files with 91 additions and 670 deletions

View File

@@ -25,7 +25,7 @@ jobs:
lint:
name: "Lint"
timeout-minutes: 30
runs-on: "ubuntu-22.04"
runs-on: "ubuntu-24.04"
defaults:
run:
shell: bash
@@ -195,7 +195,7 @@ jobs:
# Make sure wheels are not included in the Rust cache
- name: Delete wheels
run: rm -rf target/wheels
pydantic1x:
min-deps:
timeout-minutes: 30
runs-on: "ubuntu-24.04"
defaults:
@@ -217,7 +217,6 @@ jobs:
python-version: "3.10"
- name: Install lancedb
run: |
pip install "pydantic<2"
pip install pyarrow==16
pip install --extra-index-url https://pypi.fury.io/lance-format/ --extra-index-url https://pypi.fury.io/lancedb/ -e .[tests]
pip install tantivy

View File

@@ -66,7 +66,7 @@ Follow the [Quickstart](https://lancedb.com/docs/quickstart/) doc to set up Lanc
| Python SDK | https://lancedb.github.io/lancedb/python/python/ |
| Typescript SDK | https://lancedb.github.io/lancedb/js/globals/ |
| Rust SDK | https://docs.rs/lancedb/latest/lancedb/index.html |
| REST API | https://docs.lancedb.com/api-reference/rest |
| REST API | https://docs.lancedb.com/api-reference/introduction |
## **Join Us and Contribute**

View File

@@ -1,62 +0,0 @@
# VoyageAI Embeddings
Voyage AI provides cutting-edge embedding and rerankers.
Using voyageai API requires voyageai package, which can be installed using `pip install voyageai`. Voyage AI embeddings are used to generate embeddings for text data. The embeddings can be used for various tasks like semantic search, clustering, and classification.
You also need to set the `VOYAGE_API_KEY` environment variable to use the VoyageAI API.
Supported models are:
**Voyage-4 Series (Latest)**
- voyage-4 (1024 dims, general-purpose and multilingual retrieval, 320K batch tokens)
- voyage-4-lite (1024 dims, optimized for latency and cost, 1M batch tokens)
- voyage-4-large (1024 dims, best retrieval quality, 120K batch tokens)
**Voyage-3 Series**
- voyage-3
- voyage-3-lite
**Domain-Specific Models**
- voyage-finance-2
- voyage-multilingual-2
- voyage-law-2
- voyage-code-2
Supported parameters (to be passed in `create` method) are:
| Parameter | Type | Default Value | Description |
|---|---|--------|---------|
| `name` | `str` | `None` | The model ID of the model to use. Supported base models for Text Embeddings: voyage-4, voyage-4-lite, voyage-4-large, voyage-3, voyage-3-lite, voyage-finance-2, voyage-multilingual-2, voyage-law-2, voyage-code-2 |
| `input_type` | `str` | `None` | Type of the input text. Default to None. Other options: query, document. |
| `truncation` | `bool` | `True` | Whether to truncate the input texts to fit within the context length. |
Usage Example:
```python
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunctionRegistry
voyageai = EmbeddingFunctionRegistry
.get_instance()
.get("voyageai")
.create(name="voyage-3")
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
data = [ { "text": "hello world" },
{ "text": "goodbye world" }]
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(data)
```

View File

@@ -28,7 +28,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version>
<lance-core.version>1.0.4</lance-core.version>
<lance-core.version>1.0.0-rc.2</lance-core.version>
<spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.28.0-beta.0"
current_version = "0.27.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.28.0-beta.0"
version = "0.27.1"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true

View File

@@ -8,7 +8,7 @@ dependencies = [
"overrides>=0.7; python_version<'3.12'",
"packaging",
"pyarrow>=16",
"pydantic>=1.10",
"pydantic>=2",
"tqdm>=4.27.0",
"lance-namespace>=0.3.2"
]

View File

@@ -22,12 +22,7 @@ class BackgroundEventLoop:
self.thread.start()
def run(self, future):
concurrent_future = asyncio.run_coroutine_threadsafe(future, self.loop)
try:
return concurrent_future.result()
except BaseException:
concurrent_future.cancel()
raise
return asyncio.run_coroutine_threadsafe(future, self.loop).result()
LOOP = BackgroundEventLoop()

View File

@@ -275,7 +275,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
"""
Convert image inputs to PIL Images.
"""
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
requests = attempt_import_or_raise("requests", "requests")
images = self.sanitize_input(images)
pil_images = []
@@ -285,12 +285,12 @@ class ColPaliEmbeddings(EmbeddingFunction):
if image.startswith(("http://", "https://")):
response = requests.get(image, timeout=10)
response.raise_for_status()
pil_images.append(PIL_Image.open(io.BytesIO(response.content)))
pil_images.append(PIL.Image.open(io.BytesIO(response.content)))
else:
with PIL_Image.open(image) as im:
with PIL.Image.open(image) as im:
pil_images.append(im.copy())
elif isinstance(image, bytes):
pil_images.append(PIL_Image.open(io.BytesIO(image)))
pil_images.append(PIL.Image.open(io.BytesIO(image)))
else:
# Assume it's a PIL Image; will raise if invalid
pil_images.append(image)

View File

@@ -77,8 +77,8 @@ class JinaEmbeddings(EmbeddingFunction):
if isinstance(inputs, list):
inputs = inputs
else:
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(inputs, PIL_Image.Image):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(inputs, PIL.Image.Image):
inputs = [inputs]
return inputs
@@ -89,13 +89,13 @@ class JinaEmbeddings(EmbeddingFunction):
elif isinstance(image, (str, Path)):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if parsed.scheme == "file":
pil_image = PIL_Image.open(parsed.path)
pil_image = PIL.Image.open(parsed.path)
elif parsed.scheme == "":
pil_image = PIL_Image.open(image if os.name == "nt" else parsed.path)
pil_image = PIL.Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"):
pil_image = PIL_Image.open(io.BytesIO(url_retrieve(image)))
pil_image = PIL.Image.open(io.BytesIO(url_retrieve(image)))
else:
raise NotImplementedError("Only local and http(s) urls are supported")
buffered = io.BytesIO()
@@ -103,9 +103,9 @@ class JinaEmbeddings(EmbeddingFunction):
image_bytes = buffered.getvalue()
image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")}
else:
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, PIL_Image.Image):
if isinstance(image, PIL.Image.Image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
@@ -136,9 +136,9 @@ class JinaEmbeddings(EmbeddingFunction):
elif isinstance(query, (Path, bytes)):
return [self.generate_image_embedding(query)]
else:
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL_Image.Image):
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError(

View File

@@ -71,8 +71,8 @@ class OpenClipEmbeddings(EmbeddingFunction):
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(query, PIL_Image.Image):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("OpenClip supports str or PIL Image as query")
@@ -145,20 +145,20 @@ class OpenClipEmbeddings(EmbeddingFunction):
return self._encode_and_normalize_image(image)
def _to_pil(self, image: Union[str, bytes]):
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, bytes):
return PIL_Image.open(io.BytesIO(image))
if isinstance(image, PIL_Image.Image):
return PIL.Image.open(io.BytesIO(image))
if isinstance(image, PIL.Image.Image):
return image
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
if parsed.scheme == "file":
return PIL_Image.open(parsed.path)
return PIL.Image.open(parsed.path)
elif parsed.scheme == "":
return PIL_Image.open(image if os.name == "nt" else parsed.path)
return PIL.Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"):
return PIL_Image.open(io.BytesIO(url_retrieve(image)))
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
else:
raise NotImplementedError("Only local and http(s) urls are supported")

View File

@@ -56,8 +56,8 @@ class SigLipEmbeddings(EmbeddingFunction):
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(query, PIL_Image.Image):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("SigLIP supports str or PIL Image as query")
@@ -127,21 +127,21 @@ class SigLipEmbeddings(EmbeddingFunction):
return image_features.cpu().detach().numpy().squeeze()
def _to_pil(self, image: Union[str, bytes, "PIL.Image.Image"]):
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(image, PIL_Image.Image):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, PIL.Image.Image):
return image.convert("RGB") if image.mode != "RGB" else image
elif isinstance(image, bytes):
return PIL_Image.open(io.BytesIO(image)).convert("RGB")
return PIL.Image.open(io.BytesIO(image)).convert("RGB")
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
if parsed.scheme == "file":
return PIL_Image.open(parsed.path).convert("RGB")
return PIL.Image.open(parsed.path).convert("RGB")
elif parsed.scheme == "":
path = image if os.name == "nt" else parsed.path
return PIL_Image.open(path).convert("RGB")
return PIL.Image.open(path).convert("RGB")
elif parsed.scheme.startswith("http"):
image_bytes = url_retrieve(image)
return PIL_Image.open(io.BytesIO(image_bytes)).convert("RGB")
return PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB")
else:
raise NotImplementedError("Only local and http(s) urls are supported")
else:

View File

@@ -21,9 +21,6 @@ if TYPE_CHECKING:
# Token limits for different VoyageAI models
VOYAGE_TOTAL_TOKEN_LIMITS = {
"voyage-4": 320_000,
"voyage-4-lite": 1_000_000,
"voyage-4-large": 120_000,
"voyage-context-3": 32_000,
"voyage-3.5-lite": 1_000_000,
"voyage-3.5": 320_000,
@@ -64,7 +61,7 @@ def is_video_path(path: Path) -> bool:
def transform_input(input_data: Union[str, bytes, Path]):
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(input_data, str):
if is_valid_url(input_data):
if is_video_url(input_data):
@@ -73,7 +70,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
content = {"type": "image_url", "image_url": input_data}
else:
content = {"type": "text", "text": input_data}
elif isinstance(input_data, PIL_Image.Image):
elif isinstance(input_data, PIL.Image.Image):
buffered = BytesIO()
input_data.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -82,7 +79,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
"image_base64": "data:image/jpeg;base64," + img_str,
}
elif isinstance(input_data, bytes):
img = PIL_Image.open(BytesIO(input_data))
img = PIL.Image.open(BytesIO(input_data))
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -101,7 +98,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
"video_base64": video_str,
}
else:
img = PIL_Image.open(input_data)
img = PIL.Image.open(input_data)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -119,8 +116,8 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
"""
Sanitize the input to the embedding function.
"""
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(inputs, (str, bytes, Path, PIL_Image.Image)):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)):
inputs = [inputs]
elif isinstance(inputs, list):
pass # Already a list, use as-is
@@ -133,7 +130,7 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
f"Input type {type(inputs)} not allowed with multimodal model."
)
if not all(isinstance(x, (str, bytes, Path, PIL_Image.Image)) for x in inputs):
if not all(isinstance(x, (str, bytes, Path, PIL.Image.Image)) for x in inputs):
raise ValueError("Each input should be either str, bytes, Path or Image.")
return [transform_input(i) for i in inputs]
@@ -170,9 +167,6 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
name: str
The name of the model to use. List of acceptable models:
* voyage-4 (1024 dims, general-purpose and multilingual retrieval)
* voyage-4-lite (1024 dims, optimized for latency and cost)
* voyage-4-large (1024 dims, best retrieval quality)
* voyage-context-3
* voyage-3.5
* voyage-3.5-lite
@@ -221,9 +215,6 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
_FLEXIBLE_DIM_MODELS: ClassVar[list] = ["voyage-multimodal-3.5"]
_VALID_DIMENSIONS: ClassVar[list] = [256, 512, 1024, 2048]
text_embedding_models: list = [
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
"voyage-3.5",
"voyage-3.5-lite",
"voyage-3",
@@ -261,9 +252,6 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
elif self.name == "voyage-code-2":
return 1536
elif self.name in [
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
"voyage-context-3",
"voyage-3.5",
"voyage-3.5-lite",

View File

@@ -6,7 +6,6 @@
from __future__ import annotations
import inspect
import sys
import types
from abc import ABC, abstractmethod
from datetime import date, datetime
@@ -141,14 +140,6 @@ def Vector(
raise TypeError("A list of numbers or numpy.ndarray is needed")
return cls(v)
if PYDANTIC_VERSION.major < 2:
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["items"] = {"type": "number"}
field_schema["maxItems"] = dim
field_schema["minItems"] = dim
return FixedSizeList
@@ -226,26 +217,14 @@ def MultiVector(
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v1
@classmethod
def validate(cls, v):
if not isinstance(v, (list, range)):
raise TypeError("A list of vectors is needed")
for vec in v:
if not isinstance(vec, (list, range, np.ndarray)) or len(vec) != dim:
raise TypeError(f"Each vector must be a list of {dim} numbers")
return cls(v)
if PYDANTIC_VERSION.major < 2:
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["items"] = {
"type": "array",
"items": {"type": "number"},
"minItems": dim,
"maxItems": dim,
}
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["items"] = {
"type": "array",
"items": {"type": "number"},
"minItems": dim,
"maxItems": dim,
}
return MultiVectorList
@@ -281,20 +260,10 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
)
if PYDANTIC_VERSION.major < 2:
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
return [
_pydantic_to_field(name, field) for name, field in model.__fields__.items()
]
else:
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
return [
_pydantic_to_field(name, field)
for name, field in model.model_fields.items()
]
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
return [
_pydantic_to_field(name, field) for name, field in model.model_fields.items()
]
def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
@@ -334,7 +303,7 @@ def _unwrap_optional_annotation(annotation: Any) -> Any | None:
non_none = [arg for arg in args if arg is not type(None)]
if len(non_none) == 1 and len(non_none) != len(args):
return non_none[0]
elif sys.version_info >= (3, 10) and isinstance(annotation, types.UnionType):
elif isinstance(annotation, types.UnionType):
args = annotation.__args__
non_none = [arg for arg in args if arg is not type(None)]
if len(non_none) == 1 and len(non_none) != len(args):
@@ -367,7 +336,7 @@ def is_nullable(field: FieldInfo) -> bool:
if origin == Union:
if any(typ is type(None) for typ in args):
return True
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
elif isinstance(field.annotation, types.UnionType):
args = field.annotation.__args__
for typ in args:
if typ is type(None):
@@ -474,8 +443,6 @@ class LanceModel(pydantic.BaseModel):
@classmethod
def safe_get_fields(cls):
if PYDANTIC_VERSION.major < 2:
return cls.__fields__
return cls.model_fields
@classmethod
@@ -518,18 +485,8 @@ def get_extras(field_info: FieldInfo, key: str) -> Any:
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
if PYDANTIC_VERSION.major < 2:
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
"""
Convert a Pydantic model to a dictionary.
"""
return model.dict()
else:
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
"""
Convert a Pydantic model to a dictionary.
"""
return model.model_dump()
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
"""
Convert a Pydantic model to a dictionary.
"""
return model.model_dump()

View File

@@ -517,36 +517,19 @@ def test_ollama_embedding(tmp_path):
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
@pytest.mark.parametrize(
"model_name,expected_dims",
[
("voyage-3", 1024),
("voyage-4", 1024),
("voyage-4-lite", 1024),
("voyage-4-large", 1024),
],
)
def test_voyageai_embedding_function(model_name, expected_dims, tmp_path):
"""Integration test for VoyageAI text embedding models with real API calls."""
voyageai = get_registry().get("voyageai").create(name=model_name, max_retries=0)
def test_voyageai_embedding_function():
voyageai = get_registry().get("voyageai").create(name="voyage-3", max_retries=0)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
assert voyageai.ndims() == expected_dims, (
f"{model_name} should have {expected_dims} dimensions"
)
# Test search functionality
result = tbl.search("hello").limit(1).to_pandas()
assert result["text"][0] == "hello world"
@pytest.mark.slow

View File

@@ -438,15 +438,11 @@ def test_filter_with_splits(mem_db):
row_count = permutation_tbl.count_rows()
assert row_count == 67
# Verify the permutation table only contains row_id and split_id
assert set(permutation_tbl.schema.names) == {"row_id", "split_id"}
row_ids = permutation_tbl.search(None).to_arrow().to_pydict()["row_id"]
data = tbl.take_row_ids(row_ids).to_arrow().to_pydict()
data = permutation_tbl.search(None).to_arrow().to_pydict()
categories = data["category"]
# All categories should be A or B
assert all(cat in ("A", "B") for cat in categories)
assert all(cat in ["A", "B"] for cat in categories)
def test_filter_with_shuffle(mem_db):

View File

@@ -8,7 +8,7 @@ import http.server
import json
import threading
import time
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import uuid
from packaging.version import Version
@@ -1203,22 +1203,3 @@ async def test_header_provider_overrides_static_headers():
extra_headers={"X-API-Key": "static-key", "X-Extra": "extra-value"},
) as db:
await db.table_names()
@pytest.mark.parametrize("exception", [KeyboardInterrupt, SystemExit, GeneratorExit])
def test_background_loop_cancellation(exception):
"""Test that BackgroundEventLoop.run() cancels the future on interrupt."""
from lancedb.background_loop import BackgroundEventLoop
mock_future = MagicMock()
mock_future.result.side_effect = exception()
with (
patch.object(BackgroundEventLoop, "__init__", return_value=None),
patch("asyncio.run_coroutine_threadsafe", return_value=mock_future),
):
loop = BackgroundEventLoop()
loop.loop = MagicMock()
with pytest.raises(exception):
loop.run(None)
mock_future.cancel.assert_called_once()

View File

@@ -1,108 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Unit tests for VoyageAI embedding function.
These tests verify model registration and configuration without requiring API calls.
"""
import pytest
from unittest.mock import MagicMock, patch
from lancedb.embeddings import get_registry
@pytest.fixture(autouse=True)
def reset_voyageai_client():
"""Reset VoyageAI client before and after each test to avoid state pollution."""
from lancedb.embeddings.voyageai import VoyageAIEmbeddingFunction
VoyageAIEmbeddingFunction.client = None
yield
VoyageAIEmbeddingFunction.client = None
class TestVoyageAIModelRegistration:
"""Tests for VoyageAI model registration and configuration."""
@pytest.fixture
def mock_voyageai_client(self):
"""Mock VoyageAI client to avoid API calls."""
with patch.dict("os.environ", {"VOYAGE_API_KEY": "test-key"}):
with patch("lancedb.embeddings.voyageai.attempt_import_or_raise") as mock:
mock_client = MagicMock()
mock_voyageai = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock.return_value = mock_voyageai
yield mock_client
def test_voyageai_registered(self):
"""Test that VoyageAI is registered in the embedding function registry."""
registry = get_registry()
assert registry.get("voyageai") is not None
@pytest.mark.parametrize(
"model_name,expected_dims",
[
# Voyage-4 series (all 1024 dims)
("voyage-4", 1024),
("voyage-4-lite", 1024),
("voyage-4-large", 1024),
# Voyage-3 series
("voyage-3", 1024),
("voyage-3-lite", 512),
# Domain-specific models
("voyage-finance-2", 1024),
("voyage-multilingual-2", 1024),
("voyage-law-2", 1024),
("voyage-code-2", 1536),
# Multimodal
("voyage-multimodal-3", 1024),
],
)
def test_model_dimensions(self, model_name, expected_dims, mock_voyageai_client):
"""Test that each model returns the correct dimensions."""
registry = get_registry()
func = registry.get("voyageai").create(name=model_name)
assert func.ndims() == expected_dims, (
f"Model {model_name} should have {expected_dims} dimensions"
)
def test_unsupported_model_raises_error(self, mock_voyageai_client):
"""Test that unsupported models raise ValueError."""
registry = get_registry()
func = registry.get("voyageai").create(name="unsupported-model")
with pytest.raises(ValueError, match="not supported"):
func.ndims()
@pytest.mark.parametrize(
"model_name",
[
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
],
)
def test_voyage4_models_are_text_models(self, model_name, mock_voyageai_client):
"""Test that voyage-4 models are classified as text models (not multimodal)."""
registry = get_registry()
func = registry.get("voyageai").create(name=model_name)
assert not func._is_multimodal_model(model_name), (
f"{model_name} should be a text model, not multimodal"
)
def test_voyage4_models_in_text_embedding_list(self, mock_voyageai_client):
"""Test that voyage-4 models are in the text_embedding_models list."""
registry = get_registry()
func = registry.get("voyageai").create(name="voyage-4")
assert "voyage-4" in func.text_embedding_models
assert "voyage-4-lite" in func.text_embedding_models
assert "voyage-4-large" in func.text_embedding_models
def test_voyage4_models_not_in_multimodal_list(self, mock_voyageai_client):
"""Test that voyage-4 models are NOT in the multimodal_embedding_models list."""
registry = get_registry()
func = registry.get("voyageai").create(name="voyage-4")
assert "voyage-4" not in func.multimodal_embedding_models
assert "voyage-4-lite" not in func.multimodal_embedding_models
assert "voyage-4-large" not in func.multimodal_embedding_models

View File

@@ -251,36 +251,8 @@ impl CreateTableBuilder<false> {
/// Execute the create table operation
pub async fn execute(self) -> Result<Table> {
let parent = self.parent.clone();
let embedding_registry = self.embedding_registry.clone();
let request = self.into_request()?;
Ok(Table::new_with_embedding_registry(
parent.create_table(request).await?,
parent,
embedding_registry,
))
}
fn into_request(self) -> Result<CreateTableRequest> {
if self.embeddings.is_empty() {
return Ok(self.request);
}
let CreateTableData::Empty(table_def) = self.request.data else {
unreachable!("CreateTableBuilder<false> should always have Empty data")
};
let schema = table_def.schema.clone();
let empty_batch = arrow_array::RecordBatch::new_empty(schema.clone());
let reader = Box::new(std::iter::once(Ok(empty_batch)).collect::<Vec<_>>());
let reader = arrow_array::RecordBatchIterator::new(reader.into_iter(), schema);
let with_embeddings = WithEmbeddings::new(reader, self.embeddings);
let table_definition = with_embeddings.table_definition()?;
Ok(CreateTableRequest {
data: CreateTableData::Empty(table_definition),
..self.request
})
let table = parent.create_table(self.request).await?;
Ok(Table::new(table, parent))
}
}
@@ -1720,128 +1692,4 @@ mod tests {
let cloned_count = cloned_table.count_rows(None).await.unwrap();
assert_eq!(source_count, cloned_count);
}
#[tokio::test]
async fn test_create_empty_table_with_embeddings() {
use crate::embeddings::{EmbeddingDefinition, EmbeddingFunction};
use arrow_array::{
Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
};
use std::borrow::Cow;
#[derive(Debug, Clone)]
struct MockEmbedding {
dim: usize,
}
impl EmbeddingFunction for MockEmbedding {
fn name(&self) -> &str {
"test_embedding"
}
fn source_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
self.dim as i32,
true,
)))
}
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
let len = source.len();
let values = vec![1.0f32; len * self.dim];
let values = Arc::new(Float32Array::from(values));
let field = Arc::new(Field::new("item", DataType::Float32, true));
Ok(Arc::new(FixedSizeListArray::new(
field,
self.dim as i32,
values,
None,
)))
}
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let embed_func = Arc::new(MockEmbedding { dim: 128 });
db.embedding_registry()
.register("test_embedding", embed_func.clone())
.unwrap();
let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)]));
let ed = EmbeddingDefinition {
source_column: "name".to_owned(),
dest_column: Some("name_embedding".to_owned()),
embedding_name: "test_embedding".to_owned(),
};
let table = db
.create_empty_table("test", schema)
.mode(CreateTableMode::Overwrite)
.add_embedding(ed)
.unwrap()
.execute()
.await
.unwrap();
let table_schema = table.schema().await.unwrap();
assert!(table_schema.column_with_name("name").is_some());
assert!(table_schema.column_with_name("name_embedding").is_some());
let embedding_field = table_schema.field_with_name("name_embedding").unwrap();
assert_eq!(
embedding_field.data_type(),
&DataType::new_fixed_size_list(DataType::Float32, 128, true)
);
let input_schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)]));
let input_batch = RecordBatch::try_new(
input_schema.clone(),
vec![Arc::new(StringArray::from(vec![
Some("Alice"),
Some("Bob"),
Some("Charlie"),
]))],
)
.unwrap();
let input_reader = Box::new(RecordBatchIterator::new(
vec![Ok(input_batch)].into_iter(),
input_schema,
));
table.add(input_reader).execute().await.unwrap();
let results = table
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_rows(), 3);
assert!(batch.column_by_name("name_embedding").is_some());
let embedding_col = batch
.column_by_name("name_embedding")
.unwrap()
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap();
assert_eq!(embedding_col.len(), 3);
}
}

View File

@@ -12,8 +12,6 @@ use datafusion_common::hash_utils::create_hashes;
use futures::{StreamExt, TryStreamExt};
use lance_arrow::SchemaExt;
use lance_core::ROW_ID;
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
dataloader::{
@@ -362,15 +360,11 @@ impl Splitter {
pub fn project(&self, query: Query) -> Query {
match &self.strategy {
SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![
(SPLIT_ID_COLUMN.to_string(), calculation.clone()),
(ROW_ID.to_string(), ROW_ID.to_string()),
])),
SplitStrategy::Hash { columns, .. } => {
let mut cols = columns.clone();
cols.push(ROW_ID.to_string());
query.select(Select::Columns(cols))
}
SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![(
SPLIT_ID_COLUMN.to_string(),
calculation.clone(),
)])),
SplitStrategy::Hash { columns, .. } => query.select(Select::Columns(columns.clone())),
_ => query,
}
}

View File

@@ -79,11 +79,10 @@ use self::merge::MergeInsertBuilder;
pub mod datafusion;
pub(crate) mod dataset;
pub mod delete;
pub mod merge;
use crate::index::waiter::wait_for_index;
pub use chrono::Duration;
pub use delete::DeleteResult;
use futures::future::{join_all, Either};
pub use lance::dataset::optimize::CompactionOptions;
pub use lance::dataset::refs::{TagContents, Tags as LanceTags};
@@ -447,6 +446,15 @@ pub struct AddResult {
pub version: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct DeleteResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct MergeResult {
// The commit version associated with the operation.
@@ -3070,8 +3078,11 @@ impl BaseTable for NativeTable {
/// Delete rows from the table
async fn delete(&self, predicate: &str) -> Result<DeleteResult> {
// Delegate to the submodule implementation
delete::execute_delete(self, predicate).await
let mut dataset = self.dataset.get_mut().await?;
dataset.delete(predicate).await?;
Ok(DeleteResult {
version: dataset.version().version,
})
}
async fn tags(&self) -> Result<Box<dyn Tags + '_>> {

View File

@@ -1,161 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use serde::{Deserialize, Serialize};
use super::NativeTable;
use crate::Result;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct DeleteResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
/// Internal implementation of the delete logic
///
/// This logic was moved from NativeTable::delete to keep table.rs clean.
pub(crate) async fn execute_delete(table: &NativeTable, predicate: &str) -> Result<DeleteResult> {
// We access the dataset from the table. Since this is in the same module hierarchy (super),
// and 'dataset' is pub(crate), we can access it.
let mut dataset = table.dataset.get_mut().await?;
// Perform the actual delete on the Lance dataset
dataset.delete(predicate).await?;
// Return the result with the new version
Ok(DeleteResult {
version: dataset.version().version,
})
}
#[cfg(test)]
mod tests {
use crate::connect;
use arrow_array::{record_batch, Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
use crate::query::ExecutableQuery;
use futures::TryStreamExt;
#[tokio::test]
async fn test_delete_simple() {
let conn = connect("memory://").execute().await.unwrap();
// 1. Create a table with values 0 to 9
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))],
)
.unwrap();
let table = conn
.create_table(
"test_delete",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// 2. Verify initial state
assert_eq!(table.count_rows(None).await.unwrap(), 10);
// 3. Execute Delete (removes values > 5)
table.delete("i > 5").await.unwrap();
// 4. Verify results
assert_eq!(table.count_rows(None).await.unwrap(), 6); // 0, 1, 2, 3, 4, 5 remain
// 5. Verify specific data consistency
let batches = table
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &batches[0];
let array = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
// Ensure no value > 5 exists
for val in array.iter() {
assert!(val.unwrap() <= 5);
}
}
#[tokio::test]
async fn rows_removed_schema_same() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(
("id", Int32, [1, 2, 3, 4, 5]),
("name", Utf8, ["a", "b", "c", "d", "e"])
)
.unwrap();
let original_schema = batch.schema();
let table = conn
.create_table(
"test_delete_all",
RecordBatchIterator::new(vec![Ok(batch)], original_schema.clone()),
)
.execute()
.await
.unwrap();
table.delete("true").await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 0);
let current_schema = table.schema().await.unwrap();
//check if the original schema is the same as current
assert_eq!(current_schema, original_schema);
}
#[tokio::test]
async fn test_delete_false_increments_version() {
let conn = connect("memory://").execute().await.unwrap();
// Create a table with 5 rows
let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_delete_noop",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Capture the initial state (Rows = 5, Version = 1)
let initial_rows = table.count_rows(None).await.unwrap();
let initial_version = table.version().await.unwrap();
assert_eq!(initial_rows, 5);
table.delete("false").await.unwrap();
// Rows should still be 5
let current_rows = table.count_rows(None).await.unwrap();
assert_eq!(
current_rows, initial_rows,
"Data should not change when predicate is false"
);
// version check
let current_version = table.version().await.unwrap();
assert!(
current_version > initial_version,
"Table version must increment after delete operation"
);
}
}