Compare commits

..

2 Commits

Author SHA1 Message Date
BubbleCal
929c683a6e Handle version number refs 2025-12-22 16:53:16 +08:00
lancedb automation
e2794d1a29 chore: update lance dependency to v2.0.0-beta.3 2025-12-19 22:04:10 +00:00
57 changed files with 775 additions and 1723 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.23.1" current_version = "0.23.1-beta.1"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -167,13 +167,13 @@ jobs:
- name: Build - name: Build
run: | run: |
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT $env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
cargo build --profile ci --features aws,remote --tests --locked --target ${{ matrix.target }} cargo build --profile ci --features remote --tests --locked --target ${{ matrix.target }}
- name: Run tests - name: Run tests
# Can only run tests when target matches host # Can only run tests when target matches host
if: ${{ matrix.target == 'x86_64-pc-windows-msvc' }} if: ${{ matrix.target == 'x86_64-pc-windows-msvc' }}
run: | run: |
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT $env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
cargo test --profile ci --features aws,remote --locked cargo test --profile ci --features remote --locked
msrv: msrv:
# Check the minimum supported Rust version # Check the minimum supported Rust version

847
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -15,39 +15,39 @@ categories = ["database-implementations"]
rust-version = "1.78.0" rust-version = "1.78.0"
[workspace.dependencies] [workspace.dependencies]
lance = { "version" = "=2.0.0-beta.8", default-features = false, "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance = { "version" = "=2.0.0-beta.3", default-features = false, "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-core = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-datagen = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-file = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=2.0.0-beta.8", default-features = false, "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-io = { "version" = "=2.0.0-beta.3", default-features = false, "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-index = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-linalg = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-namespace = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=2.0.0-beta.8", default-features = false, "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-namespace-impls = { "version" = "=2.0.0-beta.3", default-features = false, "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-table = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-testing = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-datafusion = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-encoding = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-arrow = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
ahash = "0.8" ahash = "0.8"
# Note that this one does not include pyarrow # Note that this one does not include pyarrow
arrow = { version = "57.2", optional = false } arrow = { version = "56.2", optional = false }
arrow-array = "57.2" arrow-array = "56.2"
arrow-data = "57.2" arrow-data = "56.2"
arrow-ipc = "57.2" arrow-ipc = "56.2"
arrow-ord = "57.2" arrow-ord = "56.2"
arrow-schema = "57.2" arrow-schema = "56.2"
arrow-select = "57.2" arrow-select = "56.2"
arrow-cast = "57.2" arrow-cast = "56.2"
async-trait = "0" async-trait = "0"
datafusion = { version = "51.0", default-features = false } datafusion = { version = "50.1", default-features = false }
datafusion-catalog = "51.0" datafusion-catalog = "50.1"
datafusion-common = { version = "51.0", default-features = false } datafusion-common = { version = "50.1", default-features = false }
datafusion-execution = "51.0" datafusion-execution = "50.1"
datafusion-expr = "51.0" datafusion-expr = "50.1"
datafusion-physical-plan = "51.0" datafusion-physical-plan = "50.1"
env_logger = "0.11" env_logger = "0.11"
half = { "version" = "2.7.1", default-features = false, features = [ half = { "version" = "2.6.0", default-features = false, features = [
"num-traits", "num-traits",
] } ] }
futures = "0" futures = "0"
@@ -59,7 +59,7 @@ rand = "0.9"
snafu = "0.8" snafu = "0.8"
url = "2" url = "2"
num-traits = "0.2" num-traits = "0.2"
regex = "1.12" regex = "1.10"
lazy_static = "1" lazy_static = "1"
semver = "1.0.25" semver = "1.0.25"
chrono = "0.4" chrono = "0.4"

View File

@@ -16,7 +16,7 @@ check_command_exists() {
} }
if [[ ! -e ./lancedb ]]; then if [[ ! -e ./lancedb ]]; then
if [[ x${SOPHON_READ_TOKEN} != "x" ]]; then if [[ -v SOPHON_READ_TOKEN ]]; then
INPUT="lancedb-linux-x64" INPUT="lancedb-linux-x64"
gh release \ gh release \
--repo lancedb/lancedb \ --repo lancedb/lancedb \

View File

@@ -11,7 +11,7 @@ watch:
theme: theme:
name: "material" name: "material"
logo: assets/logo.png logo: assets/logo.png
favicon: assets/favicon.ico favicon: assets/logo.png
palette: palette:
# Palette toggle for light mode # Palette toggle for light mode
- scheme: lancedb - scheme: lancedb
@@ -32,6 +32,8 @@ theme:
- content.tooltips - content.tooltips
- toc.follow - toc.follow
- navigation.top - navigation.top
- navigation.tabs
- navigation.tabs.sticky
- navigation.footer - navigation.footer
- navigation.tracking - navigation.tracking
- navigation.instant - navigation.instant
@@ -113,13 +115,12 @@ markdown_extensions:
emoji_index: !!python/name:material.extensions.emoji.twemoji emoji_index: !!python/name:material.extensions.emoji.twemoji
emoji_generator: !!python/name:material.extensions.emoji.to_svg emoji_generator: !!python/name:material.extensions.emoji.to_svg
- markdown.extensions.toc: - markdown.extensions.toc:
toc_depth: 3 baselevel: 1
permalink: true permalink: ""
permalink_title: Anchor link to this section
nav: nav:
- Documentation: - API reference:
- SDK Reference: index.md - Overview: index.md
- Python: python/python.md - Python: python/python.md
- Javascript/TypeScript: js/globals.md - Javascript/TypeScript: js/globals.md
- Java: java/java.md - Java: java/java.md

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

View File

@@ -1,111 +0,0 @@
# VoyageAI Embeddings : Multimodal
VoyageAI embeddings can also be used to embed both text and image data, only some of the models support image data and you can check the list
under [https://docs.voyageai.com/docs/multimodal-embeddings](https://docs.voyageai.com/docs/multimodal-embeddings)
Supported multimodal models:
- `voyage-multimodal-3` - 1024 dimensions (text + images)
- `voyage-multimodal-3.5` - Flexible dimensions (256, 512, 1024 default, 2048). Supports text, images, and video.
### Video Support (voyage-multimodal-3.5)
The `voyage-multimodal-3.5` model supports video input through:
- Video URLs (`.mp4`, `.webm`, `.mov`, `.avi`, `.mkv`, `.m4v`, `.gif`)
- Video file paths
Constraints: Max 20MB video size.
Supported parameters (to be passed in `create` method) are:
| Parameter | Type | Default Value | Description |
|---|---|-------------------------|-------------------------------------------|
| `name` | `str` | `"voyage-multimodal-3"` | The model ID of the VoyageAI model to use |
| `output_dimension` | `int` | `None` | Output dimension for voyage-multimodal-3.5. Valid: 256, 512, 1024, 2048 |
Usage Example:
```python
import base64
import os
from io import BytesIO
import requests
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
import pandas as pd
os.environ['VOYAGE_API_KEY'] = 'YOUR_VOYAGE_API_KEY'
db = lancedb.connect(".lancedb")
func = get_registry().get("voyageai").create(name="voyage-multimodal-3")
def image_to_base64(image_bytes: bytes):
buffered = BytesIO(image_bytes)
img_str = base64.b64encode(buffered.getvalue())
return img_str.decode("utf-8")
class Images(LanceModel):
label: str
image_uri: str = func.SourceField() # image uri as the source
image_bytes: str = func.SourceField() # image bytes base64 encoded as the source
vector: Vector(func.ndims()) = func.VectorField() # vector column
vec_from_bytes: Vector(func.ndims()) = func.VectorField() # Another vector column
if "images" in db.table_names():
db.drop_table("images")
table = db.create_table("images", schema=Images)
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# get each uri as bytes
images_bytes = [image_to_base64(requests.get(uri).content) for uri in uris]
table.add(
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": images_bytes})
)
```
Now we can search using text from both the default vector column and the custom vector column
```python
# text search
actual = table.search("man's best friend", "vec_from_bytes").limit(1).to_pydantic(Images)[0]
print(actual.label) # prints "dog"
frombytes = (
table.search("man's best friend", vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
print(frombytes.label)
```
Because we're using a multi-modal embedding function, we can also search using images
```python
# image search
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
image_bytes = requests.get(query_image_uri).content
query_image = Image.open(BytesIO(image_bytes))
actual = table.search(query_image, "vec_from_bytes").limit(1).to_pydantic(Images)[0]
print(actual.label == "dog")
# image search using a custom vector column
other = (
table.search(query_image, vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
print(actual.label)
```

View File

@@ -1,12 +1,8 @@
# SDK Reference # API Reference
This site contains the API reference for the client SDKs supported by [LanceDB](https://lancedb.com). This page contains the API reference for the SDKs supported by the LanceDB team.
- [Python](python/python.md) - [Python](python/python.md)
- [JavaScript/TypeScript](js/globals.md) - [JavaScript/TypeScript](js/globals.md)
- [Java](java/java.md) - [Java](java/java.md)
- [Rust](https://docs.rs/lancedb/latest/lancedb/index.html) - [Rust](https://docs.rs/lancedb/latest/lancedb/index.html)
!!! info "LanceDB Documentation"
If you're looking for the full documentation of LanceDB, visit [docs.lancedb.com](https://docs.lancedb.com).

View File

@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
<dependency> <dependency>
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-core</artifactId> <artifactId>lancedb-core</artifactId>
<version>0.23.1</version> <version>0.23.1-beta.1</version>
</dependency> </dependency>
``` ```

View File

@@ -85,26 +85,17 @@
/* Header gradient (only header area) */ /* Header gradient (only header area) */
.md-header { .md-header {
background: linear-gradient(90deg, #e4d8f8 0%, #F0B7C1 45%, #E55A2B 100%); background: linear-gradient(90deg, #3B2E58 0%, #F0B7C1 45%, #E55A2B 100%);
box-shadow: inset 0 1px 0 rgba(255,255,255,0.08), 0 1px 0 rgba(0,0,0,0.08); box-shadow: inset 0 1px 0 rgba(255,255,255,0.08), 0 1px 0 rgba(0,0,0,0.08);
} }
/* Improve brand title contrast on the lavender side */
.md-header__title,
.md-header__topic,
.md-header__title .md-ellipsis,
.md-header__topic .md-ellipsis {
color: #2b1b3a;
text-shadow: 0 1px 0 rgba(255, 255, 255, 0.25);
}
/* Same colors as header for tabs (that hold the text) */ /* Same colors as header for tabs (that hold the text) */
.md-tabs { .md-tabs {
background: linear-gradient(90deg, #e4d8f8 0%, #F0B7C1 45%, #E55A2B 100%); background: linear-gradient(90deg, #3B2E58 0%, #F0B7C1 45%, #E55A2B 100%);
} }
/* Dark scheme variant */ /* Dark scheme variant */
[data-md-color-scheme="slate"] .md-header, [data-md-color-scheme="slate"] .md-header,
[data-md-color-scheme="slate"] .md-tabs { [data-md-color-scheme="slate"] .md-tabs {
background: linear-gradient(90deg, #e4d8f8 0%, #F0B7C1 45%, #E55A2B 100%); background: linear-gradient(90deg, #3B2E58 0%, #F0B7C1 45%, #E55A2B 100%);
} }

View File

@@ -8,7 +8,7 @@
<parent> <parent>
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId> <artifactId>lancedb-parent</artifactId>
<version>0.23.1-final.0</version> <version>0.23.1-beta.1</version>
<relativePath>../pom.xml</relativePath> <relativePath>../pom.xml</relativePath>
</parent> </parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId> <artifactId>lancedb-parent</artifactId>
<version>0.23.1-final.0</version> <version>0.23.1-beta.1</version>
<packaging>pom</packaging> <packaging>pom</packaging>
<name>${project.artifactId}</name> <name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description> <description>LanceDB Java SDK Parent POM</description>

View File

@@ -1,7 +1,7 @@
[package] [package]
name = "lancedb-nodejs" name = "lancedb-nodejs"
edition.workspace = true edition.workspace = true
version = "0.23.1" version = "0.23.1-beta.1"
license.workspace = true license.workspace = true
description.workspace = true description.workspace = true
repository.workspace = true repository.workspace = true
@@ -36,6 +36,6 @@ aws-lc-rs = "=1.13.0"
napi-build = "2.1" napi-build = "2.1"
[features] [features]
default = ["remote", "lancedb/aws", "lancedb/gcs", "lancedb/azure", "lancedb/dynamodb", "lancedb/oss", "lancedb/huggingface"] default = ["remote", "lancedb/default"]
fp16kernels = ["lancedb/fp16kernels"] fp16kernels = ["lancedb/fp16kernels"]
remote = ["lancedb/remote"] remote = ["lancedb/remote"]

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-arm64", "name": "@lancedb/lancedb-darwin-arm64",
"version": "0.23.1", "version": "0.23.1-beta.1",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node", "main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-x64", "name": "@lancedb/lancedb-darwin-x64",
"version": "0.23.1", "version": "0.23.1-beta.1",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.darwin-x64.node", "main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-arm64-gnu", "name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.23.1", "version": "0.23.1-beta.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node", "main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-arm64-musl", "name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.23.1", "version": "0.23.1-beta.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node", "main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-x64-gnu", "name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.23.1", "version": "0.23.1-beta.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node", "main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-x64-musl", "name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.23.1", "version": "0.23.1-beta.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node", "main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-win32-arm64-msvc", "name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.23.1", "version": "0.23.1-beta.1",
"os": [ "os": [
"win32" "win32"
], ],

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-win32-x64-msvc", "name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.23.1", "version": "0.23.1-beta.1",
"os": ["win32"], "os": ["win32"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node", "main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{ {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.23.1", "version": "0.23.1-beta.1",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.23.1", "version": "0.23.1-beta.1",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"

View File

@@ -11,7 +11,7 @@
"ann" "ann"
], ],
"private": false, "private": false,
"version": "0.23.1", "version": "0.23.1-beta.1",
"main": "dist/index.js", "main": "dist/index.js",
"exports": { "exports": {
".": "./dist/index.js", ".": "./dist/index.js",

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.27.0-beta.0" current_version = "0.26.1-beta.1"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb-python" name = "lancedb-python"
version = "0.27.0-beta.0" version = "0.26.1-beta.1"
edition.workspace = true edition.workspace = true
description = "Python bindings for LanceDB" description = "Python bindings for LanceDB"
license.workspace = true license.workspace = true
@@ -14,15 +14,15 @@ name = "_lancedb"
crate-type = ["cdylib"] crate-type = ["cdylib"]
[dependencies] [dependencies]
arrow = { version = "57.2", features = ["pyarrow"] } arrow = { version = "56.2", features = ["pyarrow"] }
async-trait = "0.1" async-trait = "0.1"
lancedb = { path = "../rust/lancedb", default-features = false } lancedb = { path = "../rust/lancedb", default-features = false }
lance-core.workspace = true lance-core.workspace = true
lance-namespace.workspace = true lance-namespace.workspace = true
lance-io.workspace = true lance-io.workspace = true
env_logger.workspace = true env_logger.workspace = true
pyo3 = { version = "0.26", features = ["extension-module", "abi3-py39"] } pyo3 = { version = "0.25", features = ["extension-module", "abi3-py39"] }
pyo3-async-runtimes = { version = "0.26", features = [ pyo3-async-runtimes = { version = "0.25", features = [
"attributes", "attributes",
"tokio-runtime", "tokio-runtime",
] } ] }
@@ -32,12 +32,12 @@ snafu.workspace = true
tokio = { version = "1.40", features = ["sync"] } tokio = { version = "1.40", features = ["sync"] }
[build-dependencies] [build-dependencies]
pyo3-build-config = { version = "0.26", features = [ pyo3-build-config = { version = "0.25", features = [
"extension-module", "extension-module",
"abi3-py39", "abi3-py39",
] } ] }
[features] [features]
default = ["remote", "lancedb/aws", "lancedb/gcs", "lancedb/azure", "lancedb/dynamodb", "lancedb/oss", "lancedb/huggingface"] default = ["remote", "lancedb/default"]
fp16kernels = ["lancedb/fp16kernels"] fp16kernels = ["lancedb/fp16kernels"]
remote = ["lancedb/remote"] remote = ["lancedb/remote"]

View File

@@ -13,7 +13,6 @@ __version__ = importlib.metadata.version("lancedb")
from ._lancedb import connect as lancedb_connect from ._lancedb import connect as lancedb_connect
from .common import URI, sanitize_uri from .common import URI, sanitize_uri
from urllib.parse import urlparse
from .db import AsyncConnection, DBConnection, LanceDBConnection from .db import AsyncConnection, DBConnection, LanceDBConnection
from .io import StorageOptionsProvider from .io import StorageOptionsProvider
from .remote import ClientConfig from .remote import ClientConfig
@@ -29,39 +28,6 @@ from .namespace import (
) )
def _check_s3_bucket_with_dots(
uri: str, storage_options: Optional[Dict[str, str]]
) -> None:
"""
Check if an S3 URI has a bucket name containing dots and warn if no region
is specified. S3 buckets with dots cannot use virtual-hosted-style URLs,
which breaks automatic region detection.
See: https://github.com/lancedb/lancedb/issues/1898
"""
if not isinstance(uri, str) or not uri.startswith("s3://"):
return
parsed = urlparse(uri)
bucket = parsed.netloc
if "." not in bucket:
return
# Check if region is provided in storage_options
region_keys = {"region", "aws_region"}
has_region = storage_options and any(k in storage_options for k in region_keys)
if not has_region:
raise ValueError(
f"S3 bucket name '{bucket}' contains dots, which prevents automatic "
f"region detection. Please specify the region explicitly via "
f"storage_options={{'region': '<your-region>'}} or "
f"storage_options={{'aws_region': '<your-region>'}}. "
f"See https://github.com/lancedb/lancedb/issues/1898 for details."
)
def connect( def connect(
uri: URI, uri: URI,
*, *,
@@ -155,11 +121,9 @@ def connect(
storage_options=storage_options, storage_options=storage_options,
**kwargs, **kwargs,
) )
_check_s3_bucket_with_dots(str(uri), storage_options)
if kwargs: if kwargs:
raise ValueError(f"Unknown keyword arguments: {kwargs}") raise ValueError(f"Unknown keyword arguments: {kwargs}")
return LanceDBConnection( return LanceDBConnection(
uri, uri,
read_consistency_interval=read_consistency_interval, read_consistency_interval=read_consistency_interval,
@@ -247,8 +211,6 @@ async def connect_async(
if isinstance(client_config, dict): if isinstance(client_config, dict):
client_config = ClientConfig(**client_config) client_config = ClientConfig(**client_config)
_check_s3_bucket_with_dots(str(uri), storage_options)
return AsyncConnection( return AsyncConnection(
await lancedb_connect( await lancedb_connect(
sanitize_uri(uri), sanitize_uri(uri),

View File

@@ -179,7 +179,6 @@ class Table:
cleanup_since_ms: Optional[int] = None, cleanup_since_ms: Optional[int] = None,
delete_unverified: Optional[bool] = None, delete_unverified: Optional[bool] = None,
) -> OptimizeStats: ... ) -> OptimizeStats: ...
async def uri(self) -> str: ...
@property @property
def tags(self) -> Tags: ... def tags(self) -> Tags: ...
def query(self) -> Query: ... def query(self) -> Query: ...

View File

@@ -210,8 +210,10 @@ class DBConnection(EnforceOverrides):
page_token: str, optional page_token: str, optional
The token to use for pagination. If not present, start from the beginning. The token to use for pagination. If not present, start from the beginning.
Typically, this token is last table name from the previous page. Typically, this token is last table name from the previous page.
Only supported by LanceDb Cloud.
limit: int, default 10 limit: int, default 10
The size of the page to return. The size of the page to return.
Only supported by LanceDb Cloud.
Returns Returns
------- -------

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors # SPDX-FileCopyrightText: Copyright The LanceDB Authors
import base64 import base64
import os import os
from typing import ClassVar, TYPE_CHECKING, List, Union, Any, Generator, Optional from typing import ClassVar, TYPE_CHECKING, List, Union, Any, Generator
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -45,29 +45,11 @@ def is_valid_url(text):
return False return False
VIDEO_EXTENSIONS = {".mp4", ".webm", ".mov", ".avi", ".mkv", ".m4v", ".gif"}
def is_video_url(url: str) -> bool:
"""Check if URL points to a video file based on extension."""
parsed = urlparse(url)
path = parsed.path.lower()
return any(path.endswith(ext) for ext in VIDEO_EXTENSIONS)
def is_video_path(path: Path) -> bool:
"""Check if file path is a video file based on extension."""
return path.suffix.lower() in VIDEO_EXTENSIONS
def transform_input(input_data: Union[str, bytes, Path]): def transform_input(input_data: Union[str, bytes, Path]):
PIL = attempt_import_or_raise("PIL", "pillow") PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(input_data, str): if isinstance(input_data, str):
if is_valid_url(input_data): if is_valid_url(input_data):
if is_video_url(input_data): content = {"type": "image_url", "image_url": input_data}
content = {"type": "video_url", "video_url": input_data}
else:
content = {"type": "image_url", "image_url": input_data}
else: else:
content = {"type": "text", "text": input_data} content = {"type": "text", "text": input_data}
elif isinstance(input_data, PIL.Image.Image): elif isinstance(input_data, PIL.Image.Image):
@@ -88,24 +70,14 @@ def transform_input(input_data: Union[str, bytes, Path]):
"image_base64": "data:image/jpeg;base64," + img_str, "image_base64": "data:image/jpeg;base64," + img_str,
} }
elif isinstance(input_data, Path): elif isinstance(input_data, Path):
if is_video_path(input_data): img = PIL.Image.open(input_data)
# Read video file and encode as base64 buffered = BytesIO()
with open(input_data, "rb") as f: img.save(buffered, format="JPEG")
video_bytes = f.read() img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
video_str = base64.b64encode(video_bytes).decode("utf-8") content = {
content = { "type": "image_base64",
"type": "video_base64", "image_base64": "data:image/jpeg;base64," + img_str,
"video_base64": video_str, }
}
else:
img = PIL.Image.open(input_data)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
content = {
"type": "image_base64",
"image_base64": "data:image/jpeg;base64," + img_str,
}
else: else:
raise ValueError("Each input should be either str, bytes, Path or Image.") raise ValueError("Each input should be either str, bytes, Path or Image.")
@@ -119,8 +91,6 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
PIL = attempt_import_or_raise("PIL", "pillow") PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)): if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)):
inputs = [inputs] inputs = [inputs]
elif isinstance(inputs, list):
pass # Already a list, use as-is
elif isinstance(inputs, pa.Array): elif isinstance(inputs, pa.Array):
inputs = inputs.to_pylist() inputs = inputs.to_pylist()
elif isinstance(inputs, pa.ChunkedArray): elif isinstance(inputs, pa.ChunkedArray):
@@ -173,16 +143,11 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
* voyage-3 * voyage-3
* voyage-3-lite * voyage-3-lite
* voyage-multimodal-3 * voyage-multimodal-3
* voyage-multimodal-3.5
* voyage-finance-2 * voyage-finance-2
* voyage-multilingual-2 * voyage-multilingual-2
* voyage-law-2 * voyage-law-2
* voyage-code-2 * voyage-code-2
output_dimension: int, optional
The output dimension for models that support flexible dimensions.
Currently only voyage-multimodal-3.5 supports this feature.
Valid options: 256, 512, 1024 (default), 2048.
Examples Examples
-------- --------
@@ -210,10 +175,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
""" """
name: str name: str
output_dimension: Optional[int] = None
client: ClassVar = None client: ClassVar = None
_FLEXIBLE_DIM_MODELS: ClassVar[list] = ["voyage-multimodal-3.5"]
_VALID_DIMENSIONS: ClassVar[list] = [256, 512, 1024, 2048]
text_embedding_models: list = [ text_embedding_models: list = [
"voyage-3.5", "voyage-3.5",
"voyage-3.5-lite", "voyage-3.5-lite",
@@ -224,7 +186,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
"voyage-law-2", "voyage-law-2",
"voyage-code-2", "voyage-code-2",
] ]
multimodal_embedding_models: list = ["voyage-multimodal-3", "voyage-multimodal-3.5"] multimodal_embedding_models: list = ["voyage-multimodal-3"]
contextual_embedding_models: list = ["voyage-context-3"] contextual_embedding_models: list = ["voyage-context-3"]
def _is_multimodal_model(self, model_name: str): def _is_multimodal_model(self, model_name: str):
@@ -236,17 +198,6 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
return model_name in self.contextual_embedding_models or "context" in model_name return model_name in self.contextual_embedding_models or "context" in model_name
def ndims(self): def ndims(self):
# Handle flexible dimension models
if self.name in self._FLEXIBLE_DIM_MODELS:
if self.output_dimension is not None:
if self.output_dimension not in self._VALID_DIMENSIONS:
raise ValueError(
f"Invalid output_dimension {self.output_dimension} "
f"for {self.name}. Valid options: {self._VALID_DIMENSIONS}"
)
return self.output_dimension
return 1024 # default dimension
if self.name == "voyage-3-lite": if self.name == "voyage-3-lite":
return 512 return 512
elif self.name == "voyage-code-2": elif self.name == "voyage-code-2":
@@ -260,17 +211,12 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
"voyage-finance-2", "voyage-finance-2",
"voyage-multilingual-2", "voyage-multilingual-2",
"voyage-law-2", "voyage-law-2",
"voyage-multimodal-3",
]: ]:
return 1024 return 1024
else: else:
raise ValueError(f"Model {self.name} not supported") raise ValueError(f"Model {self.name} not supported")
def _get_multimodal_kwargs(self, **kwargs):
"""Get kwargs for multimodal embed call, including output_dimension if set."""
if self.name in self._FLEXIBLE_DIM_MODELS and self.output_dimension is not None:
kwargs["output_dimension"] = self.output_dimension
return kwargs
def compute_query_embeddings( def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]: ) -> List[np.ndarray]:
@@ -288,7 +234,6 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
""" """
client = VoyageAIEmbeddingFunction._get_client() client = VoyageAIEmbeddingFunction._get_client()
if self._is_multimodal_model(self.name): if self._is_multimodal_model(self.name):
kwargs = self._get_multimodal_kwargs(**kwargs)
result = client.multimodal_embed( result = client.multimodal_embed(
inputs=[[query]], model=self.name, input_type="query", **kwargs inputs=[[query]], model=self.name, input_type="query", **kwargs
) )
@@ -330,7 +275,6 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
) )
if has_images: if has_images:
# Use non-batched API for images # Use non-batched API for images
kwargs = self._get_multimodal_kwargs(**kwargs)
result = client.multimodal_embed( result = client.multimodal_embed(
inputs=sanitized, model=self.name, input_type="document", **kwargs inputs=sanitized, model=self.name, input_type="document", **kwargs
) )
@@ -413,7 +357,6 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
callable: A function that takes a batch of texts and returns embeddings. callable: A function that takes a batch of texts and returns embeddings.
""" """
if self._is_multimodal_model(self.name): if self._is_multimodal_model(self.name):
multimodal_kwargs = self._get_multimodal_kwargs(**kwargs)
def embed_batch(batch: List[str]) -> List[np.array]: def embed_batch(batch: List[str]) -> List[np.array]:
batch_inputs = sanitize_multimodal_input(batch) batch_inputs = sanitize_multimodal_input(batch)
@@ -421,7 +364,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
inputs=batch_inputs, inputs=batch_inputs,
model=self.name, model=self.name,
input_type=input_type, input_type=input_type,
**multimodal_kwargs, **kwargs,
) )
return result.embeddings return result.embeddings

View File

@@ -961,27 +961,22 @@ class LanceQueryBuilder(ABC):
>>> query = [100, 100] >>> query = [100, 100]
>>> plan = table.search(query).analyze_plan() >>> plan = table.search(query).analyze_plan()
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE >>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
AnalyzeExec verbose=true, elapsed=..., metrics=... AnalyzeExec verbose=true, metrics=[], cumulative_cpu=...
TracedExec, elapsed=..., metrics=... TracedExec, metrics=[], cumulative_cpu=...
ProjectionExec: elapsed=..., expr=[...], ProjectionExec: expr=[...], metrics=[...], cumulative_cpu=...
metrics=[output_rows=..., elapsed_compute=..., output_bytes=...] GlobalLimitExec: skip=0, fetch=10, metrics=[...], cumulative_cpu=...
GlobalLimitExec: elapsed=..., skip=0, fetch=10, FilterExec: _distance@2 IS NOT NULL,
metrics=[output_rows=..., elapsed_compute=..., output_bytes=...] metrics=[output_rows=..., elapsed_compute=...], cumulative_cpu=...
FilterExec: elapsed=..., _distance@2 IS NOT NULL, metrics=[...] SortExec: TopK(fetch=10), expr=[...],
SortExec: elapsed=..., TopK(fetch=10), expr=[...],
preserve_partitioning=[...], preserve_partitioning=[...],
metrics=[output_rows=..., elapsed_compute=..., metrics=[output_rows=..., elapsed_compute=..., row_replacements=...],
output_bytes=..., row_replacements=...] cumulative_cpu=...
KNNVectorDistance: elapsed=..., metric=l2, KNNVectorDistance: metric=l2,
metrics=[output_rows=..., elapsed_compute=..., metrics=[output_rows=..., elapsed_compute=..., output_batches=...],
output_bytes=..., output_batches=...] cumulative_cpu=...
LanceRead: elapsed=..., uri=..., projection=[vector], LanceRead: uri=..., projection=[vector], ...
num_fragments=..., range_before=None, range_after=None, metrics=[output_rows=..., elapsed_compute=...,
row_id=true, row_addr=false, bytes_read=..., iops=..., requests=...], cumulative_cpu=...
full_filter=--, refine_filter=--,
metrics=[output_rows=..., elapsed_compute=..., output_bytes=...,
fragments_scanned=..., ranges_scanned=1, rows_scanned=1,
bytes_read=..., iops=..., requests=..., task_wait_time=...]
Returns Returns
------- -------

View File

@@ -384,7 +384,6 @@ class RemoteDBConnection(DBConnection):
on_bad_vectors: str = "error", on_bad_vectors: str = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
mode: Optional[str] = None, mode: Optional[str] = None,
exist_ok: bool = False,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
*, *,
namespace: Optional[List[str]] = None, namespace: Optional[List[str]] = None,
@@ -413,12 +412,6 @@ class RemoteDBConnection(DBConnection):
- pyarrow.Schema - pyarrow.Schema
- [LanceModel][lancedb.pydantic.LanceModel] - [LanceModel][lancedb.pydantic.LanceModel]
mode: str, default "create"
The mode to use when creating the table.
Can be either "create", "overwrite", or "exist_ok".
exist_ok: bool, default False
If exist_ok is True, and mode is None or "create", mode will be changed
to "exist_ok".
on_bad_vectors: str, default "error" on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs. What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill". One of "error", "drop", "fill".
@@ -490,11 +483,6 @@ class RemoteDBConnection(DBConnection):
LanceTable(table4) LanceTable(table4)
""" """
if exist_ok:
if mode == "create":
mode = "exist_ok"
elif not mode:
mode = "exist_ok"
if namespace is None: if namespace is None:
namespace = [] namespace = []
validate_table_name(name) validate_table_name(name)

View File

@@ -18,17 +18,7 @@ from lancedb._lancedb import (
UpdateResult, UpdateResult,
) )
from lancedb.embeddings.base import EmbeddingFunctionConfig from lancedb.embeddings.base import EmbeddingFunctionConfig
from lancedb.index import ( from lancedb.index import FTS, BTree, Bitmap, HnswSq, IvfFlat, IvfPq, IvfSq, LabelList
FTS,
BTree,
Bitmap,
HnswSq,
IvfFlat,
IvfPq,
IvfRq,
IvfSq,
LabelList,
)
from lancedb.remote.db import LOOP from lancedb.remote.db import LOOP
import pyarrow as pa import pyarrow as pa
@@ -275,12 +265,6 @@ class RemoteTable(Table):
num_sub_vectors=num_sub_vectors, num_sub_vectors=num_sub_vectors,
num_bits=num_bits, num_bits=num_bits,
) )
elif index_type == "IVF_RQ":
config = IvfRq(
distance_type=metric,
num_partitions=num_partitions,
num_bits=num_bits,
)
elif index_type == "IVF_SQ": elif index_type == "IVF_SQ":
config = IvfSq(distance_type=metric, num_partitions=num_partitions) config = IvfSq(distance_type=metric, num_partitions=num_partitions)
elif index_type == "IVF_HNSW_PQ": elif index_type == "IVF_HNSW_PQ":
@@ -295,8 +279,7 @@ class RemoteTable(Table):
else: else:
raise ValueError( raise ValueError(
f"Unknown vector index type: {index_type}. Valid options are" f"Unknown vector index type: {index_type}. Valid options are"
" 'IVF_FLAT', 'IVF_PQ', 'IVF_RQ', 'IVF_SQ'," " 'IVF_FLAT', 'IVF_SQ', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
" 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
) )
LOOP.run( LOOP.run(
@@ -655,14 +638,6 @@ class RemoteTable(Table):
def stats(self): def stats(self):
return LOOP.run(self._table.stats()) return LOOP.run(self._table.stats())
@property
def uri(self) -> str:
"""The table URI (storage location).
For remote tables, this fetches the location from the server via describe.
"""
return LOOP.run(self._table.uri())
def take_offsets(self, offsets: list[int]) -> LanceTakeQueryBuilder: def take_offsets(self, offsets: list[int]) -> LanceTakeQueryBuilder:
return LanceTakeQueryBuilder(self._table.take_offsets(offsets)) return LanceTakeQueryBuilder(self._table.take_offsets(offsets))

View File

@@ -2218,10 +2218,6 @@ class LanceTable(Table):
def stats(self) -> TableStatistics: def stats(self) -> TableStatistics:
return LOOP.run(self._table.stats()) return LOOP.run(self._table.stats())
@property
def uri(self) -> str:
return LOOP.run(self._table.uri())
def create_scalar_index( def create_scalar_index(
self, self,
column: str, column: str,
@@ -3610,20 +3606,6 @@ class AsyncTable:
""" """
return await self._inner.stats() return await self._inner.stats()
async def uri(self) -> str:
"""
Get the table URI (storage location).
For remote tables, this fetches the location from the server via describe.
For local tables, this returns the dataset URI.
Returns
-------
str
The full storage location of the table (e.g., S3/GCS path).
"""
return await self._inner.uri()
async def add( async def add(
self, self,
data: DATA, data: DATA,

View File

@@ -613,133 +613,6 @@ def test_voyageai_multimodal_embedding_text_function():
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims() assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_35_embedding_function():
"""Test voyage-multimodal-3.5 model with text input."""
voyageai = (
get_registry()
.get("voyageai")
.create(name="voyage-multimodal-3.5", max_retries=0)
)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test_multimodal_35", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
assert voyageai.ndims() == 1024
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_35_flexible_dimensions():
"""Test voyage-multimodal-3.5 model with custom output dimension."""
voyageai = (
get_registry()
.get("voyageai")
.create(name="voyage-multimodal-3.5", output_dimension=512, max_retries=0)
)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
assert voyageai.ndims() == 512
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test_multimodal_35_dim", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == 512
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_35_image_embedding():
"""Test voyage-multimodal-3.5 model with image input."""
voyageai = (
get_registry()
.get("voyageai")
.create(name="voyage-multimodal-3.5", max_retries=0)
)
class Images(LanceModel):
label: str
image_uri: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
db = lancedb.connect("~/lancedb")
table = db.create_table(
"test_multimodal_35_images", schema=Images, mode="overwrite"
)
labels = ["cat", "dog"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
]
table.add(pd.DataFrame({"label": labels, "image_uri": uris}))
assert len(table.to_pandas()["vector"][0]) == voyageai.ndims()
assert voyageai.ndims() == 1024
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
@pytest.mark.parametrize("dimension", [256, 512, 1024, 2048])
def test_voyageai_multimodal_35_all_dimensions(dimension):
"""Test voyage-multimodal-3.5 model with all valid output dimensions."""
voyageai = (
get_registry()
.get("voyageai")
.create(name="voyage-multimodal-3.5", output_dimension=dimension, max_retries=0)
)
assert voyageai.ndims() == dimension
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table(
f"test_multimodal_35_dim_{dimension}", schema=TextModel, mode="overwrite"
)
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == dimension
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_35_invalid_dimension():
"""Test voyage-multimodal-3.5 model raises error for invalid output dimension."""
with pytest.raises(ValueError, match="Invalid output_dimension"):
voyageai = (
get_registry()
.get("voyageai")
.create(name="voyage-multimodal-3.5", output_dimension=999, max_retries=0)
)
# ndims() is where the validation happens
voyageai.ndims()
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.skipif( @pytest.mark.skipif(
importlib.util.find_spec("colpali_engine") is None, importlib.util.find_spec("colpali_engine") is None,

View File

@@ -168,42 +168,6 @@ def test_table_len_sync():
assert len(table) == 1 assert len(table) == 1
def test_create_table_exist_ok():
def handler(request):
if request.path == "/v1/table/test/create/?mode=exist_ok":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}], exist_ok=True)
assert table is not None
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}], mode="create", exist_ok=True)
assert table is not None
def test_create_table_exist_ok_with_mode_overwrite():
def handler(request):
if request.path == "/v1/table/test/create/?mode=overwrite":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}], mode="overwrite", exist_ok=True)
assert table is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_http_error(): async def test_http_error():
request_id_holder = {"request_id": None} request_id_holder = {"request_id": None}

View File

@@ -1,68 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""
Tests for S3 bucket names containing dots.
Related issue: https://github.com/lancedb/lancedb/issues/1898
These tests validate the early error checking for S3 bucket names with dots.
No actual S3 connection is made - validation happens before connection.
"""
import pytest
import lancedb
# Test URIs
BUCKET_WITH_DOTS = "s3://my.bucket.name/path"
BUCKET_WITH_DOTS_AND_REGION = ("s3://my.bucket.name", {"region": "us-east-1"})
BUCKET_WITH_DOTS_AND_AWS_REGION = ("s3://my.bucket.name", {"aws_region": "us-east-1"})
BUCKET_WITHOUT_DOTS = "s3://my-bucket/path"
class TestS3BucketWithDotsSync:
"""Tests for connect()."""
def test_bucket_with_dots_requires_region(self):
with pytest.raises(ValueError, match="contains dots"):
lancedb.connect(BUCKET_WITH_DOTS)
def test_bucket_with_dots_and_region_passes(self):
uri, opts = BUCKET_WITH_DOTS_AND_REGION
db = lancedb.connect(uri, storage_options=opts)
assert db is not None
def test_bucket_with_dots_and_aws_region_passes(self):
uri, opts = BUCKET_WITH_DOTS_AND_AWS_REGION
db = lancedb.connect(uri, storage_options=opts)
assert db is not None
def test_bucket_without_dots_passes(self):
db = lancedb.connect(BUCKET_WITHOUT_DOTS)
assert db is not None
class TestS3BucketWithDotsAsync:
"""Tests for connect_async()."""
@pytest.mark.asyncio
async def test_bucket_with_dots_requires_region(self):
with pytest.raises(ValueError, match="contains dots"):
await lancedb.connect_async(BUCKET_WITH_DOTS)
@pytest.mark.asyncio
async def test_bucket_with_dots_and_region_passes(self):
uri, opts = BUCKET_WITH_DOTS_AND_REGION
db = await lancedb.connect_async(uri, storage_options=opts)
assert db is not None
@pytest.mark.asyncio
async def test_bucket_with_dots_and_aws_region_passes(self):
uri, opts = BUCKET_WITH_DOTS_AND_AWS_REGION
db = await lancedb.connect_async(uri, storage_options=opts)
assert db is not None
@pytest.mark.asyncio
async def test_bucket_without_dots_passes(self):
db = await lancedb.connect_async(BUCKET_WITHOUT_DOTS)
assert db is not None

View File

@@ -1967,9 +1967,3 @@ def test_add_table_with_empty_embeddings(tmp_path):
on_bad_vectors="drop", on_bad_vectors="drop",
) )
assert table.count_rows() == 1 assert table.count_rows() == 1
def test_table_uri(tmp_path):
db = lancedb.connect(tmp_path)
table = db.create_table("my_table", data=[{"x": 0}])
assert table.uri == str(tmp_path / "my_table.lance")

View File

@@ -10,7 +10,8 @@ use arrow::{
use futures::stream::StreamExt; use futures::stream::StreamExt;
use lancedb::arrow::SendableRecordBatchStream; use lancedb::arrow::SendableRecordBatchStream;
use pyo3::{ use pyo3::{
exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, Py, PyAny, PyRef, PyResult, Python, exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, PyAny, PyObject, PyRef, PyResult,
Python,
}; };
use pyo3_async_runtimes::tokio::future_into_py; use pyo3_async_runtimes::tokio::future_into_py;
@@ -35,11 +36,8 @@ impl RecordBatchStream {
#[pymethods] #[pymethods]
impl RecordBatchStream { impl RecordBatchStream {
#[getter] #[getter]
pub fn schema(&self, py: Python) -> PyResult<Py<PyAny>> { pub fn schema(&self, py: Python) -> PyResult<PyObject> {
(*self.schema) (*self.schema).clone().into_pyarrow(py)
.clone()
.into_pyarrow(py)
.map(|obj| obj.unbind())
} }
pub fn __aiter__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> { pub fn __aiter__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> {
@@ -55,12 +53,7 @@ impl RecordBatchStream {
.next() .next()
.await .await
.ok_or_else(|| PyStopAsyncIteration::new_err(""))?; .ok_or_else(|| PyStopAsyncIteration::new_err(""))?;
#[allow(deprecated)] Python::with_gil(|py| inner_next.infer_error()?.to_pyarrow(py))
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
let bound = inner_next.infer_error()?.to_pyarrow(py)?;
Ok(bound.unbind())
})?;
Ok(py_obj)
}) })
} }
} }

View File

@@ -12,7 +12,7 @@ use pyo3::{
exceptions::{PyRuntimeError, PyValueError}, exceptions::{PyRuntimeError, PyValueError},
pyclass, pyfunction, pymethods, pyclass, pyfunction, pymethods,
types::{PyDict, PyDictMethods}, types::{PyDict, PyDictMethods},
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python, Bound, FromPyObject, Py, PyAny, PyObject, PyRef, PyResult, Python,
}; };
use pyo3_async_runtimes::tokio::future_into_py; use pyo3_async_runtimes::tokio::future_into_py;
@@ -114,7 +114,7 @@ impl Connection {
data: Bound<'_, PyAny>, data: Bound<'_, PyAny>,
namespace: Vec<String>, namespace: Vec<String>,
storage_options: Option<HashMap<String, String>>, storage_options: Option<HashMap<String, String>>,
storage_options_provider: Option<Py<PyAny>>, storage_options_provider: Option<PyObject>,
location: Option<String>, location: Option<String>,
) -> PyResult<Bound<'a, PyAny>> { ) -> PyResult<Bound<'a, PyAny>> {
let inner = self_.get_inner()?.clone(); let inner = self_.get_inner()?.clone();
@@ -152,7 +152,7 @@ impl Connection {
schema: Bound<'_, PyAny>, schema: Bound<'_, PyAny>,
namespace: Vec<String>, namespace: Vec<String>,
storage_options: Option<HashMap<String, String>>, storage_options: Option<HashMap<String, String>>,
storage_options_provider: Option<Py<PyAny>>, storage_options_provider: Option<PyObject>,
location: Option<String>, location: Option<String>,
) -> PyResult<Bound<'a, PyAny>> { ) -> PyResult<Bound<'a, PyAny>> {
let inner = self_.get_inner()?.clone(); let inner = self_.get_inner()?.clone();
@@ -187,7 +187,7 @@ impl Connection {
name: String, name: String,
namespace: Vec<String>, namespace: Vec<String>,
storage_options: Option<HashMap<String, String>>, storage_options: Option<HashMap<String, String>>,
storage_options_provider: Option<Py<PyAny>>, storage_options_provider: Option<PyObject>,
index_cache_size: Option<u32>, index_cache_size: Option<u32>,
location: Option<String>, location: Option<String>,
) -> PyResult<Bound<'_, PyAny>> { ) -> PyResult<Bound<'_, PyAny>> {
@@ -304,10 +304,8 @@ impl Connection {
}, },
page_token, page_token,
limit: limit.map(|l| l as i32), limit: limit.map(|l| l as i32),
..Default::default()
}; };
let response = inner.list_namespaces(request).await.infer_error()?; let response = inner.list_namespaces(request).await.infer_error()?;
#[allow(deprecated)]
Python::with_gil(|py| -> PyResult<Py<PyDict>> { Python::with_gil(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py); let dict = PyDict::new(py);
dict.set_item("namespaces", response.namespaces)?; dict.set_item("namespaces", response.namespaces)?;
@@ -328,11 +326,11 @@ impl Connection {
let py = self_.py(); let py = self_.py();
future_into_py(py, async move { future_into_py(py, async move {
use lance_namespace::models::CreateNamespaceRequest; use lance_namespace::models::CreateNamespaceRequest;
let mode_enum = mode.and_then(|m| match m.to_lowercase().as_str() { let mode_value = mode.map(|m| match m.to_lowercase().as_str() {
"create" => Some("Create".to_string()), "create" => "Create".to_string(),
"exist_ok" => Some("ExistOk".to_string()), "exist_ok" => "ExistOk".to_string(),
"overwrite" => Some("Overwrite".to_string()), "overwrite" => "Overwrite".to_string(),
_ => None, _ => m,
}); });
let request = CreateNamespaceRequest { let request = CreateNamespaceRequest {
id: if namespace.is_empty() { id: if namespace.is_empty() {
@@ -340,12 +338,10 @@ impl Connection {
} else { } else {
Some(namespace) Some(namespace)
}, },
mode: mode_enum, mode: mode_value,
properties, properties,
..Default::default()
}; };
let response = inner.create_namespace(request).await.infer_error()?; let response = inner.create_namespace(request).await.infer_error()?;
#[allow(deprecated)]
Python::with_gil(|py| -> PyResult<Py<PyDict>> { Python::with_gil(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py); let dict = PyDict::new(py);
dict.set_item("properties", response.properties)?; dict.set_item("properties", response.properties)?;
@@ -365,15 +361,15 @@ impl Connection {
let py = self_.py(); let py = self_.py();
future_into_py(py, async move { future_into_py(py, async move {
use lance_namespace::models::DropNamespaceRequest; use lance_namespace::models::DropNamespaceRequest;
let mode_enum = mode.and_then(|m| match m.to_uppercase().as_str() { let mode_value = mode.map(|m| match m.to_uppercase().as_str() {
"SKIP" => Some("Skip".to_string()), "SKIP" => "Skip".to_string(),
"FAIL" => Some("Fail".to_string()), "FAIL" => "Fail".to_string(),
_ => None, _ => m,
}); });
let behavior_enum = behavior.and_then(|b| match b.to_uppercase().as_str() { let behavior_value = behavior.map(|b| match b.to_uppercase().as_str() {
"RESTRICT" => Some("Restrict".to_string()), "RESTRICT" => "Restrict".to_string(),
"CASCADE" => Some("Cascade".to_string()), "CASCADE" => "Cascade".to_string(),
_ => None, _ => b,
}); });
let request = DropNamespaceRequest { let request = DropNamespaceRequest {
id: if namespace.is_empty() { id: if namespace.is_empty() {
@@ -381,12 +377,10 @@ impl Connection {
} else { } else {
Some(namespace) Some(namespace)
}, },
mode: mode_enum, mode: mode_value,
behavior: behavior_enum, behavior: behavior_value,
..Default::default()
}; };
let response = inner.drop_namespace(request).await.infer_error()?; let response = inner.drop_namespace(request).await.infer_error()?;
#[allow(deprecated)]
Python::with_gil(|py| -> PyResult<Py<PyDict>> { Python::with_gil(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py); let dict = PyDict::new(py);
dict.set_item("properties", response.properties)?; dict.set_item("properties", response.properties)?;
@@ -411,10 +405,8 @@ impl Connection {
} else { } else {
Some(namespace) Some(namespace)
}, },
..Default::default()
}; };
let response = inner.describe_namespace(request).await.infer_error()?; let response = inner.describe_namespace(request).await.infer_error()?;
#[allow(deprecated)]
Python::with_gil(|py| -> PyResult<Py<PyDict>> { Python::with_gil(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py); let dict = PyDict::new(py);
dict.set_item("properties", response.properties)?; dict.set_item("properties", response.properties)?;
@@ -442,10 +434,8 @@ impl Connection {
}, },
page_token, page_token,
limit: limit.map(|l| l as i32), limit: limit.map(|l| l as i32),
..Default::default()
}; };
let response = inner.list_tables(request).await.infer_error()?; let response = inner.list_tables(request).await.infer_error()?;
#[allow(deprecated)]
Python::with_gil(|py| -> PyResult<Py<PyDict>> { Python::with_gil(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py); let dict = PyDict::new(py);
dict.set_item("tables", response.tables)?; dict.set_item("tables", response.tables)?;

View File

@@ -40,34 +40,31 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
request_id, request_id,
source, source,
status_code, status_code,
} => { } => Python::with_gil(|py| {
#[allow(deprecated)] let message = err.to_string();
Python::with_gil(|py| { let http_err_cls = py
let message = err.to_string(); .import(intern!(py, "lancedb.remote.errors"))?
let http_err_cls = py .getattr(intern!(py, "HttpError"))?;
.import(intern!(py, "lancedb.remote.errors"))? let err = http_err_cls.call1((
.getattr(intern!(py, "HttpError"))?; message,
let err = http_err_cls.call1(( request_id,
message, status_code.map(|s| s.as_u16()),
))?;
if let Some(cause) = source.source() {
// The HTTP error already includes the first cause. But
// we can add the rest of the chain if there is any more.
let cause_err = http_from_rust_error(
py,
cause,
request_id, request_id,
status_code.map(|s| s.as_u16()), status_code.map(|s| s.as_u16()),
))?; )?;
err.setattr(intern!(py, "__cause__"), cause_err)?;
}
if let Some(cause) = source.source() { Err(PyErr::from_value(err))
// The HTTP error already includes the first cause. But }),
// we can add the rest of the chain if there is any more.
let cause_err = http_from_rust_error(
py,
cause,
request_id,
status_code.map(|s| s.as_u16()),
)?;
err.setattr(intern!(py, "__cause__"), cause_err)?;
}
Err(PyErr::from_value(err))
})
}
LanceError::Retry { LanceError::Retry {
request_id, request_id,
request_failures, request_failures,
@@ -78,37 +75,33 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
max_read_failures, max_read_failures,
source, source,
status_code, status_code,
} => } => Python::with_gil(|py| {
{ let cause_err = http_from_rust_error(
#[allow(deprecated)] py,
Python::with_gil(|py| { source.as_ref(),
let cause_err = http_from_rust_error( request_id,
py, status_code.map(|s| s.as_u16()),
source.as_ref(), )?;
request_id,
status_code.map(|s| s.as_u16()),
)?;
let message = err.to_string(); let message = err.to_string();
let retry_error_cls = py let retry_error_cls = py
.import(intern!(py, "lancedb.remote.errors"))? .import(intern!(py, "lancedb.remote.errors"))?
.getattr("RetryError")?; .getattr("RetryError")?;
let err = retry_error_cls.call1(( let err = retry_error_cls.call1((
message, message,
request_id, request_id,
*request_failures, *request_failures,
*connect_failures, *connect_failures,
*read_failures, *read_failures,
*max_request_failures, *max_request_failures,
*max_connect_failures, *max_connect_failures,
*max_read_failures, *max_read_failures,
status_code.map(|s| s.as_u16()), status_code.map(|s| s.as_u16()),
))?; ))?;
err.setattr(intern!(py, "__cause__"), cause_err)?; err.setattr(intern!(py, "__cause__"), cause_err)?;
Err(PyErr::from_value(err)) Err(PyErr::from_value(err))
}) }),
}
_ => self.runtime_error(), _ => self.runtime_error(),
}, },
} }

View File

@@ -12,7 +12,6 @@ pub struct PyHeaderProvider {
impl Clone for PyHeaderProvider { impl Clone for PyHeaderProvider {
fn clone(&self) -> Self { fn clone(&self) -> Self {
#[allow(deprecated)]
Python::with_gil(|py| Self { Python::with_gil(|py| Self {
provider: self.provider.clone_ref(py), provider: self.provider.clone_ref(py),
}) })
@@ -26,7 +25,6 @@ impl PyHeaderProvider {
/// Get headers from the Python provider (internal implementation) /// Get headers from the Python provider (internal implementation)
fn get_headers_internal(&self) -> Result<HashMap<String, String>, String> { fn get_headers_internal(&self) -> Result<HashMap<String, String>, String> {
#[allow(deprecated)]
Python::with_gil(|py| { Python::with_gil(|py| {
// Call the get_headers method // Call the get_headers method
let result = self.provider.call_method0(py, "get_headers"); let result = self.provider.call_method0(py, "get_headers");

View File

@@ -19,7 +19,7 @@ use pyo3::{
exceptions::PyRuntimeError, exceptions::PyRuntimeError,
pyclass, pymethods, pyclass, pymethods,
types::{PyAnyMethods, PyDict, PyDictMethods, PyType}, types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
Bound, Py, PyAny, PyRef, PyRefMut, PyResult, Python, Bound, PyAny, PyRef, PyRefMut, PyResult, Python,
}; };
use pyo3_async_runtimes::tokio::future_into_py; use pyo3_async_runtimes::tokio::future_into_py;
@@ -281,12 +281,7 @@ impl PyPermutationReader {
let reader = slf.reader.clone(); let reader = slf.reader.clone();
future_into_py(slf.py(), async move { future_into_py(slf.py(), async move {
let schema = reader.output_schema(selection).await.infer_error()?; let schema = reader.output_schema(selection).await.infer_error()?;
#[allow(deprecated)] Python::with_gil(|py| schema.to_pyarrow(py))
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
let bound = schema.to_pyarrow(py)?;
Ok(bound.unbind())
})?;
Ok(py_obj)
}) })
} }

View File

@@ -29,7 +29,6 @@ use pyo3::types::PyList;
use pyo3::types::{PyDict, PyString}; use pyo3::types::{PyDict, PyString};
use pyo3::Bound; use pyo3::Bound;
use pyo3::IntoPyObject; use pyo3::IntoPyObject;
use pyo3::Py;
use pyo3::PyAny; use pyo3::PyAny;
use pyo3::PyRef; use pyo3::PyRef;
use pyo3::PyResult; use pyo3::PyResult;
@@ -454,12 +453,7 @@ impl Query {
let inner = self_.inner.clone(); let inner = self_.inner.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?; let schema = inner.output_schema().await.infer_error()?;
#[allow(deprecated)] Python::with_gil(|py| schema.to_pyarrow(py))
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
let bound = schema.to_pyarrow(py)?;
Ok(bound.unbind())
})?;
Ok(py_obj)
}) })
} }
@@ -538,12 +532,7 @@ impl TakeQuery {
let inner = self_.inner.clone(); let inner = self_.inner.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?; let schema = inner.output_schema().await.infer_error()?;
#[allow(deprecated)] Python::with_gil(|py| schema.to_pyarrow(py))
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
let bound = schema.to_pyarrow(py)?;
Ok(bound.unbind())
})?;
Ok(py_obj)
}) })
} }
@@ -638,12 +627,7 @@ impl FTSQuery {
let inner = self_.inner.clone(); let inner = self_.inner.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?; let schema = inner.output_schema().await.infer_error()?;
#[allow(deprecated)] Python::with_gil(|py| schema.to_pyarrow(py))
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
let bound = schema.to_pyarrow(py)?;
Ok(bound.unbind())
})?;
Ok(py_obj)
}) })
} }
@@ -822,12 +806,7 @@ impl VectorQuery {
let inner = self_.inner.clone(); let inner = self_.inner.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?; let schema = inner.output_schema().await.infer_error()?;
#[allow(deprecated)] Python::with_gil(|py| schema.to_pyarrow(py))
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
let bound = schema.to_pyarrow(py)?;
Ok(bound.unbind())
})?;
Ok(py_obj)
}) })
} }

View File

@@ -17,12 +17,11 @@ use pyo3::types::PyDict;
/// Internal wrapper around a Python object implementing StorageOptionsProvider /// Internal wrapper around a Python object implementing StorageOptionsProvider
pub struct PyStorageOptionsProvider { pub struct PyStorageOptionsProvider {
/// The Python object implementing fetch_storage_options() /// The Python object implementing fetch_storage_options()
inner: Py<PyAny>, inner: PyObject,
} }
impl Clone for PyStorageOptionsProvider { impl Clone for PyStorageOptionsProvider {
fn clone(&self) -> Self { fn clone(&self) -> Self {
#[allow(deprecated)]
Python::with_gil(|py| Self { Python::with_gil(|py| Self {
inner: self.inner.clone_ref(py), inner: self.inner.clone_ref(py),
}) })
@@ -30,8 +29,7 @@ impl Clone for PyStorageOptionsProvider {
} }
impl PyStorageOptionsProvider { impl PyStorageOptionsProvider {
pub fn new(obj: Py<PyAny>) -> PyResult<Self> { pub fn new(obj: PyObject) -> PyResult<Self> {
#[allow(deprecated)]
Python::with_gil(|py| { Python::with_gil(|py| {
// Verify the object has a fetch_storage_options method // Verify the object has a fetch_storage_options method
if !obj.bind(py).hasattr("fetch_storage_options")? { if !obj.bind(py).hasattr("fetch_storage_options")? {
@@ -39,9 +37,7 @@ impl PyStorageOptionsProvider {
"StorageOptionsProvider must implement fetch_storage_options() method", "StorageOptionsProvider must implement fetch_storage_options() method",
)); ));
} }
Ok(Self { Ok(Self { inner: obj })
inner: obj.clone_ref(py),
})
}) })
} }
} }
@@ -64,7 +60,6 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
let py_provider = self.py_provider.clone(); let py_provider = self.py_provider.clone();
tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
#[allow(deprecated)]
Python::with_gil(|py| { Python::with_gil(|py| {
// Call the Python fetch_storage_options method // Call the Python fetch_storage_options method
let result = py_provider let result = py_provider
@@ -124,7 +119,6 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
} }
fn provider_id(&self) -> String { fn provider_id(&self) -> String {
#[allow(deprecated)]
Python::with_gil(|py| { Python::with_gil(|py| {
// Call provider_id() method on the Python object // Call provider_id() method on the Python object
let obj = self.py_provider.inner.bind(py); let obj = self.py_provider.inner.bind(py);
@@ -149,7 +143,7 @@ impl std::fmt::Debug for PyStorageOptionsProviderWrapper {
/// This is the main entry point for converting Python StorageOptionsProvider objects /// This is the main entry point for converting Python StorageOptionsProvider objects
/// to Rust trait objects that can be used by the Lance ecosystem. /// to Rust trait objects that can be used by the Lance ecosystem.
pub fn py_object_to_storage_options_provider( pub fn py_object_to_storage_options_provider(
py_obj: Py<PyAny>, py_obj: PyObject,
) -> PyResult<Arc<dyn StorageOptionsProvider>> { ) -> PyResult<Arc<dyn StorageOptionsProvider>> {
let py_provider = PyStorageOptionsProvider::new(py_obj)?; let py_provider = PyStorageOptionsProvider::new(py_obj)?;
Ok(Arc::new(PyStorageOptionsProviderWrapper::new(py_provider))) Ok(Arc::new(PyStorageOptionsProviderWrapper::new(py_provider)))

View File

@@ -21,7 +21,7 @@ use pyo3::{
exceptions::{PyKeyError, PyRuntimeError, PyValueError}, exceptions::{PyKeyError, PyRuntimeError, PyValueError},
pyclass, pymethods, pyclass, pymethods,
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods}, types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python, Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
}; };
use pyo3_async_runtimes::tokio::future_into_py; use pyo3_async_runtimes::tokio::future_into_py;
@@ -287,12 +287,7 @@ impl Table {
let inner = self_.inner_ref()?.clone(); let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let schema = inner.schema().await.infer_error()?; let schema = inner.schema().await.infer_error()?;
#[allow(deprecated)] Python::with_gil(|py| schema.to_pyarrow(py))
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
let bound = schema.to_pyarrow(py)?;
Ok(bound.unbind())
})?;
Ok(py_obj)
}) })
} }
@@ -442,7 +437,6 @@ impl Table {
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let stats = inner.index_stats(&index_name).await.infer_error()?; let stats = inner.index_stats(&index_name).await.infer_error()?;
if let Some(stats) = stats { if let Some(stats) = stats {
#[allow(deprecated)]
Python::with_gil(|py| { Python::with_gil(|py| {
let dict = PyDict::new(py); let dict = PyDict::new(py);
dict.set_item("num_indexed_rows", stats.num_indexed_rows)?; dict.set_item("num_indexed_rows", stats.num_indexed_rows)?;
@@ -473,7 +467,6 @@ impl Table {
let inner = self_.inner_ref()?.clone(); let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let stats = inner.stats().await.infer_error()?; let stats = inner.stats().await.infer_error()?;
#[allow(deprecated)]
Python::with_gil(|py| { Python::with_gil(|py| {
let dict = PyDict::new(py); let dict = PyDict::new(py);
dict.set_item("total_bytes", stats.total_bytes)?; dict.set_item("total_bytes", stats.total_bytes)?;
@@ -504,11 +497,6 @@ impl Table {
}) })
} }
pub fn uri(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move { inner.uri().await.infer_error() })
}
pub fn __repr__(&self) -> String { pub fn __repr__(&self) -> String {
match &self.inner { match &self.inner {
None => format!("ClosedTable({})", self.name), None => format!("ClosedTable({})", self.name),
@@ -528,7 +516,6 @@ impl Table {
let inner = self_.inner_ref()?.clone(); let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let versions = inner.list_versions().await.infer_error()?; let versions = inner.list_versions().await.infer_error()?;
#[allow(deprecated)]
let versions_as_dict = Python::with_gil(|py| { let versions_as_dict = Python::with_gil(|py| {
versions versions
.iter() .iter()
@@ -880,7 +867,6 @@ impl Tags {
let tags = inner.tags().await.infer_error()?; let tags = inner.tags().await.infer_error()?;
let res = tags.list().await.infer_error()?; let res = tags.list().await.infer_error()?;
#[allow(deprecated)]
Python::with_gil(|py| { Python::with_gil(|py| {
let py_dict = PyDict::new(py); let py_dict = PyDict::new(py);
for (key, contents) in res { for (key, contents) in res {

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb" name = "lancedb"
version = "0.23.1" version = "0.23.1-beta.1"
edition.workspace = true edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications" description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true license.workspace = true
@@ -104,16 +104,11 @@ test-log = "0.2"
[features] [features]
default = [] default = ["aws", "gcs", "azure", "dynamodb", "oss"]
aws = ["lance/aws", "lance-io/aws", "lance-namespace-impls/dir-aws"] aws = ["lance/aws", "lance-io/aws", "lance-namespace-impls/dir-aws"]
oss = ["lance/oss", "lance-io/oss", "lance-namespace-impls/dir-oss"] oss = ["lance/oss", "lance-io/oss", "lance-namespace-impls/dir-oss"]
gcs = ["lance/gcp", "lance-io/gcp", "lance-namespace-impls/dir-gcp"] gcs = ["lance/gcp", "lance-io/gcp", "lance-namespace-impls/dir-gcp"]
azure = ["lance/azure", "lance-io/azure", "lance-namespace-impls/dir-azure"] azure = ["lance/azure", "lance-io/azure", "lance-namespace-impls/dir-azure"]
huggingface = [
"lance/huggingface",
"lance-io/huggingface",
"lance-namespace-impls/dir-huggingface",
]
dynamodb = ["lance/dynamodb", "aws"] dynamodb = ["lance/dynamodb", "aws"]
remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"] remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
fp16kernels = ["lance-linalg/fp16kernels"] fp16kernels = ["lance-linalg/fp16kernels"]
@@ -153,6 +148,3 @@ name = "ivf_pq"
[[example]] [[example]]
name = "hybrid_search" name = "hybrid_search"
required-features = ["sentence-transformers"] required-features = ["sentence-transformers"]
[package.metadata.docs.rs]
all-features = true

View File

@@ -1325,27 +1325,25 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_table_names() { async fn test_table_names() {
let tc = new_test_connection().await.unwrap(); let tmp_dir = tempdir().unwrap();
let db = tc.connection;
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));
let mut names = Vec::with_capacity(100); let mut names = Vec::with_capacity(100);
for _ in 0..100 { for _ in 0..100 {
let name = uuid::Uuid::new_v4().to_string(); let mut name = uuid::Uuid::new_v4().to_string();
names.push(name.clone()); names.push(name.clone());
db.create_empty_table(name, schema.clone()) name.push_str(".lance");
.execute() create_dir_all(tmp_dir.path().join(&name)).unwrap();
.await
.unwrap();
} }
names.sort(); names.sort();
let tables = db.table_names().limit(100).execute().await.unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let tables = db.table_names().execute().await.unwrap();
assert_eq!(tables, names); assert_eq!(tables, names);
let tables = db let tables = db
.table_names() .table_names()
.start_after(&names[30]) .start_after(&names[30])
.limit(100)
.execute() .execute()
.await .await
.unwrap(); .unwrap();

View File

@@ -463,20 +463,9 @@ impl ListingDatabase {
validate_table_name(name)?; validate_table_name(name)?;
let mut uri = self.uri.clone(); let mut uri = self.uri.clone();
// If the URI does not end with a path separator, add one // If the URI does not end with a slash, add one
// Use forward slash for URIs (http://, s3://, gs://, file://, etc.) if !uri.ends_with('/') {
// Use platform-specific separator for local paths without scheme uri.push('/');
let has_scheme = uri.contains("://");
let ends_with_separator = uri.ends_with('/') || uri.ends_with('\\');
if !ends_with_separator {
if has_scheme {
// URIs always use forward slash
uri.push('/');
} else {
// Local path without scheme - use platform separator
uri.push(std::path::MAIN_SEPARATOR);
}
} }
// Append the table name with the lance file extension // Append the table name with the lance file extension
uri.push_str(&format!("{}.{}", name, LANCE_FILE_EXTENSION)); uri.push_str(&format!("{}.{}", name, LANCE_FILE_EXTENSION));
@@ -1082,7 +1071,6 @@ mod tests {
use crate::table::{Table, TableDefinition}; use crate::table::{Table, TableDefinition};
use arrow_array::{Int32Array, RecordBatch, StringArray}; use arrow_array::{Int32Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema}; use arrow_schema::{DataType, Field, Schema};
use std::path::PathBuf;
use tempfile::tempdir; use tempfile::tempdir;
async fn setup_database() -> (tempfile::TempDir, ListingDatabase) { async fn setup_database() -> (tempfile::TempDir, ListingDatabase) {
@@ -2058,19 +2046,6 @@ mod tests {
assert_eq!(db_options.new_table_config.enable_stable_row_ids, None); assert_eq!(db_options.new_table_config.enable_stable_row_ids, None);
} }
#[tokio::test]
async fn test_table_uri() {
let (_tempdir, db) = setup_database().await;
let mut pb = PathBuf::new();
pb.push(db.uri.clone());
pb.push("test.lance");
let expected = pb.to_str().unwrap();
let uri = db.table_uri("test").ok().unwrap();
assert_eq!(uri, expected);
}
#[tokio::test] #[tokio::test]
async fn test_namespace_client() { async fn test_namespace_client() {
let (_tempdir, db) = setup_database().await; let (_tempdir, db) = setup_database().await;

View File

@@ -9,7 +9,7 @@ use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use lance_namespace::{ use lance_namespace::{
models::{ models::{
CreateNamespaceRequest, CreateNamespaceResponse, DeclareTableRequest, CreateEmptyTableRequest, CreateNamespaceRequest, CreateNamespaceResponse,
DescribeNamespaceRequest, DescribeNamespaceResponse, DescribeTableRequest, DescribeNamespaceRequest, DescribeNamespaceResponse, DescribeTableRequest,
DropNamespaceRequest, DropNamespaceResponse, DropTableRequest, ListNamespacesRequest, DropNamespaceRequest, DropNamespaceResponse, DropTableRequest, ListNamespacesRequest,
ListNamespacesResponse, ListTablesRequest, ListTablesResponse, ListNamespacesResponse, ListTablesRequest, ListTablesResponse,
@@ -137,7 +137,6 @@ impl Database for LanceNamespaceDatabase {
id: Some(request.namespace), id: Some(request.namespace),
page_token: request.start_after, page_token: request.start_after,
limit: request.limit.map(|l| l as i32), limit: request.limit.map(|l| l as i32),
..Default::default()
}; };
let response = self.namespace.list_tables(ns_request).await?; let response = self.namespace.list_tables(ns_request).await?;
@@ -155,7 +154,7 @@ impl Database for LanceNamespaceDatabase {
let describe_request = DescribeTableRequest { let describe_request = DescribeTableRequest {
id: Some(table_id.clone()), id: Some(table_id.clone()),
version: None, version: None,
..Default::default() with_table_uri: None,
}; };
let describe_result = self.namespace.describe_table(describe_request).await; let describe_result = self.namespace.describe_table(describe_request).await;
@@ -173,7 +172,6 @@ impl Database for LanceNamespaceDatabase {
// Drop the existing table - must succeed // Drop the existing table - must succeed
let drop_request = DropTableRequest { let drop_request = DropTableRequest {
id: Some(table_id.clone()), id: Some(table_id.clone()),
..Default::default()
}; };
self.namespace self.namespace
.drop_table(drop_request) .drop_table(drop_request)
@@ -205,19 +203,22 @@ impl Database for LanceNamespaceDatabase {
let mut table_id = request.namespace.clone(); let mut table_id = request.namespace.clone();
table_id.push(request.name.clone()); table_id.push(request.name.clone());
let create_empty_request = DeclareTableRequest { let create_empty_request = CreateEmptyTableRequest {
id: Some(table_id.clone()), id: Some(table_id.clone()),
location: None, location: None,
vend_credentials: None, properties: if self.storage_options.is_empty() {
..Default::default() None
} else {
Some(self.storage_options.clone())
},
}; };
let create_empty_response = self let create_empty_response = self
.namespace .namespace
.declare_table(create_empty_request) .create_empty_table(create_empty_request)
.await .await
.map_err(|e| Error::Runtime { .map_err(|e| Error::Runtime {
message: format!("Failed to declare table: {}", e), message: format!("Failed to create empty table: {}", e),
})?; })?;
let location = create_empty_response let location = create_empty_response
@@ -281,10 +282,7 @@ impl Database for LanceNamespaceDatabase {
let mut table_id = namespace.to_vec(); let mut table_id = namespace.to_vec();
table_id.push(name.to_string()); table_id.push(name.to_string());
let drop_request = DropTableRequest { let drop_request = DropTableRequest { id: Some(table_id) };
id: Some(table_id),
..Default::default()
};
self.namespace self.namespace
.drop_table(drop_request) .drop_table(drop_request)
.await .await
@@ -441,7 +439,6 @@ mod tests {
id: Some(vec!["test_ns".into()]), id: Some(vec!["test_ns".into()]),
mode: None, mode: None,
properties: None, properties: None,
..Default::default()
}) })
.await .await
.expect("Failed to create namespace"); .expect("Failed to create namespace");
@@ -503,7 +500,6 @@ mod tests {
id: Some(vec!["test_ns".into()]), id: Some(vec!["test_ns".into()]),
mode: None, mode: None,
properties: None, properties: None,
..Default::default()
}) })
.await .await
.expect("Failed to create namespace"); .expect("Failed to create namespace");
@@ -568,7 +564,6 @@ mod tests {
id: Some(vec!["test_ns".into()]), id: Some(vec!["test_ns".into()]),
mode: None, mode: None,
properties: None, properties: None,
..Default::default()
}) })
.await .await
.expect("Failed to create namespace"); .expect("Failed to create namespace");
@@ -653,7 +648,6 @@ mod tests {
id: Some(vec!["test_ns".into()]), id: Some(vec!["test_ns".into()]),
mode: None, mode: None,
properties: None, properties: None,
..Default::default()
}) })
.await .await
.expect("Failed to create namespace"); .expect("Failed to create namespace");
@@ -710,7 +704,6 @@ mod tests {
id: Some(vec!["test_ns".into()]), id: Some(vec!["test_ns".into()]),
mode: None, mode: None,
properties: None, properties: None,
..Default::default()
}) })
.await .await
.expect("Failed to create namespace"); .expect("Failed to create namespace");
@@ -792,7 +785,6 @@ mod tests {
id: Some(vec!["test_ns".into()]), id: Some(vec!["test_ns".into()]),
mode: None, mode: None,
properties: None, properties: None,
..Default::default()
}) })
.await .await
.expect("Failed to create namespace"); .expect("Failed to create namespace");
@@ -827,7 +819,6 @@ mod tests {
id: Some(vec!["test_ns".into()]), id: Some(vec!["test_ns".into()]),
mode: None, mode: None,
properties: None, properties: None,
..Default::default()
}) })
.await .await
.expect("Failed to create namespace"); .expect("Failed to create namespace");

View File

@@ -120,13 +120,8 @@ impl MemoryRegistry {
} }
/// A record batch reader that has embeddings applied to it /// A record batch reader that has embeddings applied to it
/// /// This is a wrapper around another record batch reader that applies an embedding function
/// This is a wrapper around another record batch reader that applies embedding functions /// when reading from the record batch
/// when reading from the record batch.
///
/// When multiple embedding functions are defined, they are computed in parallel using
/// scoped threads to improve performance. For a single embedding function, computation
/// is done inline without threading overhead.
pub struct WithEmbeddings<R: RecordBatchReader> { pub struct WithEmbeddings<R: RecordBatchReader> {
inner: R, inner: R,
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>, embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
@@ -240,48 +235,6 @@ impl<R: RecordBatchReader> WithEmbeddings<R> {
column_definitions, column_definitions,
}) })
} }
fn compute_embeddings_parallel(&self, batch: &RecordBatch) -> Result<Vec<Arc<dyn Array>>> {
if self.embeddings.len() == 1 {
let (fld, func) = &self.embeddings[0];
let src_column =
batch
.column_by_name(&fld.source_column)
.ok_or_else(|| Error::InvalidInput {
message: format!("Source column '{}' not found", fld.source_column),
})?;
return Ok(vec![func.compute_source_embeddings(src_column.clone())?]);
}
// Parallel path: multiple embeddings
std::thread::scope(|s| {
let handles: Vec<_> = self
.embeddings
.iter()
.map(|(fld, func)| {
let src_column = batch.column_by_name(&fld.source_column).ok_or_else(|| {
Error::InvalidInput {
message: format!("Source column '{}' not found", fld.source_column),
}
})?;
let handle =
s.spawn(move || func.compute_source_embeddings(src_column.clone()));
Ok(handle)
})
.collect::<Result<_>>()?;
handles
.into_iter()
.map(|h| {
h.join().map_err(|e| Error::Runtime {
message: format!("Thread panicked during embedding computation: {:?}", e),
})?
})
.collect()
})
}
} }
impl<R: RecordBatchReader> Iterator for MaybeEmbedded<R> { impl<R: RecordBatchReader> Iterator for MaybeEmbedded<R> {
@@ -309,19 +262,19 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
let batch = self.inner.next()?; let batch = self.inner.next()?;
match batch { match batch {
Ok(batch) => { Ok(mut batch) => {
let embeddings = match self.compute_embeddings_parallel(&batch) { // todo: parallelize this
Ok(emb) => emb, for (fld, func) in self.embeddings.iter() {
Err(e) => { let src_column = batch.column_by_name(&fld.source_column).unwrap();
return Some(Err(arrow_schema::ArrowError::ComputeError(format!( let embedding = match func.compute_source_embeddings(src_column.clone()) {
"Error computing embedding: {}", Ok(embedding) => embedding,
e Err(e) => {
)))) return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
} "Error computing embedding: {}",
}; e
))))
let mut batch = batch; }
for ((fld, _), embedding) in self.embeddings.iter().zip(embeddings.iter()) { };
let dst_field_name = fld let dst_field_name = fld
.dest_column .dest_column
.clone() .clone()
@@ -333,7 +286,7 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
embedding.nulls().is_some(), embedding.nulls().is_some(),
); );
match batch.try_with_column(dst_field.clone(), embedding.clone()) { match batch.try_with_column(dst_field.clone(), embedding) {
Ok(b) => batch = b, Ok(b) => batch = b,
Err(e) => return Some(Err(e)), Err(e) => return Some(Err(e)),
}; };

View File

@@ -25,14 +25,13 @@
//! //!
//! ## Crate Features //! ## Crate Features
//! //!
//! - `aws` - Enable AWS S3 object store support. //! ### Experimental Features
//! - `dynamodb` - Enable DynamoDB manifest store support. //!
//! - `azure` - Enable Azure Blob Storage object store support. //! These features are not enabled by default. They are experimental or in-development features that
//! - `gcs` - Enable Google Cloud Storage object store support. //! are not yet ready to be released.
//! - `oss` - Enable Alibaba Cloud OSS object store support. //!
//! - `remote` - Enable remote client to connect to LanceDB cloud. //! - `remote` - Enable remote client to connect to LanceDB cloud. This is not yet fully implemented
//! - `huggingface` - Enable HuggingFace Hub integration for loading datasets from the Hub. //! and should not be enabled.
//! - `fp16kernels` - Enable FP16 kernels for faster vector search on CPU.
//! //!
//! ### Quick Start //! ### Quick Start
//! //!
@@ -54,8 +53,6 @@
//! You can also use [`ConnectOptions`] to configure the connection to the database. //! You can also use [`ConnectOptions`] to configure the connection to the database.
//! //!
//! ```rust //! ```rust
//! # #[cfg(feature = "aws")]
//! # {
//! use object_store::aws::AwsCredential; //! use object_store::aws::AwsCredential;
//! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! let db = lancedb::connect("data/sample-lancedb") //! let db = lancedb::connect("data/sample-lancedb")
@@ -68,7 +65,6 @@
//! .await //! .await
//! .unwrap(); //! .unwrap();
//! # }); //! # });
//! # }
//! ``` //! ```
//! //!
//! LanceDB uses [arrow-rs](https://github.com/apache/arrow-rs) to define schema, data types and array itself. //! LanceDB uses [arrow-rs](https://github.com/apache/arrow-rs) to define schema, data types and array itself.

View File

@@ -1720,7 +1720,6 @@ mod tests {
id: Some(namespace.clone()), id: Some(namespace.clone()),
mode: None, mode: None,
properties: None, properties: None,
..Default::default()
}) })
.await .await
.expect("Failed to create namespace"); .expect("Failed to create namespace");
@@ -1747,7 +1746,6 @@ mod tests {
id: Some(namespace.clone()), id: Some(namespace.clone()),
page_token: None, page_token: None,
limit: None, limit: None,
..Default::default()
}) })
.await .await
.expect("Failed to list tables"); .expect("Failed to list tables");
@@ -1760,7 +1758,6 @@ mod tests {
id: Some(namespace.clone()), id: Some(namespace.clone()),
page_token: None, page_token: None,
limit: None, limit: None,
..Default::default()
}) })
.await .await
.unwrap(); .unwrap();
@@ -1802,7 +1799,6 @@ mod tests {
id: Some(namespace.clone()), id: Some(namespace.clone()),
mode: None, mode: None,
properties: None, properties: None,
..Default::default()
}) })
.await .await
.expect("Failed to create namespace"); .expect("Failed to create namespace");
@@ -1829,7 +1825,6 @@ mod tests {
id: Some(namespace.clone()), id: Some(namespace.clone()),
page_token: None, page_token: None,
limit: None, limit: None,
..Default::default()
}) })
.await .await
.unwrap(); .unwrap();

View File

@@ -204,7 +204,6 @@ pub struct RemoteTable<S: HttpSend = Sender> {
server_version: ServerVersion, server_version: ServerVersion,
version: RwLock<Option<u64>>, version: RwLock<Option<u64>>,
location: RwLock<Option<String>>,
} }
impl<S: HttpSend> RemoteTable<S> { impl<S: HttpSend> RemoteTable<S> {
@@ -222,7 +221,6 @@ impl<S: HttpSend> RemoteTable<S> {
identifier, identifier,
server_version, server_version,
version: RwLock::new(None), version: RwLock::new(None),
location: RwLock::new(None),
} }
} }
@@ -641,7 +639,6 @@ impl<S: HttpSend> RemoteTable<S> {
struct TableDescription { struct TableDescription {
version: u64, version: u64,
schema: JsonSchema, schema: JsonSchema,
location: Option<String>,
} }
impl<S: HttpSend> std::fmt::Display for RemoteTable<S> { impl<S: HttpSend> std::fmt::Display for RemoteTable<S> {
@@ -670,7 +667,6 @@ mod test_utils {
identifier: name, identifier: name,
server_version: version.map(ServerVersion).unwrap_or_default(), server_version: version.map(ServerVersion).unwrap_or_default(),
version: RwLock::new(None), version: RwLock::new(None),
location: RwLock::new(None),
} }
} }
} }
@@ -1092,17 +1088,6 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
body["num_partitions"] = serde_json::Value::Number(num_partitions.into()); body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
} }
} }
Index::IvfRq(index) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_RQ".to_string());
body[METRIC_TYPE_KEY] =
serde_json::Value::String(index.distance_type.to_string().to_lowercase());
if let Some(num_partitions) = index.num_partitions {
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
}
if let Some(num_bits) = index.num_bits {
body["num_bits"] = serde_json::Value::Number(num_bits.into());
}
}
Index::BTree(_) => { Index::BTree(_) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("BTREE".to_string()); body[INDEX_TYPE_KEY] = serde_json::Value::String("BTREE".to_string());
} }
@@ -1465,28 +1450,8 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
message: "table_definition is not supported on LanceDB cloud.".into(), message: "table_definition is not supported on LanceDB cloud.".into(),
}) })
} }
async fn uri(&self) -> Result<String> { fn dataset_uri(&self) -> &str {
// Check if we already have the location cached "NOT_SUPPORTED"
{
let location = self.location.read().await;
if let Some(ref loc) = *location {
return Ok(loc.clone());
}
}
// Fetch from server via describe
let description = self.describe().await?;
let location = description.location.ok_or_else(|| Error::NotSupported {
message: "Table URI not supported by the server".into(),
})?;
// Cache the location for future use
{
let mut cached_location = self.location.write().await;
*cached_location = Some(location.clone());
}
Ok(location)
} }
async fn storage_options(&self) -> Option<HashMap<String, String>> { async fn storage_options(&self) -> Option<HashMap<String, String>> {
@@ -3356,69 +3321,4 @@ mod tests {
let result = table.drop_columns(&["old_col1", "old_col2"]).await.unwrap(); let result = table.drop_columns(&["old_col1", "old_col2"]).await.unwrap();
assert_eq!(result.version, 5); assert_eq!(result.version, 5);
} }
#[tokio::test]
async fn test_uri() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/v1/table/my_table/describe/");
http::Response::builder()
.status(200)
.body(r#"{"version": 1, "schema": {"fields": []}, "location": "s3://bucket/path/to/table"}"#)
.unwrap()
});
let uri = table.uri().await.unwrap();
assert_eq!(uri, "s3://bucket/path/to/table");
}
#[tokio::test]
async fn test_uri_missing_location() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/v1/table/my_table/describe/");
// Server returns response without location field
http::Response::builder()
.status(200)
.body(r#"{"version": 1, "schema": {"fields": []}}"#)
.unwrap()
});
let result = table.uri().await;
assert!(result.is_err());
assert!(matches!(&result, Err(Error::NotSupported { .. })));
}
#[tokio::test]
async fn test_uri_caching() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let table = Table::new_with_handler("my_table", move |request| {
assert_eq!(request.url().path(), "/v1/table/my_table/describe/");
call_count_clone.fetch_add(1, Ordering::SeqCst);
http::Response::builder()
.status(200)
.body(
r#"{"version": 1, "schema": {"fields": []}, "location": "gs://bucket/table"}"#,
)
.unwrap()
});
// First call should fetch from server
let uri1 = table.uri().await.unwrap();
assert_eq!(uri1, "gs://bucket/table");
assert_eq!(call_count.load(Ordering::SeqCst), 1);
// Second call should use cached value
let uri2 = table.uri().await.unwrap();
assert_eq!(uri2, "gs://bucket/table");
assert_eq!(call_count.load(Ordering::SeqCst), 1); // Still 1, no new call
}
} }

View File

@@ -608,8 +608,8 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
async fn list_versions(&self) -> Result<Vec<Version>>; async fn list_versions(&self) -> Result<Vec<Version>>;
/// Get the table definition. /// Get the table definition.
async fn table_definition(&self) -> Result<TableDefinition>; async fn table_definition(&self) -> Result<TableDefinition>;
/// Get the table URI (storage location) /// Get the table URI
async fn uri(&self) -> Result<String>; fn dataset_uri(&self) -> &str;
/// Get the storage options used when opening this table, if any. /// Get the storage options used when opening this table, if any.
async fn storage_options(&self) -> Option<HashMap<String, String>>; async fn storage_options(&self) -> Option<HashMap<String, String>>;
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns /// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
@@ -1317,12 +1317,11 @@ impl Table {
self.inner.list_indices().await self.inner.list_indices().await
} }
/// Get the table URI (storage location) /// Get the underlying dataset URI
/// ///
/// Returns the full storage location of the table (e.g., S3/GCS path). /// Warning: This is an internal API and the return value is subject to change.
/// For remote tables, this fetches the location from the server via describe. pub fn dataset_uri(&self) -> &str {
pub async fn uri(&self) -> Result<String> { self.inner.dataset_uri()
self.inner.uri().await
} }
/// Get the storage options used when opening this table, if any. /// Get the storage options used when opening this table, if any.
@@ -1425,9 +1424,7 @@ impl Table {
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let unioned = UnionExec::try_new(projected_plans).map_err(|e| Error::Runtime { let unioned = Arc::new(UnionExec::new(projected_plans));
message: format!("Failed to build union plan: {e}"),
})?;
// We require 1 partition in the final output // We require 1 partition in the final output
let repartitioned = RepartitionExec::try_new( let repartitioned = RepartitionExec::try_new(
unioned, unioned,
@@ -2349,7 +2346,7 @@ impl NativeTable {
}; };
// Convert select to columns list // Convert select to columns list
let columns: Option<Box<QueryTableRequestColumns>> = match &vq.base.select { let columns = match &vq.base.select {
Select::All => None, Select::All => None,
Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns { Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns {
column_names: Some(cols.clone()), column_names: Some(cols.clone()),
@@ -2408,7 +2405,6 @@ impl NativeTable {
bypass_vector_index: Some(!vq.use_index), bypass_vector_index: Some(!vq.use_index),
full_text_query, full_text_query,
version: None, version: None,
..Default::default()
}) })
} }
AnyQuery::Query(q) => { AnyQuery::Query(q) => {
@@ -2426,7 +2422,7 @@ impl NativeTable {
.map(|f| self.filter_to_sql(f)) .map(|f| self.filter_to_sql(f))
.transpose()?; .transpose()?;
let columns: Option<Box<QueryTableRequestColumns>> = match &q.select { let columns = match &q.select {
Select::All => None, Select::All => None,
Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns { Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns {
column_names: Some(cols.clone()), column_names: Some(cols.clone()),
@@ -2482,7 +2478,6 @@ impl NativeTable {
fast_search: None, fast_search: None,
lower_bound: None, lower_bound: None,
upper_bound: None, upper_bound: None,
..Default::default()
}) })
} }
} }
@@ -3235,8 +3230,8 @@ impl BaseTable for NativeTable {
Ok(results.into_iter().flatten().collect()) Ok(results.into_iter().flatten().collect())
} }
async fn uri(&self) -> Result<String> { fn dataset_uri(&self) -> &str {
Ok(self.uri.clone()) self.uri.as_str()
} }
async fn storage_options(&self) -> Option<HashMap<String, String>> { async fn storage_options(&self) -> Option<HashMap<String, String>> {
@@ -5154,15 +5149,16 @@ mod tests {
let any_query = AnyQuery::VectorQuery(vq); let any_query = AnyQuery::VectorQuery(vq);
let ns_request = table.convert_to_namespace_query(&any_query).unwrap(); let ns_request = table.convert_to_namespace_query(&any_query).unwrap();
let column_names = ns_request
.columns
.as_ref()
.and_then(|cols| cols.column_names.clone());
assert_eq!(ns_request.k, 10); assert_eq!(ns_request.k, 10);
assert_eq!(ns_request.offset, Some(5)); assert_eq!(ns_request.offset, Some(5));
assert_eq!(ns_request.filter, Some("id > 0".to_string())); assert_eq!(ns_request.filter, Some("id > 0".to_string()));
assert_eq!(column_names, Some(vec!["id".to_string()])); assert_eq!(
ns_request
.columns
.as_ref()
.and_then(|c| c.column_names.clone()),
Some(vec!["id".to_string()])
);
assert_eq!(ns_request.vector_column, Some("vector".to_string())); assert_eq!(ns_request.vector_column, Some("vector".to_string()));
assert_eq!(ns_request.distance_type, Some("l2".to_string())); assert_eq!(ns_request.distance_type, Some("l2".to_string()));
assert!(ns_request.vector.single_vector.is_some()); assert!(ns_request.vector.single_vector.is_some());
@@ -5199,16 +5195,17 @@ mod tests {
let any_query = AnyQuery::Query(q); let any_query = AnyQuery::Query(q);
let ns_request = table.convert_to_namespace_query(&any_query).unwrap(); let ns_request = table.convert_to_namespace_query(&any_query).unwrap();
let column_names = ns_request
.columns
.as_ref()
.and_then(|cols| cols.column_names.clone());
// Plain queries should pass an empty vector // Plain queries should pass an empty vector
assert_eq!(ns_request.k, 20); assert_eq!(ns_request.k, 20);
assert_eq!(ns_request.offset, Some(5)); assert_eq!(ns_request.offset, Some(5));
assert_eq!(ns_request.filter, Some("id > 5".to_string())); assert_eq!(ns_request.filter, Some("id > 5".to_string()));
assert_eq!(column_names, Some(vec!["id".to_string()])); assert_eq!(
ns_request
.columns
.as_ref()
.and_then(|c| c.column_names.clone()),
Some(vec!["id".to_string()])
);
assert_eq!(ns_request.with_row_id, Some(true)); assert_eq!(ns_request.with_row_id, Some(true));
assert_eq!(ns_request.bypass_vector_index, Some(true)); assert_eq!(ns_request.bypass_vector_index, Some(true));
assert!(ns_request.vector_column.is_none()); // No vector column for plain queries assert!(ns_request.vector_column.is_none()); // No vector column for plain queries

View File

@@ -5,19 +5,16 @@
use regex::Regex; use regex::Regex;
use std::env; use std::env;
use std::process::Stdio; use std::io::{BufRead, BufReader};
use tokio::io::{AsyncBufReadExt, BufReader}; use std::process::{Child, ChildStdout, Command, Stdio};
use tokio::process::{Child, ChildStdout, Command};
use tokio::sync::mpsc;
use crate::{connect, Connection}; use crate::{connect, Connection};
use anyhow::{anyhow, bail, Result}; use anyhow::{bail, Result};
use tempfile::{tempdir, TempDir}; use tempfile::{tempdir, TempDir};
pub struct TestConnection { pub struct TestConnection {
pub uri: String, pub uri: String,
pub connection: Connection, pub connection: Connection,
pub is_remote: bool,
_temp_dir: Option<TempDir>, _temp_dir: Option<TempDir>,
_process: Option<TestProcess>, _process: Option<TestProcess>,
} }
@@ -40,56 +37,6 @@ pub async fn new_test_connection() -> Result<TestConnection> {
} }
} }
async fn spawn_stdout_reader(
mut stdout: BufReader<ChildStdout>,
port_sender: mpsc::Sender<anyhow::Result<String>>,
) -> tokio::task::JoinHandle<()> {
let print_stdout = env::var("PRINT_LANCEDB_TEST_CONNECTION_SCRIPT_OUTPUT").is_ok();
tokio::spawn(async move {
let mut line = String::new();
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
loop {
line.clear();
let result = stdout.read_line(&mut line).await;
if let Err(err) = result {
port_sender
.send(Err(anyhow!(
"error while reading from process output: {}",
err
)))
.await
.unwrap();
return;
} else if result.unwrap() == 0 {
port_sender
.send(Err(anyhow!(
" hit EOF before reading port from process output."
)))
.await
.unwrap();
return;
}
if re.is_match(&line) {
let caps = re.captures(&line).unwrap();
port_sender.send(Ok(caps[1].to_string())).await.unwrap();
break;
}
}
loop {
line.clear();
match stdout.read_line(&mut line).await {
Err(_) => return,
Ok(0) => return,
Ok(_size) => {
if print_stdout {
print!("{}", line);
}
}
}
}
})
}
async fn new_remote_connection(script_path: &str) -> Result<TestConnection> { async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
let temp_dir = tempdir()?; let temp_dir = tempdir()?;
let data_path = temp_dir.path().to_str().unwrap().to_string(); let data_path = temp_dir.path().to_str().unwrap().to_string();
@@ -110,25 +57,38 @@ async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
child: child_result.unwrap(), child: child_result.unwrap(),
}; };
let stdout = BufReader::new(process.child.stdout.take().unwrap()); let stdout = BufReader::new(process.child.stdout.take().unwrap());
let (port_sender, mut port_receiver) = mpsc::channel(5); let port = read_process_port(stdout)?;
let _reader = spawn_stdout_reader(stdout, port_sender).await;
let port = match port_receiver.recv().await {
None => bail!("Unable to determine the port number used by the phalanx process we spawned, because the reader thread was closed too soon."),
Some(Err(err)) => bail!("Unable to determine the port number used by the phalanx process we spawned, because of an error, {}", err),
Some(Ok(port)) => port,
};
let uri = "db://test"; let uri = "db://test";
let host_override = format!("http://localhost:{}", port); let host_override = format!("http://localhost:{}", port);
let connection = create_new_connection(uri, &host_override).await?; let connection = create_new_connection(uri, &host_override).await?;
Ok(TestConnection { Ok(TestConnection {
uri: uri.to_string(), uri: uri.to_string(),
connection, connection,
is_remote: true,
_temp_dir: Some(temp_dir), _temp_dir: Some(temp_dir),
_process: Some(process), _process: Some(process),
}) })
} }
fn read_process_port(mut stdout: BufReader<ChildStdout>) -> Result<String> {
let mut line = String::new();
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
loop {
let result = stdout.read_line(&mut line);
if let Err(err) = result {
bail!(format!(
"read_process_port: error while reading from process output: {}",
err
));
} else if result.unwrap() == 0 {
bail!("read_process_port: hit EOF before reading port from process output.");
}
if re.is_match(&line) {
let caps = re.captures(&line).unwrap();
return Ok(caps[1].to_string());
}
}
}
#[cfg(feature = "remote")] #[cfg(feature = "remote")]
async fn create_new_connection(uri: &str, host_override: &str) -> crate::error::Result<Connection> { async fn create_new_connection(uri: &str, host_override: &str) -> crate::error::Result<Connection> {
connect(uri) connect(uri)
@@ -154,7 +114,6 @@ async fn new_local_connection() -> Result<TestConnection> {
Ok(TestConnection { Ok(TestConnection {
uri: uri.to_string(), uri: uri.to_string(),
connection, connection,
is_remote: false,
_temp_dir: Some(temp_dir), _temp_dir: Some(temp_dir),
_process: None, _process: None,
}) })

View File

@@ -1,253 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{
borrow::Cow,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
use arrow::buffer::NullBuffer;
use arrow_array::{
Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use lancedb::{
embeddings::{EmbeddingDefinition, EmbeddingFunction, MaybeEmbedded, WithEmbeddings},
Error, Result,
};
#[derive(Debug)]
struct SlowMockEmbed {
name: String,
dim: usize,
delay_ms: u64,
call_count: Arc<AtomicUsize>,
}
impl SlowMockEmbed {
pub fn new(name: String, dim: usize, delay_ms: u64) -> Self {
Self {
name,
dim,
delay_ms,
call_count: Arc::new(AtomicUsize::new(0)),
}
}
pub fn get_call_count(&self) -> usize {
self.call_count.load(Ordering::SeqCst)
}
}
impl EmbeddingFunction for SlowMockEmbed {
fn name(&self) -> &str {
&self.name
}
fn source_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
self.dim as _,
true,
)))
}
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
// Simulate slow embedding computation
std::thread::sleep(Duration::from_millis(self.delay_ms));
self.call_count.fetch_add(1, Ordering::SeqCst);
let len = source.len();
let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim]));
let field = Field::new("item", inner.data_type().clone(), false);
let arr = FixedSizeListArray::new(
Arc::new(field),
self.dim as _,
inner,
Some(NullBuffer::new_valid(len)),
);
Ok(Arc::new(arr))
}
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}
fn create_test_batch() -> Result<RecordBatch> {
let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
let text = StringArray::from(vec!["hello", "world"]);
RecordBatch::try_new(schema, vec![Arc::new(text)]).map_err(|e| Error::Runtime {
message: format!("Failed to create test batch: {}", e),
})
}
#[test]
fn test_single_embedding_fast_path() {
// Single embedding should execute without spawning threads
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let embed = Arc::new(SlowMockEmbed::new("test".to_string(), 2, 10));
let embedding_def = EmbeddingDefinition::new("text", "test", Some("embedding"));
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let embeddings = vec![(embedding_def, embed.clone() as Arc<dyn EmbeddingFunction>)];
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
let result = with_embeddings.next().unwrap().unwrap();
assert!(result.column_by_name("embedding").is_some());
assert_eq!(embed.get_call_count(), 1);
}
#[test]
fn test_multiple_embeddings_parallel() {
// Multiple embeddings should execute in parallel
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let embed1 = Arc::new(SlowMockEmbed::new("embed1".to_string(), 2, 100));
let embed2 = Arc::new(SlowMockEmbed::new("embed2".to_string(), 3, 100));
let embed3 = Arc::new(SlowMockEmbed::new("embed3".to_string(), 4, 100));
let def1 = EmbeddingDefinition::new("text", "embed1", Some("emb1"));
let def2 = EmbeddingDefinition::new("text", "embed2", Some("emb2"));
let def3 = EmbeddingDefinition::new("text", "embed3", Some("emb3"));
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let embeddings = vec![
(def1, embed1.clone() as Arc<dyn EmbeddingFunction>),
(def2, embed2.clone() as Arc<dyn EmbeddingFunction>),
(def3, embed3.clone() as Arc<dyn EmbeddingFunction>),
];
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
let result = with_embeddings.next().unwrap().unwrap();
// Verify all embedding columns are present
assert!(result.column_by_name("emb1").is_some());
assert!(result.column_by_name("emb2").is_some());
assert!(result.column_by_name("emb3").is_some());
// Verify all embeddings were computed
assert_eq!(embed1.get_call_count(), 1);
assert_eq!(embed2.get_call_count(), 1);
assert_eq!(embed3.get_call_count(), 1);
}
#[test]
fn test_embedding_column_order_preserved() {
// Verify that embedding columns are added in the same order as definitions
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let embed1 = Arc::new(SlowMockEmbed::new("embed1".to_string(), 2, 10));
let embed2 = Arc::new(SlowMockEmbed::new("embed2".to_string(), 3, 10));
let embed3 = Arc::new(SlowMockEmbed::new("embed3".to_string(), 4, 10));
let def1 = EmbeddingDefinition::new("text", "embed1", Some("first"));
let def2 = EmbeddingDefinition::new("text", "embed2", Some("second"));
let def3 = EmbeddingDefinition::new("text", "embed3", Some("third"));
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let embeddings = vec![
(def1, embed1 as Arc<dyn EmbeddingFunction>),
(def2, embed2 as Arc<dyn EmbeddingFunction>),
(def3, embed3 as Arc<dyn EmbeddingFunction>),
];
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
let result = with_embeddings.next().unwrap().unwrap();
let result_schema = result.schema();
// Original column is first
assert_eq!(result_schema.field(0).name(), "text");
// Embedding columns follow in order
assert_eq!(result_schema.field(1).name(), "first");
assert_eq!(result_schema.field(2).name(), "second");
assert_eq!(result_schema.field(3).name(), "third");
}
#[test]
fn test_embedding_error_propagation() {
// Test that errors from embedding computation are properly propagated
#[derive(Debug)]
struct FailingEmbed {
name: String,
}
impl EmbeddingFunction for FailingEmbed {
fn name(&self) -> &str {
&self.name
}
fn source_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
2,
true,
)))
}
fn compute_source_embeddings(&self, _source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
Err(Error::Runtime {
message: "Intentional failure".to_string(),
})
}
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let embed = Arc::new(FailingEmbed {
name: "failing".to_string(),
});
let def = EmbeddingDefinition::new("text", "failing", Some("emb"));
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let embeddings = vec![(def, embed as Arc<dyn EmbeddingFunction>)];
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
let result = with_embeddings.next().unwrap();
assert!(result.is_err());
let err_msg = format!("{}", result.err().unwrap());
assert!(err_msg.contains("Intentional failure"));
}
#[test]
fn test_maybe_embedded_with_no_embeddings() {
// Test that MaybeEmbedded::No variant works correctly
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone());
let table_def = lancedb::table::TableDefinition {
schema: schema.clone(),
column_definitions: vec![lancedb::table::ColumnDefinition {
kind: lancedb::table::ColumnKind::Physical,
}],
};
let mut maybe_embedded = MaybeEmbedded::try_new(reader, table_def, None).unwrap();
let result = maybe_embedded.next().unwrap().unwrap();
assert_eq!(result.num_columns(), 1);
assert_eq!(result.column(0).as_ref(), batch.column(0).as_ref());
}