Compare commits

..

9 Commits

Author SHA1 Message Date
Colin Patrick McCabe
2f6d525802 fix: support exist_ok in RemoteDBConnection.create_table (#2901)
RemoteDBConnection should support passing exist_ok to create_table, just
like LanceDBConnection (the non-remote form) does. It can support this
by passing 'exist_ok' as the mode parameter.
2026-01-07 12:29:45 -08:00
Qichao Chu
4494eb9e56 feat: parallelize embedding computations (#2896)
Implement parallel execution of multiple embedding functions using
std:🧵:scope to improve performance when a table has multiple
embedding columns.

Key changes:
- Add compute_embeddings_parallel() helper method to WithEmbeddings
- Use fast path for single embeddings (no threading overhead)
- Use scoped threads for parallel execution of multiple embeddings
- Add comprehensive tests including parallelization timing verification
- Update WithEmbeddings documentation

Performance improvements:
- I/O-bound embeddings (OpenAI, Bedrock): High benefit from concurrent
API calls
- CPU-bound embeddings (sentence-transformers): Medium benefit from core
utilization
- Single embedding: No overhead (fast path)

Closes TODO on line 266 in rust/lancedb/src/embeddings.rs
2026-01-06 14:35:56 -08:00
LuQQiu
d67a8743ba feat: support remote ivf rq (#2863) 2026-01-02 15:35:33 -08:00
Chenghao Lyu
46fcbbc1e3 fix(python): require explicit region for S3 buckets with dots (#2892)
When region is not specific in the s3 path, `resolve_s3_region` from
"lance-format" project (see [here][1]) will resolve the region by
calling `resolve_bucket_region`, which is a function from the
"arrow-rs-object-store" project expecting [virtual-hosted-style
URLs][1]. When there are dot (".") in the virtual-hosted-style URLs, it
breaks automatic region detection. See more details in the issue
description:
https://github.com/lancedb/lancedb/issues/1898#issuecomment-3690142427

This PR add early validation in connect() and connect_async() to raise a
clear error with instructions when the region is not specified for such
buckets.


[1]:
https://github.com/lance-format/lance/blob/v2.0.0-beta.4/rust/lance-io/src/object_store/providers/aws.rs#L197
[2]:
eedbf3d7d8/src/aws/resolve.rs (L52C5-L52C65)
[3]:
https://docs.aws.amazon.com/AmazonS3/latest/userguide/VirtualHosting.html#virtual-hosted-style-access

Fixes #1898
2026-01-02 15:35:22 -08:00
Prashanth Rao
ff53b76ac0 docs: address styling and aesthetics issues with banner and links (#2878)
Aesthetic and styling fixes to the SDK reference docs:
- [x] Improve readability of LanceDB  in the header
- [x] Make header more compact, and consistent in gradient color with
the main website/docs
- [x] Updated favicon to match with the docs page
- [x] Enable permalink display to allow users to get anchor links to
each function/method
- [x] Point readers to the main docs at
[docs.lancedb.com](https://docs.lancedb.com)
2026-01-02 15:15:35 -08:00
fzowl
2adb10e6a8 feat: voyage-multimodal-3.5 (#2887)
voyage-multimodal-3.5 support (text, image and video embeddings)
2026-01-02 15:14:52 -08:00
Colin Patrick McCabe
ac164c352b test: convert test_table_names to test both remote and local (#2888)
Convert test_table_names to test both remote and local connections.

This PR also includes some miscellaneous improvements in
src/test_utils/connection.rs. It starts a thread to drain stdout from
the server process. It adds the
PRINT_LANCEDB_TEST_CONNECTION_SCRIPT_OUTPUT environment variable, which
optionally displays server stdout.

Fix a bash conditional in run_with_test_connection.sh.
2026-01-02 15:08:44 -08:00
Lance Release
8bcac7e372 Bump version: 0.23.1-beta.2 → 0.23.1 2026-01-02 17:39:19 +00:00
Lance Release
e496184ab2 Bump version: 0.23.1-beta.1 → 0.23.1-beta.2 2026-01-02 17:38:54 +00:00
35 changed files with 929 additions and 97 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.23.1-beta.1"
current_version = "0.23.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

6
Cargo.lock generated
View File

@@ -4987,7 +4987,7 @@ dependencies = [
[[package]]
name = "lancedb"
version = "0.23.1-beta.1"
version = "0.23.1"
dependencies = [
"ahash",
"anyhow",
@@ -5066,7 +5066,7 @@ dependencies = [
[[package]]
name = "lancedb-nodejs"
version = "0.23.1-beta.1"
version = "0.23.1"
dependencies = [
"arrow-array",
"arrow-ipc",
@@ -5086,7 +5086,7 @@ dependencies = [
[[package]]
name = "lancedb-python"
version = "0.26.1-beta.1"
version = "0.26.1"
dependencies = [
"arrow",
"async-trait",

View File

@@ -16,7 +16,7 @@ check_command_exists() {
}
if [[ ! -e ./lancedb ]]; then
if [[ -v SOPHON_READ_TOKEN ]]; then
if [[ x${SOPHON_READ_TOKEN} != "x" ]]; then
INPUT="lancedb-linux-x64"
gh release \
--repo lancedb/lancedb \

View File

@@ -11,7 +11,7 @@ watch:
theme:
name: "material"
logo: assets/logo.png
favicon: assets/logo.png
favicon: assets/favicon.ico
palette:
# Palette toggle for light mode
- scheme: lancedb
@@ -32,8 +32,6 @@ theme:
- content.tooltips
- toc.follow
- navigation.top
- navigation.tabs
- navigation.tabs.sticky
- navigation.footer
- navigation.tracking
- navigation.instant
@@ -115,12 +113,13 @@ markdown_extensions:
emoji_index: !!python/name:material.extensions.emoji.twemoji
emoji_generator: !!python/name:material.extensions.emoji.to_svg
- markdown.extensions.toc:
baselevel: 1
permalink: ""
toc_depth: 3
permalink: true
permalink_title: Anchor link to this section
nav:
- API reference:
- Overview: index.md
- Documentation:
- SDK Reference: index.md
- Python: python/python.md
- Javascript/TypeScript: js/globals.md
- Java: java/java.md

BIN
docs/src/assets/favicon.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

View File

@@ -0,0 +1,111 @@
# VoyageAI Embeddings : Multimodal
VoyageAI embeddings can also be used to embed both text and image data, only some of the models support image data and you can check the list
under [https://docs.voyageai.com/docs/multimodal-embeddings](https://docs.voyageai.com/docs/multimodal-embeddings)
Supported multimodal models:
- `voyage-multimodal-3` - 1024 dimensions (text + images)
- `voyage-multimodal-3.5` - Flexible dimensions (256, 512, 1024 default, 2048). Supports text, images, and video.
### Video Support (voyage-multimodal-3.5)
The `voyage-multimodal-3.5` model supports video input through:
- Video URLs (`.mp4`, `.webm`, `.mov`, `.avi`, `.mkv`, `.m4v`, `.gif`)
- Video file paths
Constraints: Max 20MB video size.
Supported parameters (to be passed in `create` method) are:
| Parameter | Type | Default Value | Description |
|---|---|-------------------------|-------------------------------------------|
| `name` | `str` | `"voyage-multimodal-3"` | The model ID of the VoyageAI model to use |
| `output_dimension` | `int` | `None` | Output dimension for voyage-multimodal-3.5. Valid: 256, 512, 1024, 2048 |
Usage Example:
```python
import base64
import os
from io import BytesIO
import requests
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
import pandas as pd
os.environ['VOYAGE_API_KEY'] = 'YOUR_VOYAGE_API_KEY'
db = lancedb.connect(".lancedb")
func = get_registry().get("voyageai").create(name="voyage-multimodal-3")
def image_to_base64(image_bytes: bytes):
buffered = BytesIO(image_bytes)
img_str = base64.b64encode(buffered.getvalue())
return img_str.decode("utf-8")
class Images(LanceModel):
label: str
image_uri: str = func.SourceField() # image uri as the source
image_bytes: str = func.SourceField() # image bytes base64 encoded as the source
vector: Vector(func.ndims()) = func.VectorField() # vector column
vec_from_bytes: Vector(func.ndims()) = func.VectorField() # Another vector column
if "images" in db.table_names():
db.drop_table("images")
table = db.create_table("images", schema=Images)
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# get each uri as bytes
images_bytes = [image_to_base64(requests.get(uri).content) for uri in uris]
table.add(
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": images_bytes})
)
```
Now we can search using text from both the default vector column and the custom vector column
```python
# text search
actual = table.search("man's best friend", "vec_from_bytes").limit(1).to_pydantic(Images)[0]
print(actual.label) # prints "dog"
frombytes = (
table.search("man's best friend", vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
print(frombytes.label)
```
Because we're using a multi-modal embedding function, we can also search using images
```python
# image search
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
image_bytes = requests.get(query_image_uri).content
query_image = Image.open(BytesIO(image_bytes))
actual = table.search(query_image, "vec_from_bytes").limit(1).to_pydantic(Images)[0]
print(actual.label == "dog")
# image search using a custom vector column
other = (
table.search(query_image, vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
print(actual.label)
```

View File

@@ -1,8 +1,12 @@
# API Reference
# SDK Reference
This page contains the API reference for the SDKs supported by the LanceDB team.
This site contains the API reference for the client SDKs supported by [LanceDB](https://lancedb.com).
- [Python](python/python.md)
- [JavaScript/TypeScript](js/globals.md)
- [Java](java/java.md)
- [Rust](https://docs.rs/lancedb/latest/lancedb/index.html)
- [Rust](https://docs.rs/lancedb/latest/lancedb/index.html)
!!! info "LanceDB Documentation"
If you're looking for the full documentation of LanceDB, visit [docs.lancedb.com](https://docs.lancedb.com).

View File

@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
<dependency>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-core</artifactId>
<version>0.23.1-beta.1</version>
<version>0.23.1</version>
</dependency>
```

View File

@@ -85,17 +85,26 @@
/* Header gradient (only header area) */
.md-header {
background: linear-gradient(90deg, #3B2E58 0%, #F0B7C1 45%, #E55A2B 100%);
background: linear-gradient(90deg, #e4d8f8 0%, #F0B7C1 45%, #E55A2B 100%);
box-shadow: inset 0 1px 0 rgba(255,255,255,0.08), 0 1px 0 rgba(0,0,0,0.08);
}
/* Improve brand title contrast on the lavender side */
.md-header__title,
.md-header__topic,
.md-header__title .md-ellipsis,
.md-header__topic .md-ellipsis {
color: #2b1b3a;
text-shadow: 0 1px 0 rgba(255, 255, 255, 0.25);
}
/* Same colors as header for tabs (that hold the text) */
.md-tabs {
background: linear-gradient(90deg, #3B2E58 0%, #F0B7C1 45%, #E55A2B 100%);
background: linear-gradient(90deg, #e4d8f8 0%, #F0B7C1 45%, #E55A2B 100%);
}
/* Dark scheme variant */
[data-md-color-scheme="slate"] .md-header,
[data-md-color-scheme="slate"] .md-tabs {
background: linear-gradient(90deg, #3B2E58 0%, #F0B7C1 45%, #E55A2B 100%);
background: linear-gradient(90deg, #e4d8f8 0%, #F0B7C1 45%, #E55A2B 100%);
}

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.23.1-beta.1</version>
<version>0.23.1-final.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.23.1-beta.1</version>
<version>0.23.1-final.0</version>
<packaging>pom</packaging>
<name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description>

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.23.1-beta.1"
version = "0.23.1"
license.workspace = true
description.workspace = true
repository.workspace = true

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.23.1-beta.1",
"version": "0.23.1",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-x64",
"version": "0.23.1-beta.1",
"version": "0.23.1",
"os": ["darwin"],
"cpu": ["x64"],
"main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.23.1-beta.1",
"version": "0.23.1",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.23.1-beta.1",
"version": "0.23.1",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.23.1-beta.1",
"version": "0.23.1",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.23.1-beta.1",
"version": "0.23.1",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.23.1-beta.1",
"version": "0.23.1",
"os": [
"win32"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.23.1-beta.1",
"version": "0.23.1",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.23.1-beta.1",
"version": "0.23.1",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.23.1-beta.1",
"version": "0.23.1",
"cpu": [
"x64",
"arm64"

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.23.1-beta.1",
"version": "0.23.1",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -13,6 +13,7 @@ __version__ = importlib.metadata.version("lancedb")
from ._lancedb import connect as lancedb_connect
from .common import URI, sanitize_uri
from urllib.parse import urlparse
from .db import AsyncConnection, DBConnection, LanceDBConnection
from .io import StorageOptionsProvider
from .remote import ClientConfig
@@ -28,6 +29,39 @@ from .namespace import (
)
def _check_s3_bucket_with_dots(
uri: str, storage_options: Optional[Dict[str, str]]
) -> None:
"""
Check if an S3 URI has a bucket name containing dots and warn if no region
is specified. S3 buckets with dots cannot use virtual-hosted-style URLs,
which breaks automatic region detection.
See: https://github.com/lancedb/lancedb/issues/1898
"""
if not isinstance(uri, str) or not uri.startswith("s3://"):
return
parsed = urlparse(uri)
bucket = parsed.netloc
if "." not in bucket:
return
# Check if region is provided in storage_options
region_keys = {"region", "aws_region"}
has_region = storage_options and any(k in storage_options for k in region_keys)
if not has_region:
raise ValueError(
f"S3 bucket name '{bucket}' contains dots, which prevents automatic "
f"region detection. Please specify the region explicitly via "
f"storage_options={{'region': '<your-region>'}} or "
f"storage_options={{'aws_region': '<your-region>'}}. "
f"See https://github.com/lancedb/lancedb/issues/1898 for details."
)
def connect(
uri: URI,
*,
@@ -121,9 +155,11 @@ def connect(
storage_options=storage_options,
**kwargs,
)
_check_s3_bucket_with_dots(str(uri), storage_options)
if kwargs:
raise ValueError(f"Unknown keyword arguments: {kwargs}")
return LanceDBConnection(
uri,
read_consistency_interval=read_consistency_interval,
@@ -211,6 +247,8 @@ async def connect_async(
if isinstance(client_config, dict):
client_config = ClientConfig(**client_config)
_check_s3_bucket_with_dots(str(uri), storage_options)
return AsyncConnection(
await lancedb_connect(
sanitize_uri(uri),

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import base64
import os
from typing import ClassVar, TYPE_CHECKING, List, Union, Any, Generator
from typing import ClassVar, TYPE_CHECKING, List, Union, Any, Generator, Optional
from pathlib import Path
from urllib.parse import urlparse
@@ -45,11 +45,29 @@ def is_valid_url(text):
return False
VIDEO_EXTENSIONS = {".mp4", ".webm", ".mov", ".avi", ".mkv", ".m4v", ".gif"}
def is_video_url(url: str) -> bool:
"""Check if URL points to a video file based on extension."""
parsed = urlparse(url)
path = parsed.path.lower()
return any(path.endswith(ext) for ext in VIDEO_EXTENSIONS)
def is_video_path(path: Path) -> bool:
"""Check if file path is a video file based on extension."""
return path.suffix.lower() in VIDEO_EXTENSIONS
def transform_input(input_data: Union[str, bytes, Path]):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(input_data, str):
if is_valid_url(input_data):
content = {"type": "image_url", "image_url": input_data}
if is_video_url(input_data):
content = {"type": "video_url", "video_url": input_data}
else:
content = {"type": "image_url", "image_url": input_data}
else:
content = {"type": "text", "text": input_data}
elif isinstance(input_data, PIL.Image.Image):
@@ -70,14 +88,24 @@ def transform_input(input_data: Union[str, bytes, Path]):
"image_base64": "data:image/jpeg;base64," + img_str,
}
elif isinstance(input_data, Path):
img = PIL.Image.open(input_data)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
content = {
"type": "image_base64",
"image_base64": "data:image/jpeg;base64," + img_str,
}
if is_video_path(input_data):
# Read video file and encode as base64
with open(input_data, "rb") as f:
video_bytes = f.read()
video_str = base64.b64encode(video_bytes).decode("utf-8")
content = {
"type": "video_base64",
"video_base64": video_str,
}
else:
img = PIL.Image.open(input_data)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
content = {
"type": "image_base64",
"image_base64": "data:image/jpeg;base64," + img_str,
}
else:
raise ValueError("Each input should be either str, bytes, Path or Image.")
@@ -91,6 +119,8 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)):
inputs = [inputs]
elif isinstance(inputs, list):
pass # Already a list, use as-is
elif isinstance(inputs, pa.Array):
inputs = inputs.to_pylist()
elif isinstance(inputs, pa.ChunkedArray):
@@ -143,11 +173,16 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
* voyage-3
* voyage-3-lite
* voyage-multimodal-3
* voyage-multimodal-3.5
* voyage-finance-2
* voyage-multilingual-2
* voyage-law-2
* voyage-code-2
output_dimension: int, optional
The output dimension for models that support flexible dimensions.
Currently only voyage-multimodal-3.5 supports this feature.
Valid options: 256, 512, 1024 (default), 2048.
Examples
--------
@@ -175,7 +210,10 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
"""
name: str
output_dimension: Optional[int] = None
client: ClassVar = None
_FLEXIBLE_DIM_MODELS: ClassVar[list] = ["voyage-multimodal-3.5"]
_VALID_DIMENSIONS: ClassVar[list] = [256, 512, 1024, 2048]
text_embedding_models: list = [
"voyage-3.5",
"voyage-3.5-lite",
@@ -186,7 +224,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
"voyage-law-2",
"voyage-code-2",
]
multimodal_embedding_models: list = ["voyage-multimodal-3"]
multimodal_embedding_models: list = ["voyage-multimodal-3", "voyage-multimodal-3.5"]
contextual_embedding_models: list = ["voyage-context-3"]
def _is_multimodal_model(self, model_name: str):
@@ -198,6 +236,17 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
return model_name in self.contextual_embedding_models or "context" in model_name
def ndims(self):
# Handle flexible dimension models
if self.name in self._FLEXIBLE_DIM_MODELS:
if self.output_dimension is not None:
if self.output_dimension not in self._VALID_DIMENSIONS:
raise ValueError(
f"Invalid output_dimension {self.output_dimension} "
f"for {self.name}. Valid options: {self._VALID_DIMENSIONS}"
)
return self.output_dimension
return 1024 # default dimension
if self.name == "voyage-3-lite":
return 512
elif self.name == "voyage-code-2":
@@ -211,12 +260,17 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
"voyage-finance-2",
"voyage-multilingual-2",
"voyage-law-2",
"voyage-multimodal-3",
]:
return 1024
else:
raise ValueError(f"Model {self.name} not supported")
def _get_multimodal_kwargs(self, **kwargs):
"""Get kwargs for multimodal embed call, including output_dimension if set."""
if self.name in self._FLEXIBLE_DIM_MODELS and self.output_dimension is not None:
kwargs["output_dimension"] = self.output_dimension
return kwargs
def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]:
@@ -234,6 +288,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
"""
client = VoyageAIEmbeddingFunction._get_client()
if self._is_multimodal_model(self.name):
kwargs = self._get_multimodal_kwargs(**kwargs)
result = client.multimodal_embed(
inputs=[[query]], model=self.name, input_type="query", **kwargs
)
@@ -275,6 +330,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
)
if has_images:
# Use non-batched API for images
kwargs = self._get_multimodal_kwargs(**kwargs)
result = client.multimodal_embed(
inputs=sanitized, model=self.name, input_type="document", **kwargs
)
@@ -357,6 +413,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
callable: A function that takes a batch of texts and returns embeddings.
"""
if self._is_multimodal_model(self.name):
multimodal_kwargs = self._get_multimodal_kwargs(**kwargs)
def embed_batch(batch: List[str]) -> List[np.array]:
batch_inputs = sanitize_multimodal_input(batch)
@@ -364,7 +421,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
inputs=batch_inputs,
model=self.name,
input_type=input_type,
**kwargs,
**multimodal_kwargs,
)
return result.embeddings

View File

@@ -384,6 +384,7 @@ class RemoteDBConnection(DBConnection):
on_bad_vectors: str = "error",
fill_value: float = 0.0,
mode: Optional[str] = None,
exist_ok: bool = False,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
*,
namespace: Optional[List[str]] = None,
@@ -412,6 +413,12 @@ class RemoteDBConnection(DBConnection):
- pyarrow.Schema
- [LanceModel][lancedb.pydantic.LanceModel]
mode: str, default "create"
The mode to use when creating the table.
Can be either "create", "overwrite", or "exist_ok".
exist_ok: bool, default False
If exist_ok is True, and mode is None or "create", mode will be changed
to "exist_ok".
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
@@ -483,6 +490,11 @@ class RemoteDBConnection(DBConnection):
LanceTable(table4)
"""
if exist_ok:
if mode == "create":
mode = "exist_ok"
elif not mode:
mode = "exist_ok"
if namespace is None:
namespace = []
validate_table_name(name)

View File

@@ -18,7 +18,17 @@ from lancedb._lancedb import (
UpdateResult,
)
from lancedb.embeddings.base import EmbeddingFunctionConfig
from lancedb.index import FTS, BTree, Bitmap, HnswSq, IvfFlat, IvfPq, IvfSq, LabelList
from lancedb.index import (
FTS,
BTree,
Bitmap,
HnswSq,
IvfFlat,
IvfPq,
IvfRq,
IvfSq,
LabelList,
)
from lancedb.remote.db import LOOP
import pyarrow as pa
@@ -265,6 +275,12 @@ class RemoteTable(Table):
num_sub_vectors=num_sub_vectors,
num_bits=num_bits,
)
elif index_type == "IVF_RQ":
config = IvfRq(
distance_type=metric,
num_partitions=num_partitions,
num_bits=num_bits,
)
elif index_type == "IVF_SQ":
config = IvfSq(distance_type=metric, num_partitions=num_partitions)
elif index_type == "IVF_HNSW_PQ":
@@ -279,7 +295,8 @@ class RemoteTable(Table):
else:
raise ValueError(
f"Unknown vector index type: {index_type}. Valid options are"
" 'IVF_FLAT', 'IVF_SQ', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
" 'IVF_FLAT', 'IVF_PQ', 'IVF_RQ', 'IVF_SQ',"
" 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
)
LOOP.run(

View File

@@ -613,6 +613,133 @@ def test_voyageai_multimodal_embedding_text_function():
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_35_embedding_function():
"""Test voyage-multimodal-3.5 model with text input."""
voyageai = (
get_registry()
.get("voyageai")
.create(name="voyage-multimodal-3.5", max_retries=0)
)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test_multimodal_35", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
assert voyageai.ndims() == 1024
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_35_flexible_dimensions():
"""Test voyage-multimodal-3.5 model with custom output dimension."""
voyageai = (
get_registry()
.get("voyageai")
.create(name="voyage-multimodal-3.5", output_dimension=512, max_retries=0)
)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
assert voyageai.ndims() == 512
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test_multimodal_35_dim", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == 512
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_35_image_embedding():
"""Test voyage-multimodal-3.5 model with image input."""
voyageai = (
get_registry()
.get("voyageai")
.create(name="voyage-multimodal-3.5", max_retries=0)
)
class Images(LanceModel):
label: str
image_uri: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
db = lancedb.connect("~/lancedb")
table = db.create_table(
"test_multimodal_35_images", schema=Images, mode="overwrite"
)
labels = ["cat", "dog"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
]
table.add(pd.DataFrame({"label": labels, "image_uri": uris}))
assert len(table.to_pandas()["vector"][0]) == voyageai.ndims()
assert voyageai.ndims() == 1024
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
@pytest.mark.parametrize("dimension", [256, 512, 1024, 2048])
def test_voyageai_multimodal_35_all_dimensions(dimension):
"""Test voyage-multimodal-3.5 model with all valid output dimensions."""
voyageai = (
get_registry()
.get("voyageai")
.create(name="voyage-multimodal-3.5", output_dimension=dimension, max_retries=0)
)
assert voyageai.ndims() == dimension
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table(
f"test_multimodal_35_dim_{dimension}", schema=TextModel, mode="overwrite"
)
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == dimension
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_35_invalid_dimension():
"""Test voyage-multimodal-3.5 model raises error for invalid output dimension."""
with pytest.raises(ValueError, match="Invalid output_dimension"):
voyageai = (
get_registry()
.get("voyageai")
.create(name="voyage-multimodal-3.5", output_dimension=999, max_retries=0)
)
# ndims() is where the validation happens
voyageai.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
importlib.util.find_spec("colpali_engine") is None,

View File

@@ -168,6 +168,42 @@ def test_table_len_sync():
assert len(table) == 1
def test_create_table_exist_ok():
def handler(request):
if request.path == "/v1/table/test/create/?mode=exist_ok":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}], exist_ok=True)
assert table is not None
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}], mode="create", exist_ok=True)
assert table is not None
def test_create_table_exist_ok_with_mode_overwrite():
def handler(request):
if request.path == "/v1/table/test/create/?mode=overwrite":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}], mode="overwrite", exist_ok=True)
assert table is not None
@pytest.mark.asyncio
async def test_http_error():
request_id_holder = {"request_id": None}

View File

@@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""
Tests for S3 bucket names containing dots.
Related issue: https://github.com/lancedb/lancedb/issues/1898
These tests validate the early error checking for S3 bucket names with dots.
No actual S3 connection is made - validation happens before connection.
"""
import pytest
import lancedb
# Test URIs
BUCKET_WITH_DOTS = "s3://my.bucket.name/path"
BUCKET_WITH_DOTS_AND_REGION = ("s3://my.bucket.name", {"region": "us-east-1"})
BUCKET_WITH_DOTS_AND_AWS_REGION = ("s3://my.bucket.name", {"aws_region": "us-east-1"})
BUCKET_WITHOUT_DOTS = "s3://my-bucket/path"
class TestS3BucketWithDotsSync:
"""Tests for connect()."""
def test_bucket_with_dots_requires_region(self):
with pytest.raises(ValueError, match="contains dots"):
lancedb.connect(BUCKET_WITH_DOTS)
def test_bucket_with_dots_and_region_passes(self):
uri, opts = BUCKET_WITH_DOTS_AND_REGION
db = lancedb.connect(uri, storage_options=opts)
assert db is not None
def test_bucket_with_dots_and_aws_region_passes(self):
uri, opts = BUCKET_WITH_DOTS_AND_AWS_REGION
db = lancedb.connect(uri, storage_options=opts)
assert db is not None
def test_bucket_without_dots_passes(self):
db = lancedb.connect(BUCKET_WITHOUT_DOTS)
assert db is not None
class TestS3BucketWithDotsAsync:
"""Tests for connect_async()."""
@pytest.mark.asyncio
async def test_bucket_with_dots_requires_region(self):
with pytest.raises(ValueError, match="contains dots"):
await lancedb.connect_async(BUCKET_WITH_DOTS)
@pytest.mark.asyncio
async def test_bucket_with_dots_and_region_passes(self):
uri, opts = BUCKET_WITH_DOTS_AND_REGION
db = await lancedb.connect_async(uri, storage_options=opts)
assert db is not None
@pytest.mark.asyncio
async def test_bucket_with_dots_and_aws_region_passes(self):
uri, opts = BUCKET_WITH_DOTS_AND_AWS_REGION
db = await lancedb.connect_async(uri, storage_options=opts)
assert db is not None
@pytest.mark.asyncio
async def test_bucket_without_dots_passes(self):
db = await lancedb.connect_async(BUCKET_WITHOUT_DOTS)
assert db is not None

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.23.1-beta.1"
version = "0.23.1"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true

View File

@@ -1325,25 +1325,27 @@ mod tests {
#[tokio::test]
async fn test_table_names() {
let tmp_dir = tempdir().unwrap();
let tc = new_test_connection().await.unwrap();
let db = tc.connection;
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));
let mut names = Vec::with_capacity(100);
for _ in 0..100 {
let mut name = uuid::Uuid::new_v4().to_string();
let name = uuid::Uuid::new_v4().to_string();
names.push(name.clone());
name.push_str(".lance");
create_dir_all(tmp_dir.path().join(&name)).unwrap();
db.create_empty_table(name, schema.clone())
.execute()
.await
.unwrap();
}
names.sort();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let tables = db.table_names().execute().await.unwrap();
let tables = db.table_names().limit(100).execute().await.unwrap();
assert_eq!(tables, names);
let tables = db
.table_names()
.start_after(&names[30])
.limit(100)
.execute()
.await
.unwrap();

View File

@@ -120,8 +120,13 @@ impl MemoryRegistry {
}
/// A record batch reader that has embeddings applied to it
/// This is a wrapper around another record batch reader that applies an embedding function
/// when reading from the record batch
///
/// This is a wrapper around another record batch reader that applies embedding functions
/// when reading from the record batch.
///
/// When multiple embedding functions are defined, they are computed in parallel using
/// scoped threads to improve performance. For a single embedding function, computation
/// is done inline without threading overhead.
pub struct WithEmbeddings<R: RecordBatchReader> {
inner: R,
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
@@ -235,6 +240,48 @@ impl<R: RecordBatchReader> WithEmbeddings<R> {
column_definitions,
})
}
fn compute_embeddings_parallel(&self, batch: &RecordBatch) -> Result<Vec<Arc<dyn Array>>> {
if self.embeddings.len() == 1 {
let (fld, func) = &self.embeddings[0];
let src_column =
batch
.column_by_name(&fld.source_column)
.ok_or_else(|| Error::InvalidInput {
message: format!("Source column '{}' not found", fld.source_column),
})?;
return Ok(vec![func.compute_source_embeddings(src_column.clone())?]);
}
// Parallel path: multiple embeddings
std::thread::scope(|s| {
let handles: Vec<_> = self
.embeddings
.iter()
.map(|(fld, func)| {
let src_column = batch.column_by_name(&fld.source_column).ok_or_else(|| {
Error::InvalidInput {
message: format!("Source column '{}' not found", fld.source_column),
}
})?;
let handle =
s.spawn(move || func.compute_source_embeddings(src_column.clone()));
Ok(handle)
})
.collect::<Result<_>>()?;
handles
.into_iter()
.map(|h| {
h.join().map_err(|e| Error::Runtime {
message: format!("Thread panicked during embedding computation: {:?}", e),
})?
})
.collect()
})
}
}
impl<R: RecordBatchReader> Iterator for MaybeEmbedded<R> {
@@ -262,19 +309,19 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
fn next(&mut self) -> Option<Self::Item> {
let batch = self.inner.next()?;
match batch {
Ok(mut batch) => {
// todo: parallelize this
for (fld, func) in self.embeddings.iter() {
let src_column = batch.column_by_name(&fld.source_column).unwrap();
let embedding = match func.compute_source_embeddings(src_column.clone()) {
Ok(embedding) => embedding,
Err(e) => {
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
"Error computing embedding: {}",
e
))))
}
};
Ok(batch) => {
let embeddings = match self.compute_embeddings_parallel(&batch) {
Ok(emb) => emb,
Err(e) => {
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
"Error computing embedding: {}",
e
))))
}
};
let mut batch = batch;
for ((fld, _), embedding) in self.embeddings.iter().zip(embeddings.iter()) {
let dst_field_name = fld
.dest_column
.clone()
@@ -286,7 +333,7 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
embedding.nulls().is_some(),
);
match batch.try_with_column(dst_field.clone(), embedding) {
match batch.try_with_column(dst_field.clone(), embedding.clone()) {
Ok(b) => batch = b,
Err(e) => return Some(Err(e)),
};

View File

@@ -1088,6 +1088,17 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
}
}
Index::IvfRq(index) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_RQ".to_string());
body[METRIC_TYPE_KEY] =
serde_json::Value::String(index.distance_type.to_string().to_lowercase());
if let Some(num_partitions) = index.num_partitions {
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
}
if let Some(num_bits) = index.num_bits {
body["num_bits"] = serde_json::Value::Number(num_bits.into());
}
}
Index::BTree(_) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("BTREE".to_string());
}

View File

@@ -5,16 +5,19 @@
use regex::Regex;
use std::env;
use std::io::{BufRead, BufReader};
use std::process::{Child, ChildStdout, Command, Stdio};
use std::process::Stdio;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::{Child, ChildStdout, Command};
use tokio::sync::mpsc;
use crate::{connect, Connection};
use anyhow::{bail, Result};
use anyhow::{anyhow, bail, Result};
use tempfile::{tempdir, TempDir};
pub struct TestConnection {
pub uri: String,
pub connection: Connection,
pub is_remote: bool,
_temp_dir: Option<TempDir>,
_process: Option<TestProcess>,
}
@@ -37,6 +40,56 @@ pub async fn new_test_connection() -> Result<TestConnection> {
}
}
async fn spawn_stdout_reader(
mut stdout: BufReader<ChildStdout>,
port_sender: mpsc::Sender<anyhow::Result<String>>,
) -> tokio::task::JoinHandle<()> {
let print_stdout = env::var("PRINT_LANCEDB_TEST_CONNECTION_SCRIPT_OUTPUT").is_ok();
tokio::spawn(async move {
let mut line = String::new();
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
loop {
line.clear();
let result = stdout.read_line(&mut line).await;
if let Err(err) = result {
port_sender
.send(Err(anyhow!(
"error while reading from process output: {}",
err
)))
.await
.unwrap();
return;
} else if result.unwrap() == 0 {
port_sender
.send(Err(anyhow!(
" hit EOF before reading port from process output."
)))
.await
.unwrap();
return;
}
if re.is_match(&line) {
let caps = re.captures(&line).unwrap();
port_sender.send(Ok(caps[1].to_string())).await.unwrap();
break;
}
}
loop {
line.clear();
match stdout.read_line(&mut line).await {
Err(_) => return,
Ok(0) => return,
Ok(_size) => {
if print_stdout {
print!("{}", line);
}
}
}
}
})
}
async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
let temp_dir = tempdir()?;
let data_path = temp_dir.path().to_str().unwrap().to_string();
@@ -57,38 +110,25 @@ async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
child: child_result.unwrap(),
};
let stdout = BufReader::new(process.child.stdout.take().unwrap());
let port = read_process_port(stdout)?;
let (port_sender, mut port_receiver) = mpsc::channel(5);
let _reader = spawn_stdout_reader(stdout, port_sender).await;
let port = match port_receiver.recv().await {
None => bail!("Unable to determine the port number used by the phalanx process we spawned, because the reader thread was closed too soon."),
Some(Err(err)) => bail!("Unable to determine the port number used by the phalanx process we spawned, because of an error, {}", err),
Some(Ok(port)) => port,
};
let uri = "db://test";
let host_override = format!("http://localhost:{}", port);
let connection = create_new_connection(uri, &host_override).await?;
Ok(TestConnection {
uri: uri.to_string(),
connection,
is_remote: true,
_temp_dir: Some(temp_dir),
_process: Some(process),
})
}
fn read_process_port(mut stdout: BufReader<ChildStdout>) -> Result<String> {
let mut line = String::new();
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
loop {
let result = stdout.read_line(&mut line);
if let Err(err) = result {
bail!(format!(
"read_process_port: error while reading from process output: {}",
err
));
} else if result.unwrap() == 0 {
bail!("read_process_port: hit EOF before reading port from process output.");
}
if re.is_match(&line) {
let caps = re.captures(&line).unwrap();
return Ok(caps[1].to_string());
}
}
}
#[cfg(feature = "remote")]
async fn create_new_connection(uri: &str, host_override: &str) -> crate::error::Result<Connection> {
connect(uri)
@@ -114,6 +154,7 @@ async fn new_local_connection() -> Result<TestConnection> {
Ok(TestConnection {
uri: uri.to_string(),
connection,
is_remote: false,
_temp_dir: Some(temp_dir),
_process: None,
})

View File

@@ -0,0 +1,253 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{
borrow::Cow,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
use arrow::buffer::NullBuffer;
use arrow_array::{
Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use lancedb::{
embeddings::{EmbeddingDefinition, EmbeddingFunction, MaybeEmbedded, WithEmbeddings},
Error, Result,
};
#[derive(Debug)]
struct SlowMockEmbed {
name: String,
dim: usize,
delay_ms: u64,
call_count: Arc<AtomicUsize>,
}
impl SlowMockEmbed {
pub fn new(name: String, dim: usize, delay_ms: u64) -> Self {
Self {
name,
dim,
delay_ms,
call_count: Arc::new(AtomicUsize::new(0)),
}
}
pub fn get_call_count(&self) -> usize {
self.call_count.load(Ordering::SeqCst)
}
}
impl EmbeddingFunction for SlowMockEmbed {
fn name(&self) -> &str {
&self.name
}
fn source_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
self.dim as _,
true,
)))
}
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
// Simulate slow embedding computation
std::thread::sleep(Duration::from_millis(self.delay_ms));
self.call_count.fetch_add(1, Ordering::SeqCst);
let len = source.len();
let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim]));
let field = Field::new("item", inner.data_type().clone(), false);
let arr = FixedSizeListArray::new(
Arc::new(field),
self.dim as _,
inner,
Some(NullBuffer::new_valid(len)),
);
Ok(Arc::new(arr))
}
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}
fn create_test_batch() -> Result<RecordBatch> {
let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
let text = StringArray::from(vec!["hello", "world"]);
RecordBatch::try_new(schema, vec![Arc::new(text)]).map_err(|e| Error::Runtime {
message: format!("Failed to create test batch: {}", e),
})
}
#[test]
fn test_single_embedding_fast_path() {
// Single embedding should execute without spawning threads
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let embed = Arc::new(SlowMockEmbed::new("test".to_string(), 2, 10));
let embedding_def = EmbeddingDefinition::new("text", "test", Some("embedding"));
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let embeddings = vec![(embedding_def, embed.clone() as Arc<dyn EmbeddingFunction>)];
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
let result = with_embeddings.next().unwrap().unwrap();
assert!(result.column_by_name("embedding").is_some());
assert_eq!(embed.get_call_count(), 1);
}
#[test]
fn test_multiple_embeddings_parallel() {
// Multiple embeddings should execute in parallel
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let embed1 = Arc::new(SlowMockEmbed::new("embed1".to_string(), 2, 100));
let embed2 = Arc::new(SlowMockEmbed::new("embed2".to_string(), 3, 100));
let embed3 = Arc::new(SlowMockEmbed::new("embed3".to_string(), 4, 100));
let def1 = EmbeddingDefinition::new("text", "embed1", Some("emb1"));
let def2 = EmbeddingDefinition::new("text", "embed2", Some("emb2"));
let def3 = EmbeddingDefinition::new("text", "embed3", Some("emb3"));
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let embeddings = vec![
(def1, embed1.clone() as Arc<dyn EmbeddingFunction>),
(def2, embed2.clone() as Arc<dyn EmbeddingFunction>),
(def3, embed3.clone() as Arc<dyn EmbeddingFunction>),
];
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
let result = with_embeddings.next().unwrap().unwrap();
// Verify all embedding columns are present
assert!(result.column_by_name("emb1").is_some());
assert!(result.column_by_name("emb2").is_some());
assert!(result.column_by_name("emb3").is_some());
// Verify all embeddings were computed
assert_eq!(embed1.get_call_count(), 1);
assert_eq!(embed2.get_call_count(), 1);
assert_eq!(embed3.get_call_count(), 1);
}
#[test]
fn test_embedding_column_order_preserved() {
// Verify that embedding columns are added in the same order as definitions
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let embed1 = Arc::new(SlowMockEmbed::new("embed1".to_string(), 2, 10));
let embed2 = Arc::new(SlowMockEmbed::new("embed2".to_string(), 3, 10));
let embed3 = Arc::new(SlowMockEmbed::new("embed3".to_string(), 4, 10));
let def1 = EmbeddingDefinition::new("text", "embed1", Some("first"));
let def2 = EmbeddingDefinition::new("text", "embed2", Some("second"));
let def3 = EmbeddingDefinition::new("text", "embed3", Some("third"));
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let embeddings = vec![
(def1, embed1 as Arc<dyn EmbeddingFunction>),
(def2, embed2 as Arc<dyn EmbeddingFunction>),
(def3, embed3 as Arc<dyn EmbeddingFunction>),
];
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
let result = with_embeddings.next().unwrap().unwrap();
let result_schema = result.schema();
// Original column is first
assert_eq!(result_schema.field(0).name(), "text");
// Embedding columns follow in order
assert_eq!(result_schema.field(1).name(), "first");
assert_eq!(result_schema.field(2).name(), "second");
assert_eq!(result_schema.field(3).name(), "third");
}
#[test]
fn test_embedding_error_propagation() {
// Test that errors from embedding computation are properly propagated
#[derive(Debug)]
struct FailingEmbed {
name: String,
}
impl EmbeddingFunction for FailingEmbed {
fn name(&self) -> &str {
&self.name
}
fn source_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
2,
true,
)))
}
fn compute_source_embeddings(&self, _source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
Err(Error::Runtime {
message: "Intentional failure".to_string(),
})
}
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let embed = Arc::new(FailingEmbed {
name: "failing".to_string(),
});
let def = EmbeddingDefinition::new("text", "failing", Some("emb"));
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let embeddings = vec![(def, embed as Arc<dyn EmbeddingFunction>)];
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
let result = with_embeddings.next().unwrap();
assert!(result.is_err());
let err_msg = format!("{}", result.err().unwrap());
assert!(err_msg.contains("Intentional failure"));
}
#[test]
fn test_maybe_embedded_with_no_embeddings() {
// Test that MaybeEmbedded::No variant works correctly
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone());
let table_def = lancedb::table::TableDefinition {
schema: schema.clone(),
column_definitions: vec![lancedb::table::ColumnDefinition {
kind: lancedb::table::ColumnKind::Physical,
}],
};
let mut maybe_embedded = MaybeEmbedded::try_new(reader, table_def, None).unwrap();
let result = maybe_embedded.next().unwrap().unwrap();
assert_eq!(result.num_columns(), 1);
assert_eq!(result.column(0).as_ref(), batch.column(0).as_ref());
}