mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-15 00:02:59 +00:00
Compare commits
13 Commits
python-v0.
...
codex/upda
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b2e0fc89e0 | ||
|
|
1840aa7edc | ||
|
|
489c91c5d6 | ||
|
|
f0c3fe5c6d | ||
|
|
2f6d525802 | ||
|
|
4494eb9e56 | ||
|
|
d67a8743ba | ||
|
|
46fcbbc1e3 | ||
|
|
ff53b76ac0 | ||
|
|
2adb10e6a8 | ||
|
|
ac164c352b | ||
|
|
8bcac7e372 | ||
|
|
e496184ab2 |
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.23.1-beta.1"
|
current_version = "0.23.1"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
4
.github/workflows/rust.yml
vendored
4
.github/workflows/rust.yml
vendored
@@ -167,13 +167,13 @@ jobs:
|
|||||||
- name: Build
|
- name: Build
|
||||||
run: |
|
run: |
|
||||||
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
|
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
|
||||||
cargo build --profile ci --features remote --tests --locked --target ${{ matrix.target }}
|
cargo build --profile ci --features aws,remote --tests --locked --target ${{ matrix.target }}
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
# Can only run tests when target matches host
|
# Can only run tests when target matches host
|
||||||
if: ${{ matrix.target == 'x86_64-pc-windows-msvc' }}
|
if: ${{ matrix.target == 'x86_64-pc-windows-msvc' }}
|
||||||
run: |
|
run: |
|
||||||
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
|
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
|
||||||
cargo test --profile ci --features remote --locked
|
cargo test --profile ci --features aws,remote --locked
|
||||||
|
|
||||||
msrv:
|
msrv:
|
||||||
# Check the minimum supported Rust version
|
# Check the minimum supported Rust version
|
||||||
|
|||||||
857
Cargo.lock
generated
857
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
60
Cargo.toml
60
Cargo.toml
@@ -15,39 +15,39 @@ categories = ["database-implementations"]
|
|||||||
rust-version = "1.78.0"
|
rust-version = "1.78.0"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=1.0.1", default-features = false }
|
lance = { "version" = "=2.0.0-beta.8", default-features = false, "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-core = "=1.0.1"
|
lance-core = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-datagen = "=1.0.1"
|
lance-datagen = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-file = "=1.0.1"
|
lance-file = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-io = { "version" = "=1.0.1", default-features = false }
|
lance-io = { "version" = "=2.0.0-beta.8", default-features = false, "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-index = "=1.0.1"
|
lance-index = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-linalg = "=1.0.1"
|
lance-linalg = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-namespace = "=1.0.1"
|
lance-namespace = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-namespace-impls = { "version" = "=1.0.1", default-features = false }
|
lance-namespace-impls = { "version" = "=2.0.0-beta.8", default-features = false, "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-table = "=1.0.1"
|
lance-table = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-testing = "=1.0.1"
|
lance-testing = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-datafusion = "=1.0.1"
|
lance-datafusion = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-encoding = "=1.0.1"
|
lance-encoding = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-arrow = "=1.0.1"
|
lance-arrow = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
ahash = "0.8"
|
ahash = "0.8"
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "56.2", optional = false }
|
arrow = { version = "57.2", optional = false }
|
||||||
arrow-array = "56.2"
|
arrow-array = "57.2"
|
||||||
arrow-data = "56.2"
|
arrow-data = "57.2"
|
||||||
arrow-ipc = "56.2"
|
arrow-ipc = "57.2"
|
||||||
arrow-ord = "56.2"
|
arrow-ord = "57.2"
|
||||||
arrow-schema = "56.2"
|
arrow-schema = "57.2"
|
||||||
arrow-select = "56.2"
|
arrow-select = "57.2"
|
||||||
arrow-cast = "56.2"
|
arrow-cast = "57.2"
|
||||||
async-trait = "0"
|
async-trait = "0"
|
||||||
datafusion = { version = "50.1", default-features = false }
|
datafusion = { version = "51.0", default-features = false }
|
||||||
datafusion-catalog = "50.1"
|
datafusion-catalog = "51.0"
|
||||||
datafusion-common = { version = "50.1", default-features = false }
|
datafusion-common = { version = "51.0", default-features = false }
|
||||||
datafusion-execution = "50.1"
|
datafusion-execution = "51.0"
|
||||||
datafusion-expr = "50.1"
|
datafusion-expr = "51.0"
|
||||||
datafusion-physical-plan = "50.1"
|
datafusion-physical-plan = "51.0"
|
||||||
env_logger = "0.11"
|
env_logger = "0.11"
|
||||||
half = { "version" = "2.6.0", default-features = false, features = [
|
half = { "version" = "2.7.1", default-features = false, features = [
|
||||||
"num-traits",
|
"num-traits",
|
||||||
] }
|
] }
|
||||||
futures = "0"
|
futures = "0"
|
||||||
@@ -59,7 +59,7 @@ rand = "0.9"
|
|||||||
snafu = "0.8"
|
snafu = "0.8"
|
||||||
url = "2"
|
url = "2"
|
||||||
num-traits = "0.2"
|
num-traits = "0.2"
|
||||||
regex = "1.10"
|
regex = "1.12"
|
||||||
lazy_static = "1"
|
lazy_static = "1"
|
||||||
semver = "1.0.25"
|
semver = "1.0.25"
|
||||||
chrono = "0.4"
|
chrono = "0.4"
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ check_command_exists() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if [[ ! -e ./lancedb ]]; then
|
if [[ ! -e ./lancedb ]]; then
|
||||||
if [[ -v SOPHON_READ_TOKEN ]]; then
|
if [[ x${SOPHON_READ_TOKEN} != "x" ]]; then
|
||||||
INPUT="lancedb-linux-x64"
|
INPUT="lancedb-linux-x64"
|
||||||
gh release \
|
gh release \
|
||||||
--repo lancedb/lancedb \
|
--repo lancedb/lancedb \
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ watch:
|
|||||||
theme:
|
theme:
|
||||||
name: "material"
|
name: "material"
|
||||||
logo: assets/logo.png
|
logo: assets/logo.png
|
||||||
favicon: assets/logo.png
|
favicon: assets/favicon.ico
|
||||||
palette:
|
palette:
|
||||||
# Palette toggle for light mode
|
# Palette toggle for light mode
|
||||||
- scheme: lancedb
|
- scheme: lancedb
|
||||||
@@ -32,8 +32,6 @@ theme:
|
|||||||
- content.tooltips
|
- content.tooltips
|
||||||
- toc.follow
|
- toc.follow
|
||||||
- navigation.top
|
- navigation.top
|
||||||
- navigation.tabs
|
|
||||||
- navigation.tabs.sticky
|
|
||||||
- navigation.footer
|
- navigation.footer
|
||||||
- navigation.tracking
|
- navigation.tracking
|
||||||
- navigation.instant
|
- navigation.instant
|
||||||
@@ -115,12 +113,13 @@ markdown_extensions:
|
|||||||
emoji_index: !!python/name:material.extensions.emoji.twemoji
|
emoji_index: !!python/name:material.extensions.emoji.twemoji
|
||||||
emoji_generator: !!python/name:material.extensions.emoji.to_svg
|
emoji_generator: !!python/name:material.extensions.emoji.to_svg
|
||||||
- markdown.extensions.toc:
|
- markdown.extensions.toc:
|
||||||
baselevel: 1
|
toc_depth: 3
|
||||||
permalink: ""
|
permalink: true
|
||||||
|
permalink_title: Anchor link to this section
|
||||||
|
|
||||||
nav:
|
nav:
|
||||||
- API reference:
|
- Documentation:
|
||||||
- Overview: index.md
|
- SDK Reference: index.md
|
||||||
- Python: python/python.md
|
- Python: python/python.md
|
||||||
- Javascript/TypeScript: js/globals.md
|
- Javascript/TypeScript: js/globals.md
|
||||||
- Java: java/java.md
|
- Java: java/java.md
|
||||||
|
|||||||
BIN
docs/src/assets/favicon.ico
Normal file
BIN
docs/src/assets/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 15 KiB |
@@ -0,0 +1,111 @@
|
|||||||
|
# VoyageAI Embeddings : Multimodal
|
||||||
|
|
||||||
|
VoyageAI embeddings can also be used to embed both text and image data, only some of the models support image data and you can check the list
|
||||||
|
under [https://docs.voyageai.com/docs/multimodal-embeddings](https://docs.voyageai.com/docs/multimodal-embeddings)
|
||||||
|
|
||||||
|
Supported multimodal models:
|
||||||
|
|
||||||
|
- `voyage-multimodal-3` - 1024 dimensions (text + images)
|
||||||
|
- `voyage-multimodal-3.5` - Flexible dimensions (256, 512, 1024 default, 2048). Supports text, images, and video.
|
||||||
|
|
||||||
|
### Video Support (voyage-multimodal-3.5)
|
||||||
|
|
||||||
|
The `voyage-multimodal-3.5` model supports video input through:
|
||||||
|
- Video URLs (`.mp4`, `.webm`, `.mov`, `.avi`, `.mkv`, `.m4v`, `.gif`)
|
||||||
|
- Video file paths
|
||||||
|
|
||||||
|
Constraints: Max 20MB video size.
|
||||||
|
|
||||||
|
Supported parameters (to be passed in `create` method) are:
|
||||||
|
|
||||||
|
| Parameter | Type | Default Value | Description |
|
||||||
|
|---|---|-------------------------|-------------------------------------------|
|
||||||
|
| `name` | `str` | `"voyage-multimodal-3"` | The model ID of the VoyageAI model to use |
|
||||||
|
| `output_dimension` | `int` | `None` | Output dimension for voyage-multimodal-3.5. Valid: 256, 512, 1024, 2048 |
|
||||||
|
|
||||||
|
Usage Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import lancedb
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
from lancedb.embeddings import get_registry
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
os.environ['VOYAGE_API_KEY'] = 'YOUR_VOYAGE_API_KEY'
|
||||||
|
|
||||||
|
db = lancedb.connect(".lancedb")
|
||||||
|
func = get_registry().get("voyageai").create(name="voyage-multimodal-3")
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_base64(image_bytes: bytes):
|
||||||
|
buffered = BytesIO(image_bytes)
|
||||||
|
img_str = base64.b64encode(buffered.getvalue())
|
||||||
|
return img_str.decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
class Images(LanceModel):
|
||||||
|
label: str
|
||||||
|
image_uri: str = func.SourceField() # image uri as the source
|
||||||
|
image_bytes: str = func.SourceField() # image bytes base64 encoded as the source
|
||||||
|
vector: Vector(func.ndims()) = func.VectorField() # vector column
|
||||||
|
vec_from_bytes: Vector(func.ndims()) = func.VectorField() # Another vector column
|
||||||
|
|
||||||
|
|
||||||
|
if "images" in db.table_names():
|
||||||
|
db.drop_table("images")
|
||||||
|
table = db.create_table("images", schema=Images)
|
||||||
|
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||||
|
uris = [
|
||||||
|
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||||
|
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||||
|
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||||
|
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||||
|
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||||
|
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||||
|
]
|
||||||
|
# get each uri as bytes
|
||||||
|
images_bytes = [image_to_base64(requests.get(uri).content) for uri in uris]
|
||||||
|
table.add(
|
||||||
|
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": images_bytes})
|
||||||
|
)
|
||||||
|
```
|
||||||
|
Now we can search using text from both the default vector column and the custom vector column
|
||||||
|
```python
|
||||||
|
|
||||||
|
# text search
|
||||||
|
actual = table.search("man's best friend", "vec_from_bytes").limit(1).to_pydantic(Images)[0]
|
||||||
|
print(actual.label) # prints "dog"
|
||||||
|
|
||||||
|
frombytes = (
|
||||||
|
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||||
|
.limit(1)
|
||||||
|
.to_pydantic(Images)[0]
|
||||||
|
)
|
||||||
|
print(frombytes.label)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Because we're using a multi-modal embedding function, we can also search using images
|
||||||
|
|
||||||
|
```python
|
||||||
|
# image search
|
||||||
|
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||||
|
image_bytes = requests.get(query_image_uri).content
|
||||||
|
query_image = Image.open(BytesIO(image_bytes))
|
||||||
|
actual = table.search(query_image, "vec_from_bytes").limit(1).to_pydantic(Images)[0]
|
||||||
|
print(actual.label == "dog")
|
||||||
|
|
||||||
|
# image search using a custom vector column
|
||||||
|
other = (
|
||||||
|
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||||
|
.limit(1)
|
||||||
|
.to_pydantic(Images)[0]
|
||||||
|
)
|
||||||
|
print(actual.label)
|
||||||
|
|
||||||
|
```
|
||||||
@@ -1,8 +1,12 @@
|
|||||||
# API Reference
|
# SDK Reference
|
||||||
|
|
||||||
This page contains the API reference for the SDKs supported by the LanceDB team.
|
This site contains the API reference for the client SDKs supported by [LanceDB](https://lancedb.com).
|
||||||
|
|
||||||
- [Python](python/python.md)
|
- [Python](python/python.md)
|
||||||
- [JavaScript/TypeScript](js/globals.md)
|
- [JavaScript/TypeScript](js/globals.md)
|
||||||
- [Java](java/java.md)
|
- [Java](java/java.md)
|
||||||
- [Rust](https://docs.rs/lancedb/latest/lancedb/index.html)
|
- [Rust](https://docs.rs/lancedb/latest/lancedb/index.html)
|
||||||
|
|
||||||
|
!!! info "LanceDB Documentation"
|
||||||
|
|
||||||
|
If you're looking for the full documentation of LanceDB, visit [docs.lancedb.com](https://docs.lancedb.com).
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-core</artifactId>
|
<artifactId>lancedb-core</artifactId>
|
||||||
<version>0.23.1-beta.1</version>
|
<version>0.23.1</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -85,17 +85,26 @@
|
|||||||
|
|
||||||
/* Header gradient (only header area) */
|
/* Header gradient (only header area) */
|
||||||
.md-header {
|
.md-header {
|
||||||
background: linear-gradient(90deg, #3B2E58 0%, #F0B7C1 45%, #E55A2B 100%);
|
background: linear-gradient(90deg, #e4d8f8 0%, #F0B7C1 45%, #E55A2B 100%);
|
||||||
box-shadow: inset 0 1px 0 rgba(255,255,255,0.08), 0 1px 0 rgba(0,0,0,0.08);
|
box-shadow: inset 0 1px 0 rgba(255,255,255,0.08), 0 1px 0 rgba(0,0,0,0.08);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Improve brand title contrast on the lavender side */
|
||||||
|
.md-header__title,
|
||||||
|
.md-header__topic,
|
||||||
|
.md-header__title .md-ellipsis,
|
||||||
|
.md-header__topic .md-ellipsis {
|
||||||
|
color: #2b1b3a;
|
||||||
|
text-shadow: 0 1px 0 rgba(255, 255, 255, 0.25);
|
||||||
|
}
|
||||||
|
|
||||||
/* Same colors as header for tabs (that hold the text) */
|
/* Same colors as header for tabs (that hold the text) */
|
||||||
.md-tabs {
|
.md-tabs {
|
||||||
background: linear-gradient(90deg, #3B2E58 0%, #F0B7C1 45%, #E55A2B 100%);
|
background: linear-gradient(90deg, #e4d8f8 0%, #F0B7C1 45%, #E55A2B 100%);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Dark scheme variant */
|
/* Dark scheme variant */
|
||||||
[data-md-color-scheme="slate"] .md-header,
|
[data-md-color-scheme="slate"] .md-header,
|
||||||
[data-md-color-scheme="slate"] .md-tabs {
|
[data-md-color-scheme="slate"] .md-tabs {
|
||||||
background: linear-gradient(90deg, #3B2E58 0%, #F0B7C1 45%, #E55A2B 100%);
|
background: linear-gradient(90deg, #e4d8f8 0%, #F0B7C1 45%, #E55A2B 100%);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
<parent>
|
<parent>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.23.1-beta.1</version>
|
<version>0.23.1-final.0</version>
|
||||||
<relativePath>../pom.xml</relativePath>
|
<relativePath>../pom.xml</relativePath>
|
||||||
</parent>
|
</parent>
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.23.1-beta.1</version>
|
<version>0.23.1-final.0</version>
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
<name>${project.artifactId}</name>
|
<name>${project.artifactId}</name>
|
||||||
<description>LanceDB Java SDK Parent POM</description>
|
<description>LanceDB Java SDK Parent POM</description>
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-nodejs"
|
name = "lancedb-nodejs"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
version = "0.23.1-beta.1"
|
version = "0.23.1"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
description.workspace = true
|
description.workspace = true
|
||||||
repository.workspace = true
|
repository.workspace = true
|
||||||
@@ -36,6 +36,6 @@ aws-lc-rs = "=1.13.0"
|
|||||||
napi-build = "2.1"
|
napi-build = "2.1"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["remote", "lancedb/default"]
|
default = ["remote", "lancedb/aws", "lancedb/gcs", "lancedb/azure", "lancedb/dynamodb", "lancedb/oss", "lancedb/huggingface"]
|
||||||
fp16kernels = ["lancedb/fp16kernels"]
|
fp16kernels = ["lancedb/fp16kernels"]
|
||||||
remote = ["lancedb/remote"]
|
remote = ["lancedb/remote"]
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-arm64",
|
"name": "@lancedb/lancedb-darwin-arm64",
|
||||||
"version": "0.23.1-beta.1",
|
"version": "0.23.1",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.darwin-arm64.node",
|
"main": "lancedb.darwin-arm64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-x64",
|
"name": "@lancedb/lancedb-darwin-x64",
|
||||||
"version": "0.23.1-beta.1",
|
"version": "0.23.1",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.darwin-x64.node",
|
"main": "lancedb.darwin-x64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||||
"version": "0.23.1-beta.1",
|
"version": "0.23.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-gnu.node",
|
"main": "lancedb.linux-arm64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||||
"version": "0.23.1-beta.1",
|
"version": "0.23.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-musl.node",
|
"main": "lancedb.linux-arm64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||||
"version": "0.23.1-beta.1",
|
"version": "0.23.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-gnu.node",
|
"main": "lancedb.linux-x64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||||
"version": "0.23.1-beta.1",
|
"version": "0.23.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-musl.node",
|
"main": "lancedb.linux-x64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||||
"version": "0.23.1-beta.1",
|
"version": "0.23.1",
|
||||||
"os": [
|
"os": [
|
||||||
"win32"
|
"win32"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||||
"version": "0.23.1-beta.1",
|
"version": "0.23.1",
|
||||||
"os": ["win32"],
|
"os": ["win32"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.win32-x64-msvc.node",
|
"main": "lancedb.win32-x64-msvc.node",
|
||||||
|
|||||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.23.1-beta.1",
|
"version": "0.23.1",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.23.1-beta.1",
|
"version": "0.23.1",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
"ann"
|
"ann"
|
||||||
],
|
],
|
||||||
"private": false,
|
"private": false,
|
||||||
"version": "0.23.1-beta.1",
|
"version": "0.23.1",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"exports": {
|
"exports": {
|
||||||
".": "./dist/index.js",
|
".": "./dist/index.js",
|
||||||
|
|||||||
@@ -14,15 +14,15 @@ name = "_lancedb"
|
|||||||
crate-type = ["cdylib"]
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
arrow = { version = "56.2", features = ["pyarrow"] }
|
arrow = { version = "57.2", features = ["pyarrow"] }
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
lancedb = { path = "../rust/lancedb", default-features = false }
|
lancedb = { path = "../rust/lancedb", default-features = false }
|
||||||
lance-core.workspace = true
|
lance-core.workspace = true
|
||||||
lance-namespace.workspace = true
|
lance-namespace.workspace = true
|
||||||
lance-io.workspace = true
|
lance-io.workspace = true
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
pyo3 = { version = "0.25", features = ["extension-module", "abi3-py39"] }
|
pyo3 = { version = "0.26", features = ["extension-module", "abi3-py39"] }
|
||||||
pyo3-async-runtimes = { version = "0.25", features = [
|
pyo3-async-runtimes = { version = "0.26", features = [
|
||||||
"attributes",
|
"attributes",
|
||||||
"tokio-runtime",
|
"tokio-runtime",
|
||||||
] }
|
] }
|
||||||
@@ -32,12 +32,12 @@ snafu.workspace = true
|
|||||||
tokio = { version = "1.40", features = ["sync"] }
|
tokio = { version = "1.40", features = ["sync"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
pyo3-build-config = { version = "0.25", features = [
|
pyo3-build-config = { version = "0.26", features = [
|
||||||
"extension-module",
|
"extension-module",
|
||||||
"abi3-py39",
|
"abi3-py39",
|
||||||
] }
|
] }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["remote", "lancedb/default"]
|
default = ["remote", "lancedb/aws", "lancedb/gcs", "lancedb/azure", "lancedb/dynamodb", "lancedb/oss", "lancedb/huggingface"]
|
||||||
fp16kernels = ["lancedb/fp16kernels"]
|
fp16kernels = ["lancedb/fp16kernels"]
|
||||||
remote = ["lancedb/remote"]
|
remote = ["lancedb/remote"]
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ __version__ = importlib.metadata.version("lancedb")
|
|||||||
|
|
||||||
from ._lancedb import connect as lancedb_connect
|
from ._lancedb import connect as lancedb_connect
|
||||||
from .common import URI, sanitize_uri
|
from .common import URI, sanitize_uri
|
||||||
|
from urllib.parse import urlparse
|
||||||
from .db import AsyncConnection, DBConnection, LanceDBConnection
|
from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||||
from .io import StorageOptionsProvider
|
from .io import StorageOptionsProvider
|
||||||
from .remote import ClientConfig
|
from .remote import ClientConfig
|
||||||
@@ -28,6 +29,39 @@ from .namespace import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_s3_bucket_with_dots(
|
||||||
|
uri: str, storage_options: Optional[Dict[str, str]]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Check if an S3 URI has a bucket name containing dots and warn if no region
|
||||||
|
is specified. S3 buckets with dots cannot use virtual-hosted-style URLs,
|
||||||
|
which breaks automatic region detection.
|
||||||
|
|
||||||
|
See: https://github.com/lancedb/lancedb/issues/1898
|
||||||
|
"""
|
||||||
|
if not isinstance(uri, str) or not uri.startswith("s3://"):
|
||||||
|
return
|
||||||
|
|
||||||
|
parsed = urlparse(uri)
|
||||||
|
bucket = parsed.netloc
|
||||||
|
|
||||||
|
if "." not in bucket:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if region is provided in storage_options
|
||||||
|
region_keys = {"region", "aws_region"}
|
||||||
|
has_region = storage_options and any(k in storage_options for k in region_keys)
|
||||||
|
|
||||||
|
if not has_region:
|
||||||
|
raise ValueError(
|
||||||
|
f"S3 bucket name '{bucket}' contains dots, which prevents automatic "
|
||||||
|
f"region detection. Please specify the region explicitly via "
|
||||||
|
f"storage_options={{'region': '<your-region>'}} or "
|
||||||
|
f"storage_options={{'aws_region': '<your-region>'}}. "
|
||||||
|
f"See https://github.com/lancedb/lancedb/issues/1898 for details."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def connect(
|
def connect(
|
||||||
uri: URI,
|
uri: URI,
|
||||||
*,
|
*,
|
||||||
@@ -121,9 +155,11 @@ def connect(
|
|||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
_check_s3_bucket_with_dots(str(uri), storage_options)
|
||||||
|
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError(f"Unknown keyword arguments: {kwargs}")
|
raise ValueError(f"Unknown keyword arguments: {kwargs}")
|
||||||
|
|
||||||
return LanceDBConnection(
|
return LanceDBConnection(
|
||||||
uri,
|
uri,
|
||||||
read_consistency_interval=read_consistency_interval,
|
read_consistency_interval=read_consistency_interval,
|
||||||
@@ -211,6 +247,8 @@ async def connect_async(
|
|||||||
if isinstance(client_config, dict):
|
if isinstance(client_config, dict):
|
||||||
client_config = ClientConfig(**client_config)
|
client_config = ClientConfig(**client_config)
|
||||||
|
|
||||||
|
_check_s3_bucket_with_dots(str(uri), storage_options)
|
||||||
|
|
||||||
return AsyncConnection(
|
return AsyncConnection(
|
||||||
await lancedb_connect(
|
await lancedb_connect(
|
||||||
sanitize_uri(uri),
|
sanitize_uri(uri),
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
from typing import ClassVar, TYPE_CHECKING, List, Union, Any, Generator
|
from typing import ClassVar, TYPE_CHECKING, List, Union, Any, Generator, Optional
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
@@ -45,11 +45,29 @@ def is_valid_url(text):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
VIDEO_EXTENSIONS = {".mp4", ".webm", ".mov", ".avi", ".mkv", ".m4v", ".gif"}
|
||||||
|
|
||||||
|
|
||||||
|
def is_video_url(url: str) -> bool:
|
||||||
|
"""Check if URL points to a video file based on extension."""
|
||||||
|
parsed = urlparse(url)
|
||||||
|
path = parsed.path.lower()
|
||||||
|
return any(path.endswith(ext) for ext in VIDEO_EXTENSIONS)
|
||||||
|
|
||||||
|
|
||||||
|
def is_video_path(path: Path) -> bool:
|
||||||
|
"""Check if file path is a video file based on extension."""
|
||||||
|
return path.suffix.lower() in VIDEO_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
def transform_input(input_data: Union[str, bytes, Path]):
|
def transform_input(input_data: Union[str, bytes, Path]):
|
||||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||||
if isinstance(input_data, str):
|
if isinstance(input_data, str):
|
||||||
if is_valid_url(input_data):
|
if is_valid_url(input_data):
|
||||||
content = {"type": "image_url", "image_url": input_data}
|
if is_video_url(input_data):
|
||||||
|
content = {"type": "video_url", "video_url": input_data}
|
||||||
|
else:
|
||||||
|
content = {"type": "image_url", "image_url": input_data}
|
||||||
else:
|
else:
|
||||||
content = {"type": "text", "text": input_data}
|
content = {"type": "text", "text": input_data}
|
||||||
elif isinstance(input_data, PIL.Image.Image):
|
elif isinstance(input_data, PIL.Image.Image):
|
||||||
@@ -70,14 +88,24 @@ def transform_input(input_data: Union[str, bytes, Path]):
|
|||||||
"image_base64": "data:image/jpeg;base64," + img_str,
|
"image_base64": "data:image/jpeg;base64," + img_str,
|
||||||
}
|
}
|
||||||
elif isinstance(input_data, Path):
|
elif isinstance(input_data, Path):
|
||||||
img = PIL.Image.open(input_data)
|
if is_video_path(input_data):
|
||||||
buffered = BytesIO()
|
# Read video file and encode as base64
|
||||||
img.save(buffered, format="JPEG")
|
with open(input_data, "rb") as f:
|
||||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
video_bytes = f.read()
|
||||||
content = {
|
video_str = base64.b64encode(video_bytes).decode("utf-8")
|
||||||
"type": "image_base64",
|
content = {
|
||||||
"image_base64": "data:image/jpeg;base64," + img_str,
|
"type": "video_base64",
|
||||||
}
|
"video_base64": video_str,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
img = PIL.Image.open(input_data)
|
||||||
|
buffered = BytesIO()
|
||||||
|
img.save(buffered, format="JPEG")
|
||||||
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
content = {
|
||||||
|
"type": "image_base64",
|
||||||
|
"image_base64": "data:image/jpeg;base64," + img_str,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError("Each input should be either str, bytes, Path or Image.")
|
raise ValueError("Each input should be either str, bytes, Path or Image.")
|
||||||
|
|
||||||
@@ -91,6 +119,8 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
|
|||||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||||
if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)):
|
if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)):
|
||||||
inputs = [inputs]
|
inputs = [inputs]
|
||||||
|
elif isinstance(inputs, list):
|
||||||
|
pass # Already a list, use as-is
|
||||||
elif isinstance(inputs, pa.Array):
|
elif isinstance(inputs, pa.Array):
|
||||||
inputs = inputs.to_pylist()
|
inputs = inputs.to_pylist()
|
||||||
elif isinstance(inputs, pa.ChunkedArray):
|
elif isinstance(inputs, pa.ChunkedArray):
|
||||||
@@ -143,11 +173,16 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
* voyage-3
|
* voyage-3
|
||||||
* voyage-3-lite
|
* voyage-3-lite
|
||||||
* voyage-multimodal-3
|
* voyage-multimodal-3
|
||||||
|
* voyage-multimodal-3.5
|
||||||
* voyage-finance-2
|
* voyage-finance-2
|
||||||
* voyage-multilingual-2
|
* voyage-multilingual-2
|
||||||
* voyage-law-2
|
* voyage-law-2
|
||||||
* voyage-code-2
|
* voyage-code-2
|
||||||
|
|
||||||
|
output_dimension: int, optional
|
||||||
|
The output dimension for models that support flexible dimensions.
|
||||||
|
Currently only voyage-multimodal-3.5 supports this feature.
|
||||||
|
Valid options: 256, 512, 1024 (default), 2048.
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
@@ -175,7 +210,10 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
|
output_dimension: Optional[int] = None
|
||||||
client: ClassVar = None
|
client: ClassVar = None
|
||||||
|
_FLEXIBLE_DIM_MODELS: ClassVar[list] = ["voyage-multimodal-3.5"]
|
||||||
|
_VALID_DIMENSIONS: ClassVar[list] = [256, 512, 1024, 2048]
|
||||||
text_embedding_models: list = [
|
text_embedding_models: list = [
|
||||||
"voyage-3.5",
|
"voyage-3.5",
|
||||||
"voyage-3.5-lite",
|
"voyage-3.5-lite",
|
||||||
@@ -186,7 +224,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
"voyage-law-2",
|
"voyage-law-2",
|
||||||
"voyage-code-2",
|
"voyage-code-2",
|
||||||
]
|
]
|
||||||
multimodal_embedding_models: list = ["voyage-multimodal-3"]
|
multimodal_embedding_models: list = ["voyage-multimodal-3", "voyage-multimodal-3.5"]
|
||||||
contextual_embedding_models: list = ["voyage-context-3"]
|
contextual_embedding_models: list = ["voyage-context-3"]
|
||||||
|
|
||||||
def _is_multimodal_model(self, model_name: str):
|
def _is_multimodal_model(self, model_name: str):
|
||||||
@@ -198,6 +236,17 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
return model_name in self.contextual_embedding_models or "context" in model_name
|
return model_name in self.contextual_embedding_models or "context" in model_name
|
||||||
|
|
||||||
def ndims(self):
|
def ndims(self):
|
||||||
|
# Handle flexible dimension models
|
||||||
|
if self.name in self._FLEXIBLE_DIM_MODELS:
|
||||||
|
if self.output_dimension is not None:
|
||||||
|
if self.output_dimension not in self._VALID_DIMENSIONS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid output_dimension {self.output_dimension} "
|
||||||
|
f"for {self.name}. Valid options: {self._VALID_DIMENSIONS}"
|
||||||
|
)
|
||||||
|
return self.output_dimension
|
||||||
|
return 1024 # default dimension
|
||||||
|
|
||||||
if self.name == "voyage-3-lite":
|
if self.name == "voyage-3-lite":
|
||||||
return 512
|
return 512
|
||||||
elif self.name == "voyage-code-2":
|
elif self.name == "voyage-code-2":
|
||||||
@@ -211,12 +260,17 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
"voyage-finance-2",
|
"voyage-finance-2",
|
||||||
"voyage-multilingual-2",
|
"voyage-multilingual-2",
|
||||||
"voyage-law-2",
|
"voyage-law-2",
|
||||||
"voyage-multimodal-3",
|
|
||||||
]:
|
]:
|
||||||
return 1024
|
return 1024
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Model {self.name} not supported")
|
raise ValueError(f"Model {self.name} not supported")
|
||||||
|
|
||||||
|
def _get_multimodal_kwargs(self, **kwargs):
|
||||||
|
"""Get kwargs for multimodal embed call, including output_dimension if set."""
|
||||||
|
if self.name in self._FLEXIBLE_DIM_MODELS and self.output_dimension is not None:
|
||||||
|
kwargs["output_dimension"] = self.output_dimension
|
||||||
|
return kwargs
|
||||||
|
|
||||||
def compute_query_embeddings(
|
def compute_query_embeddings(
|
||||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||||
) -> List[np.ndarray]:
|
) -> List[np.ndarray]:
|
||||||
@@ -234,6 +288,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
"""
|
"""
|
||||||
client = VoyageAIEmbeddingFunction._get_client()
|
client = VoyageAIEmbeddingFunction._get_client()
|
||||||
if self._is_multimodal_model(self.name):
|
if self._is_multimodal_model(self.name):
|
||||||
|
kwargs = self._get_multimodal_kwargs(**kwargs)
|
||||||
result = client.multimodal_embed(
|
result = client.multimodal_embed(
|
||||||
inputs=[[query]], model=self.name, input_type="query", **kwargs
|
inputs=[[query]], model=self.name, input_type="query", **kwargs
|
||||||
)
|
)
|
||||||
@@ -275,6 +330,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
)
|
)
|
||||||
if has_images:
|
if has_images:
|
||||||
# Use non-batched API for images
|
# Use non-batched API for images
|
||||||
|
kwargs = self._get_multimodal_kwargs(**kwargs)
|
||||||
result = client.multimodal_embed(
|
result = client.multimodal_embed(
|
||||||
inputs=sanitized, model=self.name, input_type="document", **kwargs
|
inputs=sanitized, model=self.name, input_type="document", **kwargs
|
||||||
)
|
)
|
||||||
@@ -357,6 +413,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
callable: A function that takes a batch of texts and returns embeddings.
|
callable: A function that takes a batch of texts and returns embeddings.
|
||||||
"""
|
"""
|
||||||
if self._is_multimodal_model(self.name):
|
if self._is_multimodal_model(self.name):
|
||||||
|
multimodal_kwargs = self._get_multimodal_kwargs(**kwargs)
|
||||||
|
|
||||||
def embed_batch(batch: List[str]) -> List[np.array]:
|
def embed_batch(batch: List[str]) -> List[np.array]:
|
||||||
batch_inputs = sanitize_multimodal_input(batch)
|
batch_inputs = sanitize_multimodal_input(batch)
|
||||||
@@ -364,7 +421,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
inputs=batch_inputs,
|
inputs=batch_inputs,
|
||||||
model=self.name,
|
model=self.name,
|
||||||
input_type=input_type,
|
input_type=input_type,
|
||||||
**kwargs,
|
**multimodal_kwargs,
|
||||||
)
|
)
|
||||||
return result.embeddings
|
return result.embeddings
|
||||||
|
|
||||||
|
|||||||
@@ -384,6 +384,7 @@ class RemoteDBConnection(DBConnection):
|
|||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
mode: Optional[str] = None,
|
mode: Optional[str] = None,
|
||||||
|
exist_ok: bool = False,
|
||||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||||
*,
|
*,
|
||||||
namespace: Optional[List[str]] = None,
|
namespace: Optional[List[str]] = None,
|
||||||
@@ -412,6 +413,12 @@ class RemoteDBConnection(DBConnection):
|
|||||||
- pyarrow.Schema
|
- pyarrow.Schema
|
||||||
|
|
||||||
- [LanceModel][lancedb.pydantic.LanceModel]
|
- [LanceModel][lancedb.pydantic.LanceModel]
|
||||||
|
mode: str, default "create"
|
||||||
|
The mode to use when creating the table.
|
||||||
|
Can be either "create", "overwrite", or "exist_ok".
|
||||||
|
exist_ok: bool, default False
|
||||||
|
If exist_ok is True, and mode is None or "create", mode will be changed
|
||||||
|
to "exist_ok".
|
||||||
on_bad_vectors: str, default "error"
|
on_bad_vectors: str, default "error"
|
||||||
What to do if any of the vectors are not the same size or contains NaNs.
|
What to do if any of the vectors are not the same size or contains NaNs.
|
||||||
One of "error", "drop", "fill".
|
One of "error", "drop", "fill".
|
||||||
@@ -483,6 +490,11 @@ class RemoteDBConnection(DBConnection):
|
|||||||
LanceTable(table4)
|
LanceTable(table4)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if exist_ok:
|
||||||
|
if mode == "create":
|
||||||
|
mode = "exist_ok"
|
||||||
|
elif not mode:
|
||||||
|
mode = "exist_ok"
|
||||||
if namespace is None:
|
if namespace is None:
|
||||||
namespace = []
|
namespace = []
|
||||||
validate_table_name(name)
|
validate_table_name(name)
|
||||||
|
|||||||
@@ -18,7 +18,17 @@ from lancedb._lancedb import (
|
|||||||
UpdateResult,
|
UpdateResult,
|
||||||
)
|
)
|
||||||
from lancedb.embeddings.base import EmbeddingFunctionConfig
|
from lancedb.embeddings.base import EmbeddingFunctionConfig
|
||||||
from lancedb.index import FTS, BTree, Bitmap, HnswSq, IvfFlat, IvfPq, IvfSq, LabelList
|
from lancedb.index import (
|
||||||
|
FTS,
|
||||||
|
BTree,
|
||||||
|
Bitmap,
|
||||||
|
HnswSq,
|
||||||
|
IvfFlat,
|
||||||
|
IvfPq,
|
||||||
|
IvfRq,
|
||||||
|
IvfSq,
|
||||||
|
LabelList,
|
||||||
|
)
|
||||||
from lancedb.remote.db import LOOP
|
from lancedb.remote.db import LOOP
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
@@ -265,6 +275,12 @@ class RemoteTable(Table):
|
|||||||
num_sub_vectors=num_sub_vectors,
|
num_sub_vectors=num_sub_vectors,
|
||||||
num_bits=num_bits,
|
num_bits=num_bits,
|
||||||
)
|
)
|
||||||
|
elif index_type == "IVF_RQ":
|
||||||
|
config = IvfRq(
|
||||||
|
distance_type=metric,
|
||||||
|
num_partitions=num_partitions,
|
||||||
|
num_bits=num_bits,
|
||||||
|
)
|
||||||
elif index_type == "IVF_SQ":
|
elif index_type == "IVF_SQ":
|
||||||
config = IvfSq(distance_type=metric, num_partitions=num_partitions)
|
config = IvfSq(distance_type=metric, num_partitions=num_partitions)
|
||||||
elif index_type == "IVF_HNSW_PQ":
|
elif index_type == "IVF_HNSW_PQ":
|
||||||
@@ -279,7 +295,8 @@ class RemoteTable(Table):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown vector index type: {index_type}. Valid options are"
|
f"Unknown vector index type: {index_type}. Valid options are"
|
||||||
" 'IVF_FLAT', 'IVF_SQ', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
" 'IVF_FLAT', 'IVF_PQ', 'IVF_RQ', 'IVF_SQ',"
|
||||||
|
" 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
||||||
)
|
)
|
||||||
|
|
||||||
LOOP.run(
|
LOOP.run(
|
||||||
|
|||||||
@@ -613,6 +613,133 @@ def test_voyageai_multimodal_embedding_text_function():
|
|||||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||||
|
)
|
||||||
|
def test_voyageai_multimodal_35_embedding_function():
|
||||||
|
"""Test voyage-multimodal-3.5 model with text input."""
|
||||||
|
voyageai = (
|
||||||
|
get_registry()
|
||||||
|
.get("voyageai")
|
||||||
|
.create(name="voyage-multimodal-3.5", max_retries=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = voyageai.SourceField()
|
||||||
|
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||||
|
|
||||||
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||||
|
db = lancedb.connect("~/lancedb")
|
||||||
|
tbl = db.create_table("test_multimodal_35", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||||
|
assert voyageai.ndims() == 1024
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||||
|
)
|
||||||
|
def test_voyageai_multimodal_35_flexible_dimensions():
|
||||||
|
"""Test voyage-multimodal-3.5 model with custom output dimension."""
|
||||||
|
voyageai = (
|
||||||
|
get_registry()
|
||||||
|
.get("voyageai")
|
||||||
|
.create(name="voyage-multimodal-3.5", output_dimension=512, max_retries=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = voyageai.SourceField()
|
||||||
|
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||||
|
|
||||||
|
assert voyageai.ndims() == 512
|
||||||
|
|
||||||
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||||
|
db = lancedb.connect("~/lancedb")
|
||||||
|
tbl = db.create_table("test_multimodal_35_dim", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
assert len(tbl.to_pandas()["vector"][0]) == 512
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||||
|
)
|
||||||
|
def test_voyageai_multimodal_35_image_embedding():
|
||||||
|
"""Test voyage-multimodal-3.5 model with image input."""
|
||||||
|
voyageai = (
|
||||||
|
get_registry()
|
||||||
|
.get("voyageai")
|
||||||
|
.create(name="voyage-multimodal-3.5", max_retries=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
class Images(LanceModel):
|
||||||
|
label: str
|
||||||
|
image_uri: str = voyageai.SourceField()
|
||||||
|
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||||
|
|
||||||
|
db = lancedb.connect("~/lancedb")
|
||||||
|
table = db.create_table(
|
||||||
|
"test_multimodal_35_images", schema=Images, mode="overwrite"
|
||||||
|
)
|
||||||
|
labels = ["cat", "dog"]
|
||||||
|
uris = [
|
||||||
|
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||||
|
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||||
|
]
|
||||||
|
table.add(pd.DataFrame({"label": labels, "image_uri": uris}))
|
||||||
|
assert len(table.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||||
|
assert voyageai.ndims() == 1024
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dimension", [256, 512, 1024, 2048])
|
||||||
|
def test_voyageai_multimodal_35_all_dimensions(dimension):
|
||||||
|
"""Test voyage-multimodal-3.5 model with all valid output dimensions."""
|
||||||
|
voyageai = (
|
||||||
|
get_registry()
|
||||||
|
.get("voyageai")
|
||||||
|
.create(name="voyage-multimodal-3.5", output_dimension=dimension, max_retries=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert voyageai.ndims() == dimension
|
||||||
|
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = voyageai.SourceField()
|
||||||
|
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||||
|
|
||||||
|
df = pd.DataFrame({"text": ["hello world"]})
|
||||||
|
db = lancedb.connect("~/lancedb")
|
||||||
|
tbl = db.create_table(
|
||||||
|
f"test_multimodal_35_dim_{dimension}", schema=TextModel, mode="overwrite"
|
||||||
|
)
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
assert len(tbl.to_pandas()["vector"][0]) == dimension
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||||
|
)
|
||||||
|
def test_voyageai_multimodal_35_invalid_dimension():
|
||||||
|
"""Test voyage-multimodal-3.5 model raises error for invalid output dimension."""
|
||||||
|
with pytest.raises(ValueError, match="Invalid output_dimension"):
|
||||||
|
voyageai = (
|
||||||
|
get_registry()
|
||||||
|
.get("voyageai")
|
||||||
|
.create(name="voyage-multimodal-3.5", output_dimension=999, max_retries=0)
|
||||||
|
)
|
||||||
|
# ndims() is where the validation happens
|
||||||
|
voyageai.ndims()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
importlib.util.find_spec("colpali_engine") is None,
|
importlib.util.find_spec("colpali_engine") is None,
|
||||||
|
|||||||
@@ -168,6 +168,42 @@ def test_table_len_sync():
|
|||||||
assert len(table) == 1
|
assert len(table) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_table_exist_ok():
|
||||||
|
def handler(request):
|
||||||
|
if request.path == "/v1/table/test/create/?mode=exist_ok":
|
||||||
|
request.send_response(200)
|
||||||
|
request.send_header("Content-Type", "application/json")
|
||||||
|
request.end_headers()
|
||||||
|
request.wfile.write(b"{}")
|
||||||
|
else:
|
||||||
|
request.send_response(404)
|
||||||
|
request.end_headers()
|
||||||
|
|
||||||
|
with mock_lancedb_connection(handler) as db:
|
||||||
|
table = db.create_table("test", [{"id": 1}], exist_ok=True)
|
||||||
|
assert table is not None
|
||||||
|
|
||||||
|
with mock_lancedb_connection(handler) as db:
|
||||||
|
table = db.create_table("test", [{"id": 1}], mode="create", exist_ok=True)
|
||||||
|
assert table is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_table_exist_ok_with_mode_overwrite():
|
||||||
|
def handler(request):
|
||||||
|
if request.path == "/v1/table/test/create/?mode=overwrite":
|
||||||
|
request.send_response(200)
|
||||||
|
request.send_header("Content-Type", "application/json")
|
||||||
|
request.end_headers()
|
||||||
|
request.wfile.write(b"{}")
|
||||||
|
else:
|
||||||
|
request.send_response(404)
|
||||||
|
request.end_headers()
|
||||||
|
|
||||||
|
with mock_lancedb_connection(handler) as db:
|
||||||
|
table = db.create_table("test", [{"id": 1}], mode="overwrite", exist_ok=True)
|
||||||
|
assert table is not None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_http_error():
|
async def test_http_error():
|
||||||
request_id_holder = {"request_id": None}
|
request_id_holder = {"request_id": None}
|
||||||
|
|||||||
68
python/python/tests/test_s3_bucket_dots.py
Normal file
68
python/python/tests/test_s3_bucket_dots.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
"""
|
||||||
|
Tests for S3 bucket names containing dots.
|
||||||
|
|
||||||
|
Related issue: https://github.com/lancedb/lancedb/issues/1898
|
||||||
|
|
||||||
|
These tests validate the early error checking for S3 bucket names with dots.
|
||||||
|
No actual S3 connection is made - validation happens before connection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import lancedb
|
||||||
|
|
||||||
|
# Test URIs
|
||||||
|
BUCKET_WITH_DOTS = "s3://my.bucket.name/path"
|
||||||
|
BUCKET_WITH_DOTS_AND_REGION = ("s3://my.bucket.name", {"region": "us-east-1"})
|
||||||
|
BUCKET_WITH_DOTS_AND_AWS_REGION = ("s3://my.bucket.name", {"aws_region": "us-east-1"})
|
||||||
|
BUCKET_WITHOUT_DOTS = "s3://my-bucket/path"
|
||||||
|
|
||||||
|
|
||||||
|
class TestS3BucketWithDotsSync:
|
||||||
|
"""Tests for connect()."""
|
||||||
|
|
||||||
|
def test_bucket_with_dots_requires_region(self):
|
||||||
|
with pytest.raises(ValueError, match="contains dots"):
|
||||||
|
lancedb.connect(BUCKET_WITH_DOTS)
|
||||||
|
|
||||||
|
def test_bucket_with_dots_and_region_passes(self):
|
||||||
|
uri, opts = BUCKET_WITH_DOTS_AND_REGION
|
||||||
|
db = lancedb.connect(uri, storage_options=opts)
|
||||||
|
assert db is not None
|
||||||
|
|
||||||
|
def test_bucket_with_dots_and_aws_region_passes(self):
|
||||||
|
uri, opts = BUCKET_WITH_DOTS_AND_AWS_REGION
|
||||||
|
db = lancedb.connect(uri, storage_options=opts)
|
||||||
|
assert db is not None
|
||||||
|
|
||||||
|
def test_bucket_without_dots_passes(self):
|
||||||
|
db = lancedb.connect(BUCKET_WITHOUT_DOTS)
|
||||||
|
assert db is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestS3BucketWithDotsAsync:
|
||||||
|
"""Tests for connect_async()."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bucket_with_dots_requires_region(self):
|
||||||
|
with pytest.raises(ValueError, match="contains dots"):
|
||||||
|
await lancedb.connect_async(BUCKET_WITH_DOTS)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bucket_with_dots_and_region_passes(self):
|
||||||
|
uri, opts = BUCKET_WITH_DOTS_AND_REGION
|
||||||
|
db = await lancedb.connect_async(uri, storage_options=opts)
|
||||||
|
assert db is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bucket_with_dots_and_aws_region_passes(self):
|
||||||
|
uri, opts = BUCKET_WITH_DOTS_AND_AWS_REGION
|
||||||
|
db = await lancedb.connect_async(uri, storage_options=opts)
|
||||||
|
assert db is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bucket_without_dots_passes(self):
|
||||||
|
db = await lancedb.connect_async(BUCKET_WITHOUT_DOTS)
|
||||||
|
assert db is not None
|
||||||
@@ -10,8 +10,7 @@ use arrow::{
|
|||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
use lancedb::arrow::SendableRecordBatchStream;
|
use lancedb::arrow::SendableRecordBatchStream;
|
||||||
use pyo3::{
|
use pyo3::{
|
||||||
exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, PyAny, PyObject, PyRef, PyResult,
|
exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, Py, PyAny, PyRef, PyResult, Python,
|
||||||
Python,
|
|
||||||
};
|
};
|
||||||
use pyo3_async_runtimes::tokio::future_into_py;
|
use pyo3_async_runtimes::tokio::future_into_py;
|
||||||
|
|
||||||
@@ -36,8 +35,11 @@ impl RecordBatchStream {
|
|||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl RecordBatchStream {
|
impl RecordBatchStream {
|
||||||
#[getter]
|
#[getter]
|
||||||
pub fn schema(&self, py: Python) -> PyResult<PyObject> {
|
pub fn schema(&self, py: Python) -> PyResult<Py<PyAny>> {
|
||||||
(*self.schema).clone().into_pyarrow(py)
|
(*self.schema)
|
||||||
|
.clone()
|
||||||
|
.into_pyarrow(py)
|
||||||
|
.map(|obj| obj.unbind())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn __aiter__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> {
|
pub fn __aiter__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> {
|
||||||
@@ -53,7 +55,12 @@ impl RecordBatchStream {
|
|||||||
.next()
|
.next()
|
||||||
.await
|
.await
|
||||||
.ok_or_else(|| PyStopAsyncIteration::new_err(""))?;
|
.ok_or_else(|| PyStopAsyncIteration::new_err(""))?;
|
||||||
Python::with_gil(|py| inner_next.infer_error()?.to_pyarrow(py))
|
#[allow(deprecated)]
|
||||||
|
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
|
||||||
|
let bound = inner_next.infer_error()?.to_pyarrow(py)?;
|
||||||
|
Ok(bound.unbind())
|
||||||
|
})?;
|
||||||
|
Ok(py_obj)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ use pyo3::{
|
|||||||
exceptions::{PyRuntimeError, PyValueError},
|
exceptions::{PyRuntimeError, PyValueError},
|
||||||
pyclass, pyfunction, pymethods,
|
pyclass, pyfunction, pymethods,
|
||||||
types::{PyDict, PyDictMethods},
|
types::{PyDict, PyDictMethods},
|
||||||
Bound, FromPyObject, Py, PyAny, PyObject, PyRef, PyResult, Python,
|
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
||||||
};
|
};
|
||||||
use pyo3_async_runtimes::tokio::future_into_py;
|
use pyo3_async_runtimes::tokio::future_into_py;
|
||||||
|
|
||||||
@@ -114,7 +114,7 @@ impl Connection {
|
|||||||
data: Bound<'_, PyAny>,
|
data: Bound<'_, PyAny>,
|
||||||
namespace: Vec<String>,
|
namespace: Vec<String>,
|
||||||
storage_options: Option<HashMap<String, String>>,
|
storage_options: Option<HashMap<String, String>>,
|
||||||
storage_options_provider: Option<PyObject>,
|
storage_options_provider: Option<Py<PyAny>>,
|
||||||
location: Option<String>,
|
location: Option<String>,
|
||||||
) -> PyResult<Bound<'a, PyAny>> {
|
) -> PyResult<Bound<'a, PyAny>> {
|
||||||
let inner = self_.get_inner()?.clone();
|
let inner = self_.get_inner()?.clone();
|
||||||
@@ -152,7 +152,7 @@ impl Connection {
|
|||||||
schema: Bound<'_, PyAny>,
|
schema: Bound<'_, PyAny>,
|
||||||
namespace: Vec<String>,
|
namespace: Vec<String>,
|
||||||
storage_options: Option<HashMap<String, String>>,
|
storage_options: Option<HashMap<String, String>>,
|
||||||
storage_options_provider: Option<PyObject>,
|
storage_options_provider: Option<Py<PyAny>>,
|
||||||
location: Option<String>,
|
location: Option<String>,
|
||||||
) -> PyResult<Bound<'a, PyAny>> {
|
) -> PyResult<Bound<'a, PyAny>> {
|
||||||
let inner = self_.get_inner()?.clone();
|
let inner = self_.get_inner()?.clone();
|
||||||
@@ -187,7 +187,7 @@ impl Connection {
|
|||||||
name: String,
|
name: String,
|
||||||
namespace: Vec<String>,
|
namespace: Vec<String>,
|
||||||
storage_options: Option<HashMap<String, String>>,
|
storage_options: Option<HashMap<String, String>>,
|
||||||
storage_options_provider: Option<PyObject>,
|
storage_options_provider: Option<Py<PyAny>>,
|
||||||
index_cache_size: Option<u32>,
|
index_cache_size: Option<u32>,
|
||||||
location: Option<String>,
|
location: Option<String>,
|
||||||
) -> PyResult<Bound<'_, PyAny>> {
|
) -> PyResult<Bound<'_, PyAny>> {
|
||||||
@@ -304,8 +304,10 @@ impl Connection {
|
|||||||
},
|
},
|
||||||
page_token,
|
page_token,
|
||||||
limit: limit.map(|l| l as i32),
|
limit: limit.map(|l| l as i32),
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
let response = inner.list_namespaces(request).await.infer_error()?;
|
let response = inner.list_namespaces(request).await.infer_error()?;
|
||||||
|
#[allow(deprecated)]
|
||||||
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
|
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
|
||||||
let dict = PyDict::new(py);
|
let dict = PyDict::new(py);
|
||||||
dict.set_item("namespaces", response.namespaces)?;
|
dict.set_item("namespaces", response.namespaces)?;
|
||||||
@@ -325,11 +327,11 @@ impl Connection {
|
|||||||
let inner = self_.get_inner()?.clone();
|
let inner = self_.get_inner()?.clone();
|
||||||
let py = self_.py();
|
let py = self_.py();
|
||||||
future_into_py(py, async move {
|
future_into_py(py, async move {
|
||||||
use lance_namespace::models::{create_namespace_request, CreateNamespaceRequest};
|
use lance_namespace::models::CreateNamespaceRequest;
|
||||||
let mode_enum = mode.and_then(|m| match m.to_lowercase().as_str() {
|
let mode_enum = mode.and_then(|m| match m.to_lowercase().as_str() {
|
||||||
"create" => Some(create_namespace_request::Mode::Create),
|
"create" => Some("Create".to_string()),
|
||||||
"exist_ok" => Some(create_namespace_request::Mode::ExistOk),
|
"exist_ok" => Some("ExistOk".to_string()),
|
||||||
"overwrite" => Some(create_namespace_request::Mode::Overwrite),
|
"overwrite" => Some("Overwrite".to_string()),
|
||||||
_ => None,
|
_ => None,
|
||||||
});
|
});
|
||||||
let request = CreateNamespaceRequest {
|
let request = CreateNamespaceRequest {
|
||||||
@@ -340,8 +342,10 @@ impl Connection {
|
|||||||
},
|
},
|
||||||
mode: mode_enum,
|
mode: mode_enum,
|
||||||
properties,
|
properties,
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
let response = inner.create_namespace(request).await.infer_error()?;
|
let response = inner.create_namespace(request).await.infer_error()?;
|
||||||
|
#[allow(deprecated)]
|
||||||
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
|
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
|
||||||
let dict = PyDict::new(py);
|
let dict = PyDict::new(py);
|
||||||
dict.set_item("properties", response.properties)?;
|
dict.set_item("properties", response.properties)?;
|
||||||
@@ -360,15 +364,15 @@ impl Connection {
|
|||||||
let inner = self_.get_inner()?.clone();
|
let inner = self_.get_inner()?.clone();
|
||||||
let py = self_.py();
|
let py = self_.py();
|
||||||
future_into_py(py, async move {
|
future_into_py(py, async move {
|
||||||
use lance_namespace::models::{drop_namespace_request, DropNamespaceRequest};
|
use lance_namespace::models::DropNamespaceRequest;
|
||||||
let mode_enum = mode.and_then(|m| match m.to_uppercase().as_str() {
|
let mode_enum = mode.and_then(|m| match m.to_uppercase().as_str() {
|
||||||
"SKIP" => Some(drop_namespace_request::Mode::Skip),
|
"SKIP" => Some("Skip".to_string()),
|
||||||
"FAIL" => Some(drop_namespace_request::Mode::Fail),
|
"FAIL" => Some("Fail".to_string()),
|
||||||
_ => None,
|
_ => None,
|
||||||
});
|
});
|
||||||
let behavior_enum = behavior.and_then(|b| match b.to_uppercase().as_str() {
|
let behavior_enum = behavior.and_then(|b| match b.to_uppercase().as_str() {
|
||||||
"RESTRICT" => Some(drop_namespace_request::Behavior::Restrict),
|
"RESTRICT" => Some("Restrict".to_string()),
|
||||||
"CASCADE" => Some(drop_namespace_request::Behavior::Cascade),
|
"CASCADE" => Some("Cascade".to_string()),
|
||||||
_ => None,
|
_ => None,
|
||||||
});
|
});
|
||||||
let request = DropNamespaceRequest {
|
let request = DropNamespaceRequest {
|
||||||
@@ -379,8 +383,10 @@ impl Connection {
|
|||||||
},
|
},
|
||||||
mode: mode_enum,
|
mode: mode_enum,
|
||||||
behavior: behavior_enum,
|
behavior: behavior_enum,
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
let response = inner.drop_namespace(request).await.infer_error()?;
|
let response = inner.drop_namespace(request).await.infer_error()?;
|
||||||
|
#[allow(deprecated)]
|
||||||
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
|
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
|
||||||
let dict = PyDict::new(py);
|
let dict = PyDict::new(py);
|
||||||
dict.set_item("properties", response.properties)?;
|
dict.set_item("properties", response.properties)?;
|
||||||
@@ -405,8 +411,10 @@ impl Connection {
|
|||||||
} else {
|
} else {
|
||||||
Some(namespace)
|
Some(namespace)
|
||||||
},
|
},
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
let response = inner.describe_namespace(request).await.infer_error()?;
|
let response = inner.describe_namespace(request).await.infer_error()?;
|
||||||
|
#[allow(deprecated)]
|
||||||
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
|
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
|
||||||
let dict = PyDict::new(py);
|
let dict = PyDict::new(py);
|
||||||
dict.set_item("properties", response.properties)?;
|
dict.set_item("properties", response.properties)?;
|
||||||
@@ -434,8 +442,10 @@ impl Connection {
|
|||||||
},
|
},
|
||||||
page_token,
|
page_token,
|
||||||
limit: limit.map(|l| l as i32),
|
limit: limit.map(|l| l as i32),
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
let response = inner.list_tables(request).await.infer_error()?;
|
let response = inner.list_tables(request).await.infer_error()?;
|
||||||
|
#[allow(deprecated)]
|
||||||
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
|
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
|
||||||
let dict = PyDict::new(py);
|
let dict = PyDict::new(py);
|
||||||
dict.set_item("tables", response.tables)?;
|
dict.set_item("tables", response.tables)?;
|
||||||
|
|||||||
@@ -40,31 +40,34 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
|
|||||||
request_id,
|
request_id,
|
||||||
source,
|
source,
|
||||||
status_code,
|
status_code,
|
||||||
} => Python::with_gil(|py| {
|
} => {
|
||||||
let message = err.to_string();
|
#[allow(deprecated)]
|
||||||
let http_err_cls = py
|
Python::with_gil(|py| {
|
||||||
.import(intern!(py, "lancedb.remote.errors"))?
|
let message = err.to_string();
|
||||||
.getattr(intern!(py, "HttpError"))?;
|
let http_err_cls = py
|
||||||
let err = http_err_cls.call1((
|
.import(intern!(py, "lancedb.remote.errors"))?
|
||||||
message,
|
.getattr(intern!(py, "HttpError"))?;
|
||||||
request_id,
|
let err = http_err_cls.call1((
|
||||||
status_code.map(|s| s.as_u16()),
|
message,
|
||||||
))?;
|
|
||||||
|
|
||||||
if let Some(cause) = source.source() {
|
|
||||||
// The HTTP error already includes the first cause. But
|
|
||||||
// we can add the rest of the chain if there is any more.
|
|
||||||
let cause_err = http_from_rust_error(
|
|
||||||
py,
|
|
||||||
cause,
|
|
||||||
request_id,
|
request_id,
|
||||||
status_code.map(|s| s.as_u16()),
|
status_code.map(|s| s.as_u16()),
|
||||||
)?;
|
))?;
|
||||||
err.setattr(intern!(py, "__cause__"), cause_err)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Err(PyErr::from_value(err))
|
if let Some(cause) = source.source() {
|
||||||
}),
|
// The HTTP error already includes the first cause. But
|
||||||
|
// we can add the rest of the chain if there is any more.
|
||||||
|
let cause_err = http_from_rust_error(
|
||||||
|
py,
|
||||||
|
cause,
|
||||||
|
request_id,
|
||||||
|
status_code.map(|s| s.as_u16()),
|
||||||
|
)?;
|
||||||
|
err.setattr(intern!(py, "__cause__"), cause_err)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(PyErr::from_value(err))
|
||||||
|
})
|
||||||
|
}
|
||||||
LanceError::Retry {
|
LanceError::Retry {
|
||||||
request_id,
|
request_id,
|
||||||
request_failures,
|
request_failures,
|
||||||
@@ -75,33 +78,37 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
|
|||||||
max_read_failures,
|
max_read_failures,
|
||||||
source,
|
source,
|
||||||
status_code,
|
status_code,
|
||||||
} => Python::with_gil(|py| {
|
} =>
|
||||||
let cause_err = http_from_rust_error(
|
{
|
||||||
py,
|
#[allow(deprecated)]
|
||||||
source.as_ref(),
|
Python::with_gil(|py| {
|
||||||
request_id,
|
let cause_err = http_from_rust_error(
|
||||||
status_code.map(|s| s.as_u16()),
|
py,
|
||||||
)?;
|
source.as_ref(),
|
||||||
|
request_id,
|
||||||
|
status_code.map(|s| s.as_u16()),
|
||||||
|
)?;
|
||||||
|
|
||||||
let message = err.to_string();
|
let message = err.to_string();
|
||||||
let retry_error_cls = py
|
let retry_error_cls = py
|
||||||
.import(intern!(py, "lancedb.remote.errors"))?
|
.import(intern!(py, "lancedb.remote.errors"))?
|
||||||
.getattr("RetryError")?;
|
.getattr("RetryError")?;
|
||||||
let err = retry_error_cls.call1((
|
let err = retry_error_cls.call1((
|
||||||
message,
|
message,
|
||||||
request_id,
|
request_id,
|
||||||
*request_failures,
|
*request_failures,
|
||||||
*connect_failures,
|
*connect_failures,
|
||||||
*read_failures,
|
*read_failures,
|
||||||
*max_request_failures,
|
*max_request_failures,
|
||||||
*max_connect_failures,
|
*max_connect_failures,
|
||||||
*max_read_failures,
|
*max_read_failures,
|
||||||
status_code.map(|s| s.as_u16()),
|
status_code.map(|s| s.as_u16()),
|
||||||
))?;
|
))?;
|
||||||
|
|
||||||
err.setattr(intern!(py, "__cause__"), cause_err)?;
|
err.setattr(intern!(py, "__cause__"), cause_err)?;
|
||||||
Err(PyErr::from_value(err))
|
Err(PyErr::from_value(err))
|
||||||
}),
|
})
|
||||||
|
}
|
||||||
_ => self.runtime_error(),
|
_ => self.runtime_error(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ pub struct PyHeaderProvider {
|
|||||||
|
|
||||||
impl Clone for PyHeaderProvider {
|
impl Clone for PyHeaderProvider {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
|
#[allow(deprecated)]
|
||||||
Python::with_gil(|py| Self {
|
Python::with_gil(|py| Self {
|
||||||
provider: self.provider.clone_ref(py),
|
provider: self.provider.clone_ref(py),
|
||||||
})
|
})
|
||||||
@@ -25,6 +26,7 @@ impl PyHeaderProvider {
|
|||||||
|
|
||||||
/// Get headers from the Python provider (internal implementation)
|
/// Get headers from the Python provider (internal implementation)
|
||||||
fn get_headers_internal(&self) -> Result<HashMap<String, String>, String> {
|
fn get_headers_internal(&self) -> Result<HashMap<String, String>, String> {
|
||||||
|
#[allow(deprecated)]
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
// Call the get_headers method
|
// Call the get_headers method
|
||||||
let result = self.provider.call_method0(py, "get_headers");
|
let result = self.provider.call_method0(py, "get_headers");
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ use pyo3::{
|
|||||||
exceptions::PyRuntimeError,
|
exceptions::PyRuntimeError,
|
||||||
pyclass, pymethods,
|
pyclass, pymethods,
|
||||||
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
|
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
|
||||||
Bound, PyAny, PyRef, PyRefMut, PyResult, Python,
|
Bound, Py, PyAny, PyRef, PyRefMut, PyResult, Python,
|
||||||
};
|
};
|
||||||
use pyo3_async_runtimes::tokio::future_into_py;
|
use pyo3_async_runtimes::tokio::future_into_py;
|
||||||
|
|
||||||
@@ -281,7 +281,12 @@ impl PyPermutationReader {
|
|||||||
let reader = slf.reader.clone();
|
let reader = slf.reader.clone();
|
||||||
future_into_py(slf.py(), async move {
|
future_into_py(slf.py(), async move {
|
||||||
let schema = reader.output_schema(selection).await.infer_error()?;
|
let schema = reader.output_schema(selection).await.infer_error()?;
|
||||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
#[allow(deprecated)]
|
||||||
|
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
|
||||||
|
let bound = schema.to_pyarrow(py)?;
|
||||||
|
Ok(bound.unbind())
|
||||||
|
})?;
|
||||||
|
Ok(py_obj)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ use pyo3::types::PyList;
|
|||||||
use pyo3::types::{PyDict, PyString};
|
use pyo3::types::{PyDict, PyString};
|
||||||
use pyo3::Bound;
|
use pyo3::Bound;
|
||||||
use pyo3::IntoPyObject;
|
use pyo3::IntoPyObject;
|
||||||
|
use pyo3::Py;
|
||||||
use pyo3::PyAny;
|
use pyo3::PyAny;
|
||||||
use pyo3::PyRef;
|
use pyo3::PyRef;
|
||||||
use pyo3::PyResult;
|
use pyo3::PyResult;
|
||||||
@@ -453,7 +454,12 @@ impl Query {
|
|||||||
let inner = self_.inner.clone();
|
let inner = self_.inner.clone();
|
||||||
future_into_py(self_.py(), async move {
|
future_into_py(self_.py(), async move {
|
||||||
let schema = inner.output_schema().await.infer_error()?;
|
let schema = inner.output_schema().await.infer_error()?;
|
||||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
#[allow(deprecated)]
|
||||||
|
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
|
||||||
|
let bound = schema.to_pyarrow(py)?;
|
||||||
|
Ok(bound.unbind())
|
||||||
|
})?;
|
||||||
|
Ok(py_obj)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -532,7 +538,12 @@ impl TakeQuery {
|
|||||||
let inner = self_.inner.clone();
|
let inner = self_.inner.clone();
|
||||||
future_into_py(self_.py(), async move {
|
future_into_py(self_.py(), async move {
|
||||||
let schema = inner.output_schema().await.infer_error()?;
|
let schema = inner.output_schema().await.infer_error()?;
|
||||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
#[allow(deprecated)]
|
||||||
|
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
|
||||||
|
let bound = schema.to_pyarrow(py)?;
|
||||||
|
Ok(bound.unbind())
|
||||||
|
})?;
|
||||||
|
Ok(py_obj)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -627,7 +638,12 @@ impl FTSQuery {
|
|||||||
let inner = self_.inner.clone();
|
let inner = self_.inner.clone();
|
||||||
future_into_py(self_.py(), async move {
|
future_into_py(self_.py(), async move {
|
||||||
let schema = inner.output_schema().await.infer_error()?;
|
let schema = inner.output_schema().await.infer_error()?;
|
||||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
#[allow(deprecated)]
|
||||||
|
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
|
||||||
|
let bound = schema.to_pyarrow(py)?;
|
||||||
|
Ok(bound.unbind())
|
||||||
|
})?;
|
||||||
|
Ok(py_obj)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -806,7 +822,12 @@ impl VectorQuery {
|
|||||||
let inner = self_.inner.clone();
|
let inner = self_.inner.clone();
|
||||||
future_into_py(self_.py(), async move {
|
future_into_py(self_.py(), async move {
|
||||||
let schema = inner.output_schema().await.infer_error()?;
|
let schema = inner.output_schema().await.infer_error()?;
|
||||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
#[allow(deprecated)]
|
||||||
|
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
|
||||||
|
let bound = schema.to_pyarrow(py)?;
|
||||||
|
Ok(bound.unbind())
|
||||||
|
})?;
|
||||||
|
Ok(py_obj)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,11 +17,12 @@ use pyo3::types::PyDict;
|
|||||||
/// Internal wrapper around a Python object implementing StorageOptionsProvider
|
/// Internal wrapper around a Python object implementing StorageOptionsProvider
|
||||||
pub struct PyStorageOptionsProvider {
|
pub struct PyStorageOptionsProvider {
|
||||||
/// The Python object implementing fetch_storage_options()
|
/// The Python object implementing fetch_storage_options()
|
||||||
inner: PyObject,
|
inner: Py<PyAny>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Clone for PyStorageOptionsProvider {
|
impl Clone for PyStorageOptionsProvider {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
|
#[allow(deprecated)]
|
||||||
Python::with_gil(|py| Self {
|
Python::with_gil(|py| Self {
|
||||||
inner: self.inner.clone_ref(py),
|
inner: self.inner.clone_ref(py),
|
||||||
})
|
})
|
||||||
@@ -29,7 +30,8 @@ impl Clone for PyStorageOptionsProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl PyStorageOptionsProvider {
|
impl PyStorageOptionsProvider {
|
||||||
pub fn new(obj: PyObject) -> PyResult<Self> {
|
pub fn new(obj: Py<PyAny>) -> PyResult<Self> {
|
||||||
|
#[allow(deprecated)]
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
// Verify the object has a fetch_storage_options method
|
// Verify the object has a fetch_storage_options method
|
||||||
if !obj.bind(py).hasattr("fetch_storage_options")? {
|
if !obj.bind(py).hasattr("fetch_storage_options")? {
|
||||||
@@ -37,7 +39,9 @@ impl PyStorageOptionsProvider {
|
|||||||
"StorageOptionsProvider must implement fetch_storage_options() method",
|
"StorageOptionsProvider must implement fetch_storage_options() method",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
Ok(Self { inner: obj })
|
Ok(Self {
|
||||||
|
inner: obj.clone_ref(py),
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -60,6 +64,7 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
|
|||||||
let py_provider = self.py_provider.clone();
|
let py_provider = self.py_provider.clone();
|
||||||
|
|
||||||
tokio::task::spawn_blocking(move || {
|
tokio::task::spawn_blocking(move || {
|
||||||
|
#[allow(deprecated)]
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
// Call the Python fetch_storage_options method
|
// Call the Python fetch_storage_options method
|
||||||
let result = py_provider
|
let result = py_provider
|
||||||
@@ -119,6 +124,7 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn provider_id(&self) -> String {
|
fn provider_id(&self) -> String {
|
||||||
|
#[allow(deprecated)]
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
// Call provider_id() method on the Python object
|
// Call provider_id() method on the Python object
|
||||||
let obj = self.py_provider.inner.bind(py);
|
let obj = self.py_provider.inner.bind(py);
|
||||||
@@ -143,7 +149,7 @@ impl std::fmt::Debug for PyStorageOptionsProviderWrapper {
|
|||||||
/// This is the main entry point for converting Python StorageOptionsProvider objects
|
/// This is the main entry point for converting Python StorageOptionsProvider objects
|
||||||
/// to Rust trait objects that can be used by the Lance ecosystem.
|
/// to Rust trait objects that can be used by the Lance ecosystem.
|
||||||
pub fn py_object_to_storage_options_provider(
|
pub fn py_object_to_storage_options_provider(
|
||||||
py_obj: PyObject,
|
py_obj: Py<PyAny>,
|
||||||
) -> PyResult<Arc<dyn StorageOptionsProvider>> {
|
) -> PyResult<Arc<dyn StorageOptionsProvider>> {
|
||||||
let py_provider = PyStorageOptionsProvider::new(py_obj)?;
|
let py_provider = PyStorageOptionsProvider::new(py_obj)?;
|
||||||
Ok(Arc::new(PyStorageOptionsProviderWrapper::new(py_provider)))
|
Ok(Arc::new(PyStorageOptionsProviderWrapper::new(py_provider)))
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ use pyo3::{
|
|||||||
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
|
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
|
||||||
pyclass, pymethods,
|
pyclass, pymethods,
|
||||||
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
|
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
|
||||||
Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
|
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
||||||
};
|
};
|
||||||
use pyo3_async_runtimes::tokio::future_into_py;
|
use pyo3_async_runtimes::tokio::future_into_py;
|
||||||
|
|
||||||
@@ -287,7 +287,12 @@ impl Table {
|
|||||||
let inner = self_.inner_ref()?.clone();
|
let inner = self_.inner_ref()?.clone();
|
||||||
future_into_py(self_.py(), async move {
|
future_into_py(self_.py(), async move {
|
||||||
let schema = inner.schema().await.infer_error()?;
|
let schema = inner.schema().await.infer_error()?;
|
||||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
#[allow(deprecated)]
|
||||||
|
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
|
||||||
|
let bound = schema.to_pyarrow(py)?;
|
||||||
|
Ok(bound.unbind())
|
||||||
|
})?;
|
||||||
|
Ok(py_obj)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -437,6 +442,7 @@ impl Table {
|
|||||||
future_into_py(self_.py(), async move {
|
future_into_py(self_.py(), async move {
|
||||||
let stats = inner.index_stats(&index_name).await.infer_error()?;
|
let stats = inner.index_stats(&index_name).await.infer_error()?;
|
||||||
if let Some(stats) = stats {
|
if let Some(stats) = stats {
|
||||||
|
#[allow(deprecated)]
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
let dict = PyDict::new(py);
|
let dict = PyDict::new(py);
|
||||||
dict.set_item("num_indexed_rows", stats.num_indexed_rows)?;
|
dict.set_item("num_indexed_rows", stats.num_indexed_rows)?;
|
||||||
@@ -467,6 +473,7 @@ impl Table {
|
|||||||
let inner = self_.inner_ref()?.clone();
|
let inner = self_.inner_ref()?.clone();
|
||||||
future_into_py(self_.py(), async move {
|
future_into_py(self_.py(), async move {
|
||||||
let stats = inner.stats().await.infer_error()?;
|
let stats = inner.stats().await.infer_error()?;
|
||||||
|
#[allow(deprecated)]
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
let dict = PyDict::new(py);
|
let dict = PyDict::new(py);
|
||||||
dict.set_item("total_bytes", stats.total_bytes)?;
|
dict.set_item("total_bytes", stats.total_bytes)?;
|
||||||
@@ -516,6 +523,7 @@ impl Table {
|
|||||||
let inner = self_.inner_ref()?.clone();
|
let inner = self_.inner_ref()?.clone();
|
||||||
future_into_py(self_.py(), async move {
|
future_into_py(self_.py(), async move {
|
||||||
let versions = inner.list_versions().await.infer_error()?;
|
let versions = inner.list_versions().await.infer_error()?;
|
||||||
|
#[allow(deprecated)]
|
||||||
let versions_as_dict = Python::with_gil(|py| {
|
let versions_as_dict = Python::with_gil(|py| {
|
||||||
versions
|
versions
|
||||||
.iter()
|
.iter()
|
||||||
@@ -867,6 +875,7 @@ impl Tags {
|
|||||||
let tags = inner.tags().await.infer_error()?;
|
let tags = inner.tags().await.infer_error()?;
|
||||||
let res = tags.list().await.infer_error()?;
|
let res = tags.list().await.infer_error()?;
|
||||||
|
|
||||||
|
#[allow(deprecated)]
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
let py_dict = PyDict::new(py);
|
let py_dict = PyDict::new(py);
|
||||||
for (key, contents) in res {
|
for (key, contents) in res {
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.23.1-beta.1"
|
version = "0.23.1"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
@@ -104,11 +104,16 @@ test-log = "0.2"
|
|||||||
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["aws", "gcs", "azure", "dynamodb", "oss"]
|
default = []
|
||||||
aws = ["lance/aws", "lance-io/aws", "lance-namespace-impls/dir-aws"]
|
aws = ["lance/aws", "lance-io/aws", "lance-namespace-impls/dir-aws"]
|
||||||
oss = ["lance/oss", "lance-io/oss", "lance-namespace-impls/dir-oss"]
|
oss = ["lance/oss", "lance-io/oss", "lance-namespace-impls/dir-oss"]
|
||||||
gcs = ["lance/gcp", "lance-io/gcp", "lance-namespace-impls/dir-gcp"]
|
gcs = ["lance/gcp", "lance-io/gcp", "lance-namespace-impls/dir-gcp"]
|
||||||
azure = ["lance/azure", "lance-io/azure", "lance-namespace-impls/dir-azure"]
|
azure = ["lance/azure", "lance-io/azure", "lance-namespace-impls/dir-azure"]
|
||||||
|
huggingface = [
|
||||||
|
"lance/huggingface",
|
||||||
|
"lance-io/huggingface",
|
||||||
|
"lance-namespace-impls/dir-huggingface",
|
||||||
|
]
|
||||||
dynamodb = ["lance/dynamodb", "aws"]
|
dynamodb = ["lance/dynamodb", "aws"]
|
||||||
remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
|
remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
|
||||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||||
@@ -148,3 +153,6 @@ name = "ivf_pq"
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "hybrid_search"
|
name = "hybrid_search"
|
||||||
required-features = ["sentence-transformers"]
|
required-features = ["sentence-transformers"]
|
||||||
|
|
||||||
|
[package.metadata.docs.rs]
|
||||||
|
all-features = true
|
||||||
|
|||||||
@@ -1325,25 +1325,27 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_table_names() {
|
async fn test_table_names() {
|
||||||
let tmp_dir = tempdir().unwrap();
|
let tc = new_test_connection().await.unwrap();
|
||||||
|
let db = tc.connection;
|
||||||
|
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));
|
||||||
let mut names = Vec::with_capacity(100);
|
let mut names = Vec::with_capacity(100);
|
||||||
for _ in 0..100 {
|
for _ in 0..100 {
|
||||||
let mut name = uuid::Uuid::new_v4().to_string();
|
let name = uuid::Uuid::new_v4().to_string();
|
||||||
names.push(name.clone());
|
names.push(name.clone());
|
||||||
name.push_str(".lance");
|
db.create_empty_table(name, schema.clone())
|
||||||
create_dir_all(tmp_dir.path().join(&name)).unwrap();
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
names.sort();
|
names.sort();
|
||||||
|
let tables = db.table_names().limit(100).execute().await.unwrap();
|
||||||
let uri = tmp_dir.path().to_str().unwrap();
|
|
||||||
let db = connect(uri).execute().await.unwrap();
|
|
||||||
let tables = db.table_names().execute().await.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(tables, names);
|
assert_eq!(tables, names);
|
||||||
|
|
||||||
let tables = db
|
let tables = db
|
||||||
.table_names()
|
.table_names()
|
||||||
.start_after(&names[30])
|
.start_after(&names[30])
|
||||||
|
.limit(100)
|
||||||
.execute()
|
.execute()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ use std::sync::Arc;
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use lance_namespace::{
|
use lance_namespace::{
|
||||||
models::{
|
models::{
|
||||||
CreateEmptyTableRequest, CreateNamespaceRequest, CreateNamespaceResponse,
|
CreateNamespaceRequest, CreateNamespaceResponse, DeclareTableRequest,
|
||||||
DescribeNamespaceRequest, DescribeNamespaceResponse, DescribeTableRequest,
|
DescribeNamespaceRequest, DescribeNamespaceResponse, DescribeTableRequest,
|
||||||
DropNamespaceRequest, DropNamespaceResponse, DropTableRequest, ListNamespacesRequest,
|
DropNamespaceRequest, DropNamespaceResponse, DropTableRequest, ListNamespacesRequest,
|
||||||
ListNamespacesResponse, ListTablesRequest, ListTablesResponse,
|
ListNamespacesResponse, ListTablesRequest, ListTablesResponse,
|
||||||
@@ -137,6 +137,7 @@ impl Database for LanceNamespaceDatabase {
|
|||||||
id: Some(request.namespace),
|
id: Some(request.namespace),
|
||||||
page_token: request.start_after,
|
page_token: request.start_after,
|
||||||
limit: request.limit.map(|l| l as i32),
|
limit: request.limit.map(|l| l as i32),
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = self.namespace.list_tables(ns_request).await?;
|
let response = self.namespace.list_tables(ns_request).await?;
|
||||||
@@ -154,6 +155,7 @@ impl Database for LanceNamespaceDatabase {
|
|||||||
let describe_request = DescribeTableRequest {
|
let describe_request = DescribeTableRequest {
|
||||||
id: Some(table_id.clone()),
|
id: Some(table_id.clone()),
|
||||||
version: None,
|
version: None,
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let describe_result = self.namespace.describe_table(describe_request).await;
|
let describe_result = self.namespace.describe_table(describe_request).await;
|
||||||
@@ -171,6 +173,7 @@ impl Database for LanceNamespaceDatabase {
|
|||||||
// Drop the existing table - must succeed
|
// Drop the existing table - must succeed
|
||||||
let drop_request = DropTableRequest {
|
let drop_request = DropTableRequest {
|
||||||
id: Some(table_id.clone()),
|
id: Some(table_id.clone()),
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
self.namespace
|
self.namespace
|
||||||
.drop_table(drop_request)
|
.drop_table(drop_request)
|
||||||
@@ -202,22 +205,19 @@ impl Database for LanceNamespaceDatabase {
|
|||||||
let mut table_id = request.namespace.clone();
|
let mut table_id = request.namespace.clone();
|
||||||
table_id.push(request.name.clone());
|
table_id.push(request.name.clone());
|
||||||
|
|
||||||
let create_empty_request = CreateEmptyTableRequest {
|
let create_empty_request = DeclareTableRequest {
|
||||||
id: Some(table_id.clone()),
|
id: Some(table_id.clone()),
|
||||||
location: None,
|
location: None,
|
||||||
properties: if self.storage_options.is_empty() {
|
vend_credentials: None,
|
||||||
None
|
..Default::default()
|
||||||
} else {
|
|
||||||
Some(self.storage_options.clone())
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let create_empty_response = self
|
let create_empty_response = self
|
||||||
.namespace
|
.namespace
|
||||||
.create_empty_table(create_empty_request)
|
.declare_table(create_empty_request)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| Error::Runtime {
|
.map_err(|e| Error::Runtime {
|
||||||
message: format!("Failed to create empty table: {}", e),
|
message: format!("Failed to declare table: {}", e),
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let location = create_empty_response
|
let location = create_empty_response
|
||||||
@@ -281,7 +281,10 @@ impl Database for LanceNamespaceDatabase {
|
|||||||
let mut table_id = namespace.to_vec();
|
let mut table_id = namespace.to_vec();
|
||||||
table_id.push(name.to_string());
|
table_id.push(name.to_string());
|
||||||
|
|
||||||
let drop_request = DropTableRequest { id: Some(table_id) };
|
let drop_request = DropTableRequest {
|
||||||
|
id: Some(table_id),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
self.namespace
|
self.namespace
|
||||||
.drop_table(drop_request)
|
.drop_table(drop_request)
|
||||||
.await
|
.await
|
||||||
@@ -438,6 +441,7 @@ mod tests {
|
|||||||
id: Some(vec!["test_ns".into()]),
|
id: Some(vec!["test_ns".into()]),
|
||||||
mode: None,
|
mode: None,
|
||||||
properties: None,
|
properties: None,
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create namespace");
|
.expect("Failed to create namespace");
|
||||||
@@ -499,6 +503,7 @@ mod tests {
|
|||||||
id: Some(vec!["test_ns".into()]),
|
id: Some(vec!["test_ns".into()]),
|
||||||
mode: None,
|
mode: None,
|
||||||
properties: None,
|
properties: None,
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create namespace");
|
.expect("Failed to create namespace");
|
||||||
@@ -563,6 +568,7 @@ mod tests {
|
|||||||
id: Some(vec!["test_ns".into()]),
|
id: Some(vec!["test_ns".into()]),
|
||||||
mode: None,
|
mode: None,
|
||||||
properties: None,
|
properties: None,
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create namespace");
|
.expect("Failed to create namespace");
|
||||||
@@ -647,6 +653,7 @@ mod tests {
|
|||||||
id: Some(vec!["test_ns".into()]),
|
id: Some(vec!["test_ns".into()]),
|
||||||
mode: None,
|
mode: None,
|
||||||
properties: None,
|
properties: None,
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create namespace");
|
.expect("Failed to create namespace");
|
||||||
@@ -703,6 +710,7 @@ mod tests {
|
|||||||
id: Some(vec!["test_ns".into()]),
|
id: Some(vec!["test_ns".into()]),
|
||||||
mode: None,
|
mode: None,
|
||||||
properties: None,
|
properties: None,
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create namespace");
|
.expect("Failed to create namespace");
|
||||||
@@ -784,6 +792,7 @@ mod tests {
|
|||||||
id: Some(vec!["test_ns".into()]),
|
id: Some(vec!["test_ns".into()]),
|
||||||
mode: None,
|
mode: None,
|
||||||
properties: None,
|
properties: None,
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create namespace");
|
.expect("Failed to create namespace");
|
||||||
@@ -818,6 +827,7 @@ mod tests {
|
|||||||
id: Some(vec!["test_ns".into()]),
|
id: Some(vec!["test_ns".into()]),
|
||||||
mode: None,
|
mode: None,
|
||||||
properties: None,
|
properties: None,
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create namespace");
|
.expect("Failed to create namespace");
|
||||||
|
|||||||
@@ -120,8 +120,13 @@ impl MemoryRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// A record batch reader that has embeddings applied to it
|
/// A record batch reader that has embeddings applied to it
|
||||||
/// This is a wrapper around another record batch reader that applies an embedding function
|
///
|
||||||
/// when reading from the record batch
|
/// This is a wrapper around another record batch reader that applies embedding functions
|
||||||
|
/// when reading from the record batch.
|
||||||
|
///
|
||||||
|
/// When multiple embedding functions are defined, they are computed in parallel using
|
||||||
|
/// scoped threads to improve performance. For a single embedding function, computation
|
||||||
|
/// is done inline without threading overhead.
|
||||||
pub struct WithEmbeddings<R: RecordBatchReader> {
|
pub struct WithEmbeddings<R: RecordBatchReader> {
|
||||||
inner: R,
|
inner: R,
|
||||||
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
||||||
@@ -235,6 +240,48 @@ impl<R: RecordBatchReader> WithEmbeddings<R> {
|
|||||||
column_definitions,
|
column_definitions,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn compute_embeddings_parallel(&self, batch: &RecordBatch) -> Result<Vec<Arc<dyn Array>>> {
|
||||||
|
if self.embeddings.len() == 1 {
|
||||||
|
let (fld, func) = &self.embeddings[0];
|
||||||
|
let src_column =
|
||||||
|
batch
|
||||||
|
.column_by_name(&fld.source_column)
|
||||||
|
.ok_or_else(|| Error::InvalidInput {
|
||||||
|
message: format!("Source column '{}' not found", fld.source_column),
|
||||||
|
})?;
|
||||||
|
return Ok(vec![func.compute_source_embeddings(src_column.clone())?]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parallel path: multiple embeddings
|
||||||
|
std::thread::scope(|s| {
|
||||||
|
let handles: Vec<_> = self
|
||||||
|
.embeddings
|
||||||
|
.iter()
|
||||||
|
.map(|(fld, func)| {
|
||||||
|
let src_column = batch.column_by_name(&fld.source_column).ok_or_else(|| {
|
||||||
|
Error::InvalidInput {
|
||||||
|
message: format!("Source column '{}' not found", fld.source_column),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let handle =
|
||||||
|
s.spawn(move || func.compute_source_embeddings(src_column.clone()));
|
||||||
|
|
||||||
|
Ok(handle)
|
||||||
|
})
|
||||||
|
.collect::<Result<_>>()?;
|
||||||
|
|
||||||
|
handles
|
||||||
|
.into_iter()
|
||||||
|
.map(|h| {
|
||||||
|
h.join().map_err(|e| Error::Runtime {
|
||||||
|
message: format!("Thread panicked during embedding computation: {:?}", e),
|
||||||
|
})?
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R: RecordBatchReader> Iterator for MaybeEmbedded<R> {
|
impl<R: RecordBatchReader> Iterator for MaybeEmbedded<R> {
|
||||||
@@ -262,19 +309,19 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
|
|||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
let batch = self.inner.next()?;
|
let batch = self.inner.next()?;
|
||||||
match batch {
|
match batch {
|
||||||
Ok(mut batch) => {
|
Ok(batch) => {
|
||||||
// todo: parallelize this
|
let embeddings = match self.compute_embeddings_parallel(&batch) {
|
||||||
for (fld, func) in self.embeddings.iter() {
|
Ok(emb) => emb,
|
||||||
let src_column = batch.column_by_name(&fld.source_column).unwrap();
|
Err(e) => {
|
||||||
let embedding = match func.compute_source_embeddings(src_column.clone()) {
|
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
|
||||||
Ok(embedding) => embedding,
|
"Error computing embedding: {}",
|
||||||
Err(e) => {
|
e
|
||||||
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
|
))))
|
||||||
"Error computing embedding: {}",
|
}
|
||||||
e
|
};
|
||||||
))))
|
|
||||||
}
|
let mut batch = batch;
|
||||||
};
|
for ((fld, _), embedding) in self.embeddings.iter().zip(embeddings.iter()) {
|
||||||
let dst_field_name = fld
|
let dst_field_name = fld
|
||||||
.dest_column
|
.dest_column
|
||||||
.clone()
|
.clone()
|
||||||
@@ -286,7 +333,7 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
|
|||||||
embedding.nulls().is_some(),
|
embedding.nulls().is_some(),
|
||||||
);
|
);
|
||||||
|
|
||||||
match batch.try_with_column(dst_field.clone(), embedding) {
|
match batch.try_with_column(dst_field.clone(), embedding.clone()) {
|
||||||
Ok(b) => batch = b,
|
Ok(b) => batch = b,
|
||||||
Err(e) => return Some(Err(e)),
|
Err(e) => return Some(Err(e)),
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -25,13 +25,14 @@
|
|||||||
//!
|
//!
|
||||||
//! ## Crate Features
|
//! ## Crate Features
|
||||||
//!
|
//!
|
||||||
//! ### Experimental Features
|
//! - `aws` - Enable AWS S3 object store support.
|
||||||
//!
|
//! - `dynamodb` - Enable DynamoDB manifest store support.
|
||||||
//! These features are not enabled by default. They are experimental or in-development features that
|
//! - `azure` - Enable Azure Blob Storage object store support.
|
||||||
//! are not yet ready to be released.
|
//! - `gcs` - Enable Google Cloud Storage object store support.
|
||||||
//!
|
//! - `oss` - Enable Alibaba Cloud OSS object store support.
|
||||||
//! - `remote` - Enable remote client to connect to LanceDB cloud. This is not yet fully implemented
|
//! - `remote` - Enable remote client to connect to LanceDB cloud.
|
||||||
//! and should not be enabled.
|
//! - `huggingface` - Enable HuggingFace Hub integration for loading datasets from the Hub.
|
||||||
|
//! - `fp16kernels` - Enable FP16 kernels for faster vector search on CPU.
|
||||||
//!
|
//!
|
||||||
//! ### Quick Start
|
//! ### Quick Start
|
||||||
//!
|
//!
|
||||||
|
|||||||
@@ -1720,6 +1720,7 @@ mod tests {
|
|||||||
id: Some(namespace.clone()),
|
id: Some(namespace.clone()),
|
||||||
mode: None,
|
mode: None,
|
||||||
properties: None,
|
properties: None,
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create namespace");
|
.expect("Failed to create namespace");
|
||||||
@@ -1746,6 +1747,7 @@ mod tests {
|
|||||||
id: Some(namespace.clone()),
|
id: Some(namespace.clone()),
|
||||||
page_token: None,
|
page_token: None,
|
||||||
limit: None,
|
limit: None,
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.expect("Failed to list tables");
|
.expect("Failed to list tables");
|
||||||
@@ -1758,6 +1760,7 @@ mod tests {
|
|||||||
id: Some(namespace.clone()),
|
id: Some(namespace.clone()),
|
||||||
page_token: None,
|
page_token: None,
|
||||||
limit: None,
|
limit: None,
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1799,6 +1802,7 @@ mod tests {
|
|||||||
id: Some(namespace.clone()),
|
id: Some(namespace.clone()),
|
||||||
mode: None,
|
mode: None,
|
||||||
properties: None,
|
properties: None,
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create namespace");
|
.expect("Failed to create namespace");
|
||||||
@@ -1825,6 +1829,7 @@ mod tests {
|
|||||||
id: Some(namespace.clone()),
|
id: Some(namespace.clone()),
|
||||||
page_token: None,
|
page_token: None,
|
||||||
limit: None,
|
limit: None,
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|||||||
@@ -1088,6 +1088,17 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
|||||||
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
|
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Index::IvfRq(index) => {
|
||||||
|
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_RQ".to_string());
|
||||||
|
body[METRIC_TYPE_KEY] =
|
||||||
|
serde_json::Value::String(index.distance_type.to_string().to_lowercase());
|
||||||
|
if let Some(num_partitions) = index.num_partitions {
|
||||||
|
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
|
||||||
|
}
|
||||||
|
if let Some(num_bits) = index.num_bits {
|
||||||
|
body["num_bits"] = serde_json::Value::Number(num_bits.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
Index::BTree(_) => {
|
Index::BTree(_) => {
|
||||||
body[INDEX_TYPE_KEY] = serde_json::Value::String("BTREE".to_string());
|
body[INDEX_TYPE_KEY] = serde_json::Value::String("BTREE".to_string());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,8 +42,8 @@ use lance_index::DatasetIndexExt;
|
|||||||
use lance_index::IndexType;
|
use lance_index::IndexType;
|
||||||
use lance_io::object_store::LanceNamespaceStorageOptionsProvider;
|
use lance_io::object_store::LanceNamespaceStorageOptionsProvider;
|
||||||
use lance_namespace::models::{
|
use lance_namespace::models::{
|
||||||
QueryTableRequest as NsQueryTableRequest, QueryTableRequestFullTextQuery,
|
QueryTableRequest as NsQueryTableRequest, QueryTableRequestColumns,
|
||||||
QueryTableRequestVector, StringFtsQuery,
|
QueryTableRequestFullTextQuery, QueryTableRequestVector, StringFtsQuery,
|
||||||
};
|
};
|
||||||
use lance_namespace::LanceNamespace;
|
use lance_namespace::LanceNamespace;
|
||||||
use lance_table::format::Manifest;
|
use lance_table::format::Manifest;
|
||||||
@@ -1424,7 +1424,9 @@ impl Table {
|
|||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
let unioned = Arc::new(UnionExec::new(projected_plans));
|
let unioned = UnionExec::try_new(projected_plans).map_err(|e| Error::Runtime {
|
||||||
|
message: format!("Failed to build union plan: {e}"),
|
||||||
|
})?;
|
||||||
// We require 1 partition in the final output
|
// We require 1 partition in the final output
|
||||||
let repartitioned = RepartitionExec::try_new(
|
let repartitioned = RepartitionExec::try_new(
|
||||||
unioned,
|
unioned,
|
||||||
@@ -2346,9 +2348,12 @@ impl NativeTable {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Convert select to columns list
|
// Convert select to columns list
|
||||||
let columns = match &vq.base.select {
|
let columns: Option<Box<QueryTableRequestColumns>> = match &vq.base.select {
|
||||||
Select::All => None,
|
Select::All => None,
|
||||||
Select::Columns(cols) => Some(cols.clone()),
|
Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns {
|
||||||
|
column_names: Some(cols.clone()),
|
||||||
|
column_aliases: None,
|
||||||
|
})),
|
||||||
Select::Dynamic(_) => {
|
Select::Dynamic(_) => {
|
||||||
return Err(Error::NotSupported {
|
return Err(Error::NotSupported {
|
||||||
message:
|
message:
|
||||||
@@ -2402,6 +2407,7 @@ impl NativeTable {
|
|||||||
bypass_vector_index: Some(!vq.use_index),
|
bypass_vector_index: Some(!vq.use_index),
|
||||||
full_text_query,
|
full_text_query,
|
||||||
version: None,
|
version: None,
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
AnyQuery::Query(q) => {
|
AnyQuery::Query(q) => {
|
||||||
@@ -2419,9 +2425,12 @@ impl NativeTable {
|
|||||||
.map(|f| self.filter_to_sql(f))
|
.map(|f| self.filter_to_sql(f))
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
|
|
||||||
let columns = match &q.select {
|
let columns: Option<Box<QueryTableRequestColumns>> = match &q.select {
|
||||||
Select::All => None,
|
Select::All => None,
|
||||||
Select::Columns(cols) => Some(cols.clone()),
|
Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns {
|
||||||
|
column_names: Some(cols.clone()),
|
||||||
|
column_aliases: None,
|
||||||
|
})),
|
||||||
Select::Dynamic(_) => {
|
Select::Dynamic(_) => {
|
||||||
return Err(Error::NotSupported {
|
return Err(Error::NotSupported {
|
||||||
message: "Dynamic columns are not supported for server-side query"
|
message: "Dynamic columns are not supported for server-side query"
|
||||||
@@ -2472,6 +2481,7 @@ impl NativeTable {
|
|||||||
fast_search: None,
|
fast_search: None,
|
||||||
lower_bound: None,
|
lower_bound: None,
|
||||||
upper_bound: None,
|
upper_bound: None,
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -5143,10 +5153,15 @@ mod tests {
|
|||||||
let any_query = AnyQuery::VectorQuery(vq);
|
let any_query = AnyQuery::VectorQuery(vq);
|
||||||
let ns_request = table.convert_to_namespace_query(&any_query).unwrap();
|
let ns_request = table.convert_to_namespace_query(&any_query).unwrap();
|
||||||
|
|
||||||
|
let column_names = ns_request
|
||||||
|
.columns
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|cols| cols.column_names.clone());
|
||||||
|
|
||||||
assert_eq!(ns_request.k, 10);
|
assert_eq!(ns_request.k, 10);
|
||||||
assert_eq!(ns_request.offset, Some(5));
|
assert_eq!(ns_request.offset, Some(5));
|
||||||
assert_eq!(ns_request.filter, Some("id > 0".to_string()));
|
assert_eq!(ns_request.filter, Some("id > 0".to_string()));
|
||||||
assert_eq!(ns_request.columns, Some(vec!["id".to_string()]));
|
assert_eq!(column_names, Some(vec!["id".to_string()]));
|
||||||
assert_eq!(ns_request.vector_column, Some("vector".to_string()));
|
assert_eq!(ns_request.vector_column, Some("vector".to_string()));
|
||||||
assert_eq!(ns_request.distance_type, Some("l2".to_string()));
|
assert_eq!(ns_request.distance_type, Some("l2".to_string()));
|
||||||
assert!(ns_request.vector.single_vector.is_some());
|
assert!(ns_request.vector.single_vector.is_some());
|
||||||
@@ -5183,11 +5198,16 @@ mod tests {
|
|||||||
let any_query = AnyQuery::Query(q);
|
let any_query = AnyQuery::Query(q);
|
||||||
let ns_request = table.convert_to_namespace_query(&any_query).unwrap();
|
let ns_request = table.convert_to_namespace_query(&any_query).unwrap();
|
||||||
|
|
||||||
|
let column_names = ns_request
|
||||||
|
.columns
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|cols| cols.column_names.clone());
|
||||||
|
|
||||||
// Plain queries should pass an empty vector
|
// Plain queries should pass an empty vector
|
||||||
assert_eq!(ns_request.k, 20);
|
assert_eq!(ns_request.k, 20);
|
||||||
assert_eq!(ns_request.offset, Some(5));
|
assert_eq!(ns_request.offset, Some(5));
|
||||||
assert_eq!(ns_request.filter, Some("id > 5".to_string()));
|
assert_eq!(ns_request.filter, Some("id > 5".to_string()));
|
||||||
assert_eq!(ns_request.columns, Some(vec!["id".to_string()]));
|
assert_eq!(column_names, Some(vec!["id".to_string()]));
|
||||||
assert_eq!(ns_request.with_row_id, Some(true));
|
assert_eq!(ns_request.with_row_id, Some(true));
|
||||||
assert_eq!(ns_request.bypass_vector_index, Some(true));
|
assert_eq!(ns_request.bypass_vector_index, Some(true));
|
||||||
assert!(ns_request.vector_column.is_none()); // No vector column for plain queries
|
assert!(ns_request.vector_column.is_none()); // No vector column for plain queries
|
||||||
|
|||||||
@@ -100,7 +100,8 @@ impl DatasetRef {
|
|||||||
let should_checkout = match &target_ref {
|
let should_checkout = match &target_ref {
|
||||||
refs::Ref::Version(_, Some(target_ver)) => version != target_ver,
|
refs::Ref::Version(_, Some(target_ver)) => version != target_ver,
|
||||||
refs::Ref::Version(_, None) => true, // No specific version, always checkout
|
refs::Ref::Version(_, None) => true, // No specific version, always checkout
|
||||||
refs::Ref::Tag(_) => true, // Always checkout for tags
|
refs::Ref::VersionNumber(target_ver) => version != target_ver,
|
||||||
|
refs::Ref::Tag(_) => true, // Always checkout for tags
|
||||||
};
|
};
|
||||||
|
|
||||||
if should_checkout {
|
if should_checkout {
|
||||||
|
|||||||
@@ -5,16 +5,19 @@
|
|||||||
|
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::io::{BufRead, BufReader};
|
use std::process::Stdio;
|
||||||
use std::process::{Child, ChildStdout, Command, Stdio};
|
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||||
|
use tokio::process::{Child, ChildStdout, Command};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
use crate::{connect, Connection};
|
use crate::{connect, Connection};
|
||||||
use anyhow::{bail, Result};
|
use anyhow::{anyhow, bail, Result};
|
||||||
use tempfile::{tempdir, TempDir};
|
use tempfile::{tempdir, TempDir};
|
||||||
|
|
||||||
pub struct TestConnection {
|
pub struct TestConnection {
|
||||||
pub uri: String,
|
pub uri: String,
|
||||||
pub connection: Connection,
|
pub connection: Connection,
|
||||||
|
pub is_remote: bool,
|
||||||
_temp_dir: Option<TempDir>,
|
_temp_dir: Option<TempDir>,
|
||||||
_process: Option<TestProcess>,
|
_process: Option<TestProcess>,
|
||||||
}
|
}
|
||||||
@@ -37,6 +40,56 @@ pub async fn new_test_connection() -> Result<TestConnection> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn spawn_stdout_reader(
|
||||||
|
mut stdout: BufReader<ChildStdout>,
|
||||||
|
port_sender: mpsc::Sender<anyhow::Result<String>>,
|
||||||
|
) -> tokio::task::JoinHandle<()> {
|
||||||
|
let print_stdout = env::var("PRINT_LANCEDB_TEST_CONNECTION_SCRIPT_OUTPUT").is_ok();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut line = String::new();
|
||||||
|
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
|
||||||
|
loop {
|
||||||
|
line.clear();
|
||||||
|
let result = stdout.read_line(&mut line).await;
|
||||||
|
if let Err(err) = result {
|
||||||
|
port_sender
|
||||||
|
.send(Err(anyhow!(
|
||||||
|
"error while reading from process output: {}",
|
||||||
|
err
|
||||||
|
)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
return;
|
||||||
|
} else if result.unwrap() == 0 {
|
||||||
|
port_sender
|
||||||
|
.send(Err(anyhow!(
|
||||||
|
" hit EOF before reading port from process output."
|
||||||
|
)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if re.is_match(&line) {
|
||||||
|
let caps = re.captures(&line).unwrap();
|
||||||
|
port_sender.send(Ok(caps[1].to_string())).await.unwrap();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
loop {
|
||||||
|
line.clear();
|
||||||
|
match stdout.read_line(&mut line).await {
|
||||||
|
Err(_) => return,
|
||||||
|
Ok(0) => return,
|
||||||
|
Ok(_size) => {
|
||||||
|
if print_stdout {
|
||||||
|
print!("{}", line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
|
async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
|
||||||
let temp_dir = tempdir()?;
|
let temp_dir = tempdir()?;
|
||||||
let data_path = temp_dir.path().to_str().unwrap().to_string();
|
let data_path = temp_dir.path().to_str().unwrap().to_string();
|
||||||
@@ -57,38 +110,25 @@ async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
|
|||||||
child: child_result.unwrap(),
|
child: child_result.unwrap(),
|
||||||
};
|
};
|
||||||
let stdout = BufReader::new(process.child.stdout.take().unwrap());
|
let stdout = BufReader::new(process.child.stdout.take().unwrap());
|
||||||
let port = read_process_port(stdout)?;
|
let (port_sender, mut port_receiver) = mpsc::channel(5);
|
||||||
|
let _reader = spawn_stdout_reader(stdout, port_sender).await;
|
||||||
|
let port = match port_receiver.recv().await {
|
||||||
|
None => bail!("Unable to determine the port number used by the phalanx process we spawned, because the reader thread was closed too soon."),
|
||||||
|
Some(Err(err)) => bail!("Unable to determine the port number used by the phalanx process we spawned, because of an error, {}", err),
|
||||||
|
Some(Ok(port)) => port,
|
||||||
|
};
|
||||||
let uri = "db://test";
|
let uri = "db://test";
|
||||||
let host_override = format!("http://localhost:{}", port);
|
let host_override = format!("http://localhost:{}", port);
|
||||||
let connection = create_new_connection(uri, &host_override).await?;
|
let connection = create_new_connection(uri, &host_override).await?;
|
||||||
Ok(TestConnection {
|
Ok(TestConnection {
|
||||||
uri: uri.to_string(),
|
uri: uri.to_string(),
|
||||||
connection,
|
connection,
|
||||||
|
is_remote: true,
|
||||||
_temp_dir: Some(temp_dir),
|
_temp_dir: Some(temp_dir),
|
||||||
_process: Some(process),
|
_process: Some(process),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_process_port(mut stdout: BufReader<ChildStdout>) -> Result<String> {
|
|
||||||
let mut line = String::new();
|
|
||||||
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
|
|
||||||
loop {
|
|
||||||
let result = stdout.read_line(&mut line);
|
|
||||||
if let Err(err) = result {
|
|
||||||
bail!(format!(
|
|
||||||
"read_process_port: error while reading from process output: {}",
|
|
||||||
err
|
|
||||||
));
|
|
||||||
} else if result.unwrap() == 0 {
|
|
||||||
bail!("read_process_port: hit EOF before reading port from process output.");
|
|
||||||
}
|
|
||||||
if re.is_match(&line) {
|
|
||||||
let caps = re.captures(&line).unwrap();
|
|
||||||
return Ok(caps[1].to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "remote")]
|
#[cfg(feature = "remote")]
|
||||||
async fn create_new_connection(uri: &str, host_override: &str) -> crate::error::Result<Connection> {
|
async fn create_new_connection(uri: &str, host_override: &str) -> crate::error::Result<Connection> {
|
||||||
connect(uri)
|
connect(uri)
|
||||||
@@ -114,6 +154,7 @@ async fn new_local_connection() -> Result<TestConnection> {
|
|||||||
Ok(TestConnection {
|
Ok(TestConnection {
|
||||||
uri: uri.to_string(),
|
uri: uri.to_string(),
|
||||||
connection,
|
connection,
|
||||||
|
is_remote: false,
|
||||||
_temp_dir: Some(temp_dir),
|
_temp_dir: Some(temp_dir),
|
||||||
_process: None,
|
_process: None,
|
||||||
})
|
})
|
||||||
|
|||||||
253
rust/lancedb/tests/embeddings_parallel_test.rs
Normal file
253
rust/lancedb/tests/embeddings_parallel_test.rs
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
borrow::Cow,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicUsize, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use arrow::buffer::NullBuffer;
|
||||||
|
use arrow_array::{
|
||||||
|
Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
|
||||||
|
};
|
||||||
|
use arrow_schema::{DataType, Field, Schema};
|
||||||
|
use lancedb::{
|
||||||
|
embeddings::{EmbeddingDefinition, EmbeddingFunction, MaybeEmbedded, WithEmbeddings},
|
||||||
|
Error, Result,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct SlowMockEmbed {
|
||||||
|
name: String,
|
||||||
|
dim: usize,
|
||||||
|
delay_ms: u64,
|
||||||
|
call_count: Arc<AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SlowMockEmbed {
|
||||||
|
pub fn new(name: String, dim: usize, delay_ms: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
name,
|
||||||
|
dim,
|
||||||
|
delay_ms,
|
||||||
|
call_count: Arc::new(AtomicUsize::new(0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_call_count(&self) -> usize {
|
||||||
|
self.call_count.load(Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingFunction for SlowMockEmbed {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
&self.name
|
||||||
|
}
|
||||||
|
|
||||||
|
fn source_type(&self) -> Result<Cow<'_, DataType>> {
|
||||||
|
Ok(Cow::Owned(DataType::Utf8))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
|
||||||
|
Ok(Cow::Owned(DataType::new_fixed_size_list(
|
||||||
|
DataType::Float32,
|
||||||
|
self.dim as _,
|
||||||
|
true,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||||
|
// Simulate slow embedding computation
|
||||||
|
std::thread::sleep(Duration::from_millis(self.delay_ms));
|
||||||
|
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||||
|
|
||||||
|
let len = source.len();
|
||||||
|
let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim]));
|
||||||
|
let field = Field::new("item", inner.data_type().clone(), false);
|
||||||
|
let arr = FixedSizeListArray::new(
|
||||||
|
Arc::new(field),
|
||||||
|
self.dim as _,
|
||||||
|
inner,
|
||||||
|
Some(NullBuffer::new_valid(len)),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(Arc::new(arr))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_test_batch() -> Result<RecordBatch> {
|
||||||
|
let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
|
||||||
|
let text = StringArray::from(vec!["hello", "world"]);
|
||||||
|
RecordBatch::try_new(schema, vec![Arc::new(text)]).map_err(|e| Error::Runtime {
|
||||||
|
message: format!("Failed to create test batch: {}", e),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_single_embedding_fast_path() {
|
||||||
|
// Single embedding should execute without spawning threads
|
||||||
|
let batch = create_test_batch().unwrap();
|
||||||
|
let schema = batch.schema();
|
||||||
|
|
||||||
|
let embed = Arc::new(SlowMockEmbed::new("test".to_string(), 2, 10));
|
||||||
|
let embedding_def = EmbeddingDefinition::new("text", "test", Some("embedding"));
|
||||||
|
|
||||||
|
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||||
|
let embeddings = vec![(embedding_def, embed.clone() as Arc<dyn EmbeddingFunction>)];
|
||||||
|
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
|
||||||
|
|
||||||
|
let result = with_embeddings.next().unwrap().unwrap();
|
||||||
|
assert!(result.column_by_name("embedding").is_some());
|
||||||
|
assert_eq!(embed.get_call_count(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multiple_embeddings_parallel() {
|
||||||
|
// Multiple embeddings should execute in parallel
|
||||||
|
let batch = create_test_batch().unwrap();
|
||||||
|
let schema = batch.schema();
|
||||||
|
|
||||||
|
let embed1 = Arc::new(SlowMockEmbed::new("embed1".to_string(), 2, 100));
|
||||||
|
let embed2 = Arc::new(SlowMockEmbed::new("embed2".to_string(), 3, 100));
|
||||||
|
let embed3 = Arc::new(SlowMockEmbed::new("embed3".to_string(), 4, 100));
|
||||||
|
|
||||||
|
let def1 = EmbeddingDefinition::new("text", "embed1", Some("emb1"));
|
||||||
|
let def2 = EmbeddingDefinition::new("text", "embed2", Some("emb2"));
|
||||||
|
let def3 = EmbeddingDefinition::new("text", "embed3", Some("emb3"));
|
||||||
|
|
||||||
|
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||||
|
let embeddings = vec![
|
||||||
|
(def1, embed1.clone() as Arc<dyn EmbeddingFunction>),
|
||||||
|
(def2, embed2.clone() as Arc<dyn EmbeddingFunction>),
|
||||||
|
(def3, embed3.clone() as Arc<dyn EmbeddingFunction>),
|
||||||
|
];
|
||||||
|
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
|
||||||
|
|
||||||
|
let result = with_embeddings.next().unwrap().unwrap();
|
||||||
|
|
||||||
|
// Verify all embedding columns are present
|
||||||
|
assert!(result.column_by_name("emb1").is_some());
|
||||||
|
assert!(result.column_by_name("emb2").is_some());
|
||||||
|
assert!(result.column_by_name("emb3").is_some());
|
||||||
|
|
||||||
|
// Verify all embeddings were computed
|
||||||
|
assert_eq!(embed1.get_call_count(), 1);
|
||||||
|
assert_eq!(embed2.get_call_count(), 1);
|
||||||
|
assert_eq!(embed3.get_call_count(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_embedding_column_order_preserved() {
|
||||||
|
// Verify that embedding columns are added in the same order as definitions
|
||||||
|
let batch = create_test_batch().unwrap();
|
||||||
|
let schema = batch.schema();
|
||||||
|
|
||||||
|
let embed1 = Arc::new(SlowMockEmbed::new("embed1".to_string(), 2, 10));
|
||||||
|
let embed2 = Arc::new(SlowMockEmbed::new("embed2".to_string(), 3, 10));
|
||||||
|
let embed3 = Arc::new(SlowMockEmbed::new("embed3".to_string(), 4, 10));
|
||||||
|
|
||||||
|
let def1 = EmbeddingDefinition::new("text", "embed1", Some("first"));
|
||||||
|
let def2 = EmbeddingDefinition::new("text", "embed2", Some("second"));
|
||||||
|
let def3 = EmbeddingDefinition::new("text", "embed3", Some("third"));
|
||||||
|
|
||||||
|
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||||
|
let embeddings = vec![
|
||||||
|
(def1, embed1 as Arc<dyn EmbeddingFunction>),
|
||||||
|
(def2, embed2 as Arc<dyn EmbeddingFunction>),
|
||||||
|
(def3, embed3 as Arc<dyn EmbeddingFunction>),
|
||||||
|
];
|
||||||
|
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
|
||||||
|
|
||||||
|
let result = with_embeddings.next().unwrap().unwrap();
|
||||||
|
let result_schema = result.schema();
|
||||||
|
|
||||||
|
// Original column is first
|
||||||
|
assert_eq!(result_schema.field(0).name(), "text");
|
||||||
|
// Embedding columns follow in order
|
||||||
|
assert_eq!(result_schema.field(1).name(), "first");
|
||||||
|
assert_eq!(result_schema.field(2).name(), "second");
|
||||||
|
assert_eq!(result_schema.field(3).name(), "third");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_embedding_error_propagation() {
|
||||||
|
// Test that errors from embedding computation are properly propagated
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct FailingEmbed {
|
||||||
|
name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingFunction for FailingEmbed {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
&self.name
|
||||||
|
}
|
||||||
|
|
||||||
|
fn source_type(&self) -> Result<Cow<'_, DataType>> {
|
||||||
|
Ok(Cow::Owned(DataType::Utf8))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
|
||||||
|
Ok(Cow::Owned(DataType::new_fixed_size_list(
|
||||||
|
DataType::Float32,
|
||||||
|
2,
|
||||||
|
true,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_source_embeddings(&self, _source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||||
|
Err(Error::Runtime {
|
||||||
|
message: "Intentional failure".to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let batch = create_test_batch().unwrap();
|
||||||
|
let schema = batch.schema();
|
||||||
|
|
||||||
|
let embed = Arc::new(FailingEmbed {
|
||||||
|
name: "failing".to_string(),
|
||||||
|
});
|
||||||
|
let def = EmbeddingDefinition::new("text", "failing", Some("emb"));
|
||||||
|
|
||||||
|
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||||
|
let embeddings = vec![(def, embed as Arc<dyn EmbeddingFunction>)];
|
||||||
|
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
|
||||||
|
|
||||||
|
let result = with_embeddings.next().unwrap();
|
||||||
|
assert!(result.is_err());
|
||||||
|
let err_msg = format!("{}", result.err().unwrap());
|
||||||
|
assert!(err_msg.contains("Intentional failure"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_maybe_embedded_with_no_embeddings() {
|
||||||
|
// Test that MaybeEmbedded::No variant works correctly
|
||||||
|
let batch = create_test_batch().unwrap();
|
||||||
|
let schema = batch.schema();
|
||||||
|
|
||||||
|
let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone());
|
||||||
|
let table_def = lancedb::table::TableDefinition {
|
||||||
|
schema: schema.clone(),
|
||||||
|
column_definitions: vec![lancedb::table::ColumnDefinition {
|
||||||
|
kind: lancedb::table::ColumnKind::Physical,
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut maybe_embedded = MaybeEmbedded::try_new(reader, table_def, None).unwrap();
|
||||||
|
|
||||||
|
let result = maybe_embedded.next().unwrap().unwrap();
|
||||||
|
assert_eq!(result.num_columns(), 1);
|
||||||
|
assert_eq!(result.column(0).as_ref(), batch.column(0).as_ref());
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user