mirror of
https://github.com/lancedb/lancedb.git
synced 2026-03-26 02:20:40 +00:00
Compare commits
40 Commits
python-v0.
...
codex/upda
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
054d09abe9 | ||
|
|
ad51e2dd1f | ||
|
|
e9e904783c | ||
|
|
8500b16eca | ||
|
|
57e7282342 | ||
|
|
cc5f8070d7 | ||
|
|
dc0fb01f6b | ||
|
|
94b7781551 | ||
|
|
7bf020b3d5 | ||
|
|
12a98479dc | ||
|
|
e4552e577a | ||
|
|
f979a902ad | ||
|
|
5a7a8da567 | ||
|
|
0db8176445 | ||
|
|
bd84bba14d | ||
|
|
ac07f8068c | ||
|
|
bba362d372 | ||
|
|
042bc22468 | ||
|
|
68569906c6 | ||
|
|
c71c1fc822 | ||
|
|
4a6a0c856e | ||
|
|
f124c9d8d2 | ||
|
|
4e65748abf | ||
|
|
e897f3edab | ||
|
|
790ba7115b | ||
|
|
446a69b51b | ||
|
|
cd5f91bb7d | ||
|
|
4da01a0e65 | ||
|
|
1840aa7edc | ||
|
|
489c91c5d6 | ||
|
|
f0c3fe5c6d | ||
|
|
2f6d525802 | ||
|
|
4494eb9e56 | ||
|
|
d67a8743ba | ||
|
|
46fcbbc1e3 | ||
|
|
ff53b76ac0 | ||
|
|
2adb10e6a8 | ||
|
|
ac164c352b | ||
|
|
8bcac7e372 | ||
|
|
e496184ab2 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.23.1-beta.1"
|
||||
current_version = "0.24.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -75,6 +75,13 @@ jobs:
|
||||
VERSION="${VERSION#v}"
|
||||
BRANCH_NAME="codex/update-lance-${VERSION//[^a-zA-Z0-9]/-}"
|
||||
|
||||
# Use "chore" for beta/rc versions, "feat" for stable releases
|
||||
if [[ "${VERSION}" == *beta* ]] || [[ "${VERSION}" == *rc* ]]; then
|
||||
COMMIT_TYPE="chore"
|
||||
else
|
||||
COMMIT_TYPE="feat"
|
||||
fi
|
||||
|
||||
cat <<EOF >/tmp/codex-prompt.txt
|
||||
You are running inside the lancedb repository on a GitHub Actions runner. Update the Lance dependency to version ${VERSION} and prepare a pull request for maintainers to review.
|
||||
|
||||
@@ -84,10 +91,10 @@ jobs:
|
||||
3. After clippy succeeds, run "cargo fmt --all" to format the workspace.
|
||||
4. Ensure the repository is clean except for intentional changes. Inspect "git status --short" and "git diff" to confirm the dependency update and any required fixes.
|
||||
5. Create and switch to a new branch named "${BRANCH_NAME}" (replace any duplicated hyphens if necessary).
|
||||
6. Stage all relevant files with "git add -A". Commit using the message "chore: update lance dependency to v${VERSION}".
|
||||
6. Stage all relevant files with "git add -A". Commit using the message "${COMMIT_TYPE}: update lance dependency to v${VERSION}".
|
||||
7. Push the branch to origin. If the branch already exists, force-push your changes.
|
||||
8. env "GH_TOKEN" is available, use "gh" tools for github related operations like creating pull request.
|
||||
9. Create a pull request targeting "main" with title "chore: update lance dependency to v${VERSION}". In the body, summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}).
|
||||
9. Create a pull request targeting "main" with title "${COMMIT_TYPE}: update lance dependency to v${VERSION}". First, write the PR body to /tmp/pr-body.md using a heredoc (cat <<'EOF' > /tmp/pr-body.md). The body should summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}). Then run "gh pr create --body-file /tmp/pr-body.md".
|
||||
10. After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results.
|
||||
|
||||
Constraints:
|
||||
|
||||
10
.github/workflows/rust.yml
vendored
10
.github/workflows/rust.yml
vendored
@@ -48,6 +48,8 @@ jobs:
|
||||
run: cargo fmt --all -- --check
|
||||
- name: Run clippy
|
||||
run: cargo clippy --profile ci --workspace --tests --all-features -- -D warnings
|
||||
- name: Run clippy (without remote feature)
|
||||
run: cargo clippy --profile ci --workspace --tests -- -D warnings
|
||||
|
||||
build-no-lock:
|
||||
runs-on: ubuntu-24.04
|
||||
@@ -167,13 +169,13 @@ jobs:
|
||||
- name: Build
|
||||
run: |
|
||||
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
|
||||
cargo build --profile ci --features remote --tests --locked --target ${{ matrix.target }}
|
||||
cargo build --profile ci --features aws,remote --tests --locked --target ${{ matrix.target }}
|
||||
- name: Run tests
|
||||
# Can only run tests when target matches host
|
||||
if: ${{ matrix.target == 'x86_64-pc-windows-msvc' }}
|
||||
run: |
|
||||
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
|
||||
cargo test --profile ci --features remote --locked
|
||||
cargo test --profile ci --features aws,remote --locked
|
||||
|
||||
msrv:
|
||||
# Check the minimum supported Rust version
|
||||
@@ -181,7 +183,7 @@ jobs:
|
||||
runs-on: ubuntu-24.04
|
||||
strategy:
|
||||
matrix:
|
||||
msrv: ["1.78.0"] # This should match up with rust-version in Cargo.toml
|
||||
msrv: ["1.88.0"] # This should match up with rust-version in Cargo.toml
|
||||
env:
|
||||
# Need up-to-date compilers for kernels
|
||||
CC: clang-18
|
||||
@@ -212,4 +214,6 @@ jobs:
|
||||
cargo update -p aws-sdk-sts --precise 1.51.0
|
||||
cargo update -p home --precise 0.5.9
|
||||
- name: cargo +${{ matrix.msrv }} check
|
||||
env:
|
||||
RUSTUP_TOOLCHAIN: ${{ matrix.msrv }}
|
||||
run: cargo check --profile ci --workspace --tests --benches --all-features
|
||||
|
||||
2516
Cargo.lock
generated
2516
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
58
Cargo.toml
58
Cargo.toml
@@ -12,40 +12,40 @@ repository = "https://github.com/lancedb/lancedb"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||
categories = ["database-implementations"]
|
||||
rust-version = "1.78.0"
|
||||
rust-version = "1.88.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=1.0.1", default-features = false }
|
||||
lance-core = "=1.0.1"
|
||||
lance-datagen = "=1.0.1"
|
||||
lance-file = "=1.0.1"
|
||||
lance-io = { "version" = "=1.0.1", default-features = false }
|
||||
lance-index = "=1.0.1"
|
||||
lance-linalg = "=1.0.1"
|
||||
lance-namespace = "=1.0.1"
|
||||
lance-namespace-impls = { "version" = "=1.0.1", default-features = false }
|
||||
lance-table = "=1.0.1"
|
||||
lance-testing = "=1.0.1"
|
||||
lance-datafusion = "=1.0.1"
|
||||
lance-encoding = "=1.0.1"
|
||||
lance-arrow = "=1.0.1"
|
||||
lance = { "version" = "=2.0.0-rc.4", default-features = false, "tag" = "v2.0.0-rc.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=2.0.0-rc.4", "tag" = "v2.0.0-rc.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=2.0.0-rc.4", "tag" = "v2.0.0-rc.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=2.0.0-rc.4", "tag" = "v2.0.0-rc.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=2.0.0-rc.4", default-features = false, "tag" = "v2.0.0-rc.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=2.0.0-rc.4", "tag" = "v2.0.0-rc.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=2.0.0-rc.4", "tag" = "v2.0.0-rc.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=2.0.0-rc.4", "tag" = "v2.0.0-rc.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=2.0.0-rc.4", default-features = false, "tag" = "v2.0.0-rc.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=2.0.0-rc.4", "tag" = "v2.0.0-rc.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=2.0.0-rc.4", "tag" = "v2.0.0-rc.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=2.0.0-rc.4", "tag" = "v2.0.0-rc.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=2.0.0-rc.4", "tag" = "v2.0.0-rc.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=2.0.0-rc.4", "tag" = "v2.0.0-rc.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
ahash = "0.8"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "56.2", optional = false }
|
||||
arrow-array = "56.2"
|
||||
arrow-data = "56.2"
|
||||
arrow-ipc = "56.2"
|
||||
arrow-ord = "56.2"
|
||||
arrow-schema = "56.2"
|
||||
arrow-select = "56.2"
|
||||
arrow-cast = "56.2"
|
||||
arrow = { version = "57.2", optional = false }
|
||||
arrow-array = "57.2"
|
||||
arrow-data = "57.2"
|
||||
arrow-ipc = "57.2"
|
||||
arrow-ord = "57.2"
|
||||
arrow-schema = "57.2"
|
||||
arrow-select = "57.2"
|
||||
arrow-cast = "57.2"
|
||||
async-trait = "0"
|
||||
datafusion = { version = "50.1", default-features = false }
|
||||
datafusion-catalog = "50.1"
|
||||
datafusion-common = { version = "50.1", default-features = false }
|
||||
datafusion-execution = "50.1"
|
||||
datafusion-expr = "50.1"
|
||||
datafusion-physical-plan = "50.1"
|
||||
datafusion = { version = "51.0", default-features = false }
|
||||
datafusion-catalog = "51.0"
|
||||
datafusion-common = { version = "51.0", default-features = false }
|
||||
datafusion-execution = "51.0"
|
||||
datafusion-expr = "51.0"
|
||||
datafusion-physical-plan = "51.0"
|
||||
env_logger = "0.11"
|
||||
half = { "version" = "2.6.0", default-features = false, features = [
|
||||
"num-traits",
|
||||
|
||||
@@ -16,7 +16,7 @@ check_command_exists() {
|
||||
}
|
||||
|
||||
if [[ ! -e ./lancedb ]]; then
|
||||
if [[ -v SOPHON_READ_TOKEN ]]; then
|
||||
if [[ x${SOPHON_READ_TOKEN} != "x" ]]; then
|
||||
INPUT="lancedb-linux-x64"
|
||||
gh release \
|
||||
--repo lancedb/lancedb \
|
||||
|
||||
@@ -11,7 +11,7 @@ watch:
|
||||
theme:
|
||||
name: "material"
|
||||
logo: assets/logo.png
|
||||
favicon: assets/logo.png
|
||||
favicon: assets/favicon.ico
|
||||
palette:
|
||||
# Palette toggle for light mode
|
||||
- scheme: lancedb
|
||||
@@ -32,8 +32,6 @@ theme:
|
||||
- content.tooltips
|
||||
- toc.follow
|
||||
- navigation.top
|
||||
- navigation.tabs
|
||||
- navigation.tabs.sticky
|
||||
- navigation.footer
|
||||
- navigation.tracking
|
||||
- navigation.instant
|
||||
@@ -115,12 +113,13 @@ markdown_extensions:
|
||||
emoji_index: !!python/name:material.extensions.emoji.twemoji
|
||||
emoji_generator: !!python/name:material.extensions.emoji.to_svg
|
||||
- markdown.extensions.toc:
|
||||
baselevel: 1
|
||||
permalink: ""
|
||||
toc_depth: 3
|
||||
permalink: true
|
||||
permalink_title: Anchor link to this section
|
||||
|
||||
nav:
|
||||
- API reference:
|
||||
- Overview: index.md
|
||||
- Documentation:
|
||||
- SDK Reference: index.md
|
||||
- Python: python/python.md
|
||||
- Javascript/TypeScript: js/globals.md
|
||||
- Java: java/java.md
|
||||
|
||||
BIN
docs/src/assets/favicon.ico
Normal file
BIN
docs/src/assets/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 15 KiB |
@@ -0,0 +1,111 @@
|
||||
# VoyageAI Embeddings : Multimodal
|
||||
|
||||
VoyageAI embeddings can also be used to embed both text and image data, only some of the models support image data and you can check the list
|
||||
under [https://docs.voyageai.com/docs/multimodal-embeddings](https://docs.voyageai.com/docs/multimodal-embeddings)
|
||||
|
||||
Supported multimodal models:
|
||||
|
||||
- `voyage-multimodal-3` - 1024 dimensions (text + images)
|
||||
- `voyage-multimodal-3.5` - Flexible dimensions (256, 512, 1024 default, 2048). Supports text, images, and video.
|
||||
|
||||
### Video Support (voyage-multimodal-3.5)
|
||||
|
||||
The `voyage-multimodal-3.5` model supports video input through:
|
||||
- Video URLs (`.mp4`, `.webm`, `.mov`, `.avi`, `.mkv`, `.m4v`, `.gif`)
|
||||
- Video file paths
|
||||
|
||||
Constraints: Max 20MB video size.
|
||||
|
||||
Supported parameters (to be passed in `create` method) are:
|
||||
|
||||
| Parameter | Type | Default Value | Description |
|
||||
|---|---|-------------------------|-------------------------------------------|
|
||||
| `name` | `str` | `"voyage-multimodal-3"` | The model ID of the VoyageAI model to use |
|
||||
| `output_dimension` | `int` | `None` | Output dimension for voyage-multimodal-3.5. Valid: 256, 512, 1024, 2048 |
|
||||
|
||||
Usage Example:
|
||||
|
||||
```python
|
||||
import base64
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
import pandas as pd
|
||||
|
||||
os.environ['VOYAGE_API_KEY'] = 'YOUR_VOYAGE_API_KEY'
|
||||
|
||||
db = lancedb.connect(".lancedb")
|
||||
func = get_registry().get("voyageai").create(name="voyage-multimodal-3")
|
||||
|
||||
|
||||
def image_to_base64(image_bytes: bytes):
|
||||
buffered = BytesIO(image_bytes)
|
||||
img_str = base64.b64encode(buffered.getvalue())
|
||||
return img_str.decode("utf-8")
|
||||
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = func.SourceField() # image uri as the source
|
||||
image_bytes: str = func.SourceField() # image bytes base64 encoded as the source
|
||||
vector: Vector(func.ndims()) = func.VectorField() # vector column
|
||||
vec_from_bytes: Vector(func.ndims()) = func.VectorField() # Another vector column
|
||||
|
||||
|
||||
if "images" in db.table_names():
|
||||
db.drop_table("images")
|
||||
table = db.create_table("images", schema=Images)
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
# get each uri as bytes
|
||||
images_bytes = [image_to_base64(requests.get(uri).content) for uri in uris]
|
||||
table.add(
|
||||
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": images_bytes})
|
||||
)
|
||||
```
|
||||
Now we can search using text from both the default vector column and the custom vector column
|
||||
```python
|
||||
|
||||
# text search
|
||||
actual = table.search("man's best friend", "vec_from_bytes").limit(1).to_pydantic(Images)[0]
|
||||
print(actual.label) # prints "dog"
|
||||
|
||||
frombytes = (
|
||||
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
print(frombytes.label)
|
||||
|
||||
```
|
||||
|
||||
Because we're using a multi-modal embedding function, we can also search using images
|
||||
|
||||
```python
|
||||
# image search
|
||||
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||
image_bytes = requests.get(query_image_uri).content
|
||||
query_image = Image.open(BytesIO(image_bytes))
|
||||
actual = table.search(query_image, "vec_from_bytes").limit(1).to_pydantic(Images)[0]
|
||||
print(actual.label == "dog")
|
||||
|
||||
# image search using a custom vector column
|
||||
other = (
|
||||
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
print(actual.label)
|
||||
|
||||
```
|
||||
@@ -1,8 +1,12 @@
|
||||
# API Reference
|
||||
# SDK Reference
|
||||
|
||||
This page contains the API reference for the SDKs supported by the LanceDB team.
|
||||
This site contains the API reference for the client SDKs supported by [LanceDB](https://lancedb.com).
|
||||
|
||||
- [Python](python/python.md)
|
||||
- [JavaScript/TypeScript](js/globals.md)
|
||||
- [Java](java/java.md)
|
||||
- [Rust](https://docs.rs/lancedb/latest/lancedb/index.html)
|
||||
- [Rust](https://docs.rs/lancedb/latest/lancedb/index.html)
|
||||
|
||||
!!! info "LanceDB Documentation"
|
||||
|
||||
If you're looking for the full documentation of LanceDB, visit [docs.lancedb.com](https://docs.lancedb.com).
|
||||
|
||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
||||
<dependency>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-core</artifactId>
|
||||
<version>0.23.1-beta.1</version>
|
||||
<version>0.24.1</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
|
||||
@@ -85,17 +85,26 @@
|
||||
|
||||
/* Header gradient (only header area) */
|
||||
.md-header {
|
||||
background: linear-gradient(90deg, #3B2E58 0%, #F0B7C1 45%, #E55A2B 100%);
|
||||
background: linear-gradient(90deg, #e4d8f8 0%, #F0B7C1 45%, #E55A2B 100%);
|
||||
box-shadow: inset 0 1px 0 rgba(255,255,255,0.08), 0 1px 0 rgba(0,0,0,0.08);
|
||||
}
|
||||
|
||||
/* Improve brand title contrast on the lavender side */
|
||||
.md-header__title,
|
||||
.md-header__topic,
|
||||
.md-header__title .md-ellipsis,
|
||||
.md-header__topic .md-ellipsis {
|
||||
color: #2b1b3a;
|
||||
text-shadow: 0 1px 0 rgba(255, 255, 255, 0.25);
|
||||
}
|
||||
|
||||
/* Same colors as header for tabs (that hold the text) */
|
||||
.md-tabs {
|
||||
background: linear-gradient(90deg, #3B2E58 0%, #F0B7C1 45%, #E55A2B 100%);
|
||||
background: linear-gradient(90deg, #e4d8f8 0%, #F0B7C1 45%, #E55A2B 100%);
|
||||
}
|
||||
|
||||
/* Dark scheme variant */
|
||||
[data-md-color-scheme="slate"] .md-header,
|
||||
[data-md-color-scheme="slate"] .md-tabs {
|
||||
background: linear-gradient(90deg, #3B2E58 0%, #F0B7C1 45%, #E55A2B 100%);
|
||||
background: linear-gradient(90deg, #e4d8f8 0%, #F0B7C1 45%, #E55A2B 100%);
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.23.1-beta.1</version>
|
||||
<version>0.24.1-final.0</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.23.1-beta.1</version>
|
||||
<version>0.24.1-final.0</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.23.1-beta.1"
|
||||
version = "0.24.1"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
@@ -36,6 +36,6 @@ aws-lc-rs = "=1.13.0"
|
||||
napi-build = "2.1"
|
||||
|
||||
[features]
|
||||
default = ["remote", "lancedb/default"]
|
||||
default = ["remote", "lancedb/aws", "lancedb/gcs", "lancedb/azure", "lancedb/dynamodb", "lancedb/oss", "lancedb/huggingface"]
|
||||
fp16kernels = ["lancedb/fp16kernels"]
|
||||
remote = ["lancedb/remote"]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.23.1-beta.1",
|
||||
"version": "0.24.1",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.23.1-beta.1",
|
||||
"version": "0.24.1",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.darwin-x64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.23.1-beta.1",
|
||||
"version": "0.24.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.23.1-beta.1",
|
||||
"version": "0.24.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.23.1-beta.1",
|
||||
"version": "0.24.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.23.1-beta.1",
|
||||
"version": "0.24.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.23.1-beta.1",
|
||||
"version": "0.24.1",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.23.1-beta.1",
|
||||
"version": "0.24.1",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.23.1-beta.1",
|
||||
"version": "0.24.1",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.23.1-beta.1",
|
||||
"version": "0.24.1",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.23.1-beta.1",
|
||||
"version": "0.24.1",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.26.1"
|
||||
current_version = "0.27.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.26.1"
|
||||
version = "0.27.1"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
rust-version = "1.75.0"
|
||||
rust-version = "1.88.0"
|
||||
|
||||
[lib]
|
||||
name = "_lancedb"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
arrow = { version = "56.2", features = ["pyarrow"] }
|
||||
arrow = { version = "57.2", features = ["pyarrow"] }
|
||||
async-trait = "0.1"
|
||||
lancedb = { path = "../rust/lancedb", default-features = false }
|
||||
lance-core.workspace = true
|
||||
lance-namespace.workspace = true
|
||||
lance-io.workspace = true
|
||||
env_logger.workspace = true
|
||||
pyo3 = { version = "0.25", features = ["extension-module", "abi3-py39"] }
|
||||
pyo3-async-runtimes = { version = "0.25", features = [
|
||||
pyo3 = { version = "0.26", features = ["extension-module", "abi3-py39"] }
|
||||
pyo3-async-runtimes = { version = "0.26", features = [
|
||||
"attributes",
|
||||
"tokio-runtime",
|
||||
] }
|
||||
@@ -32,12 +32,12 @@ snafu.workspace = true
|
||||
tokio = { version = "1.40", features = ["sync"] }
|
||||
|
||||
[build-dependencies]
|
||||
pyo3-build-config = { version = "0.25", features = [
|
||||
pyo3-build-config = { version = "0.26", features = [
|
||||
"extension-module",
|
||||
"abi3-py39",
|
||||
] }
|
||||
|
||||
[features]
|
||||
default = ["remote", "lancedb/default"]
|
||||
default = ["remote", "lancedb/aws", "lancedb/gcs", "lancedb/azure", "lancedb/dynamodb", "lancedb/oss", "lancedb/huggingface"]
|
||||
fp16kernels = ["lancedb/fp16kernels"]
|
||||
remote = ["lancedb/remote"]
|
||||
|
||||
@@ -13,6 +13,7 @@ __version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
from .common import URI, sanitize_uri
|
||||
from urllib.parse import urlparse
|
||||
from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||
from .io import StorageOptionsProvider
|
||||
from .remote import ClientConfig
|
||||
@@ -28,6 +29,39 @@ from .namespace import (
|
||||
)
|
||||
|
||||
|
||||
def _check_s3_bucket_with_dots(
|
||||
uri: str, storage_options: Optional[Dict[str, str]]
|
||||
) -> None:
|
||||
"""
|
||||
Check if an S3 URI has a bucket name containing dots and warn if no region
|
||||
is specified. S3 buckets with dots cannot use virtual-hosted-style URLs,
|
||||
which breaks automatic region detection.
|
||||
|
||||
See: https://github.com/lancedb/lancedb/issues/1898
|
||||
"""
|
||||
if not isinstance(uri, str) or not uri.startswith("s3://"):
|
||||
return
|
||||
|
||||
parsed = urlparse(uri)
|
||||
bucket = parsed.netloc
|
||||
|
||||
if "." not in bucket:
|
||||
return
|
||||
|
||||
# Check if region is provided in storage_options
|
||||
region_keys = {"region", "aws_region"}
|
||||
has_region = storage_options and any(k in storage_options for k in region_keys)
|
||||
|
||||
if not has_region:
|
||||
raise ValueError(
|
||||
f"S3 bucket name '{bucket}' contains dots, which prevents automatic "
|
||||
f"region detection. Please specify the region explicitly via "
|
||||
f"storage_options={{'region': '<your-region>'}} or "
|
||||
f"storage_options={{'aws_region': '<your-region>'}}. "
|
||||
f"See https://github.com/lancedb/lancedb/issues/1898 for details."
|
||||
)
|
||||
|
||||
|
||||
def connect(
|
||||
uri: URI,
|
||||
*,
|
||||
@@ -121,9 +155,11 @@ def connect(
|
||||
storage_options=storage_options,
|
||||
**kwargs,
|
||||
)
|
||||
_check_s3_bucket_with_dots(str(uri), storage_options)
|
||||
|
||||
if kwargs:
|
||||
raise ValueError(f"Unknown keyword arguments: {kwargs}")
|
||||
|
||||
return LanceDBConnection(
|
||||
uri,
|
||||
read_consistency_interval=read_consistency_interval,
|
||||
@@ -211,6 +247,8 @@ async def connect_async(
|
||||
if isinstance(client_config, dict):
|
||||
client_config = ClientConfig(**client_config)
|
||||
|
||||
_check_s3_bucket_with_dots(str(uri), storage_options)
|
||||
|
||||
return AsyncConnection(
|
||||
await lancedb_connect(
|
||||
sanitize_uri(uri),
|
||||
|
||||
@@ -179,6 +179,7 @@ class Table:
|
||||
cleanup_since_ms: Optional[int] = None,
|
||||
delete_unverified: Optional[bool] = None,
|
||||
) -> OptimizeStats: ...
|
||||
async def uri(self) -> str: ...
|
||||
@property
|
||||
def tags(self) -> Tags: ...
|
||||
def query(self) -> Query: ...
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
import base64
|
||||
import os
|
||||
from typing import ClassVar, TYPE_CHECKING, List, Union, Any, Generator
|
||||
from typing import ClassVar, TYPE_CHECKING, List, Union, Any, Generator, Optional
|
||||
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
@@ -45,11 +45,29 @@ def is_valid_url(text):
|
||||
return False
|
||||
|
||||
|
||||
VIDEO_EXTENSIONS = {".mp4", ".webm", ".mov", ".avi", ".mkv", ".m4v", ".gif"}
|
||||
|
||||
|
||||
def is_video_url(url: str) -> bool:
|
||||
"""Check if URL points to a video file based on extension."""
|
||||
parsed = urlparse(url)
|
||||
path = parsed.path.lower()
|
||||
return any(path.endswith(ext) for ext in VIDEO_EXTENSIONS)
|
||||
|
||||
|
||||
def is_video_path(path: Path) -> bool:
|
||||
"""Check if file path is a video file based on extension."""
|
||||
return path.suffix.lower() in VIDEO_EXTENSIONS
|
||||
|
||||
|
||||
def transform_input(input_data: Union[str, bytes, Path]):
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(input_data, str):
|
||||
if is_valid_url(input_data):
|
||||
content = {"type": "image_url", "image_url": input_data}
|
||||
if is_video_url(input_data):
|
||||
content = {"type": "video_url", "video_url": input_data}
|
||||
else:
|
||||
content = {"type": "image_url", "image_url": input_data}
|
||||
else:
|
||||
content = {"type": "text", "text": input_data}
|
||||
elif isinstance(input_data, PIL.Image.Image):
|
||||
@@ -70,14 +88,24 @@ def transform_input(input_data: Union[str, bytes, Path]):
|
||||
"image_base64": "data:image/jpeg;base64," + img_str,
|
||||
}
|
||||
elif isinstance(input_data, Path):
|
||||
img = PIL.Image.open(input_data)
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
content = {
|
||||
"type": "image_base64",
|
||||
"image_base64": "data:image/jpeg;base64," + img_str,
|
||||
}
|
||||
if is_video_path(input_data):
|
||||
# Read video file and encode as base64
|
||||
with open(input_data, "rb") as f:
|
||||
video_bytes = f.read()
|
||||
video_str = base64.b64encode(video_bytes).decode("utf-8")
|
||||
content = {
|
||||
"type": "video_base64",
|
||||
"video_base64": video_str,
|
||||
}
|
||||
else:
|
||||
img = PIL.Image.open(input_data)
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
content = {
|
||||
"type": "image_base64",
|
||||
"image_base64": "data:image/jpeg;base64," + img_str,
|
||||
}
|
||||
else:
|
||||
raise ValueError("Each input should be either str, bytes, Path or Image.")
|
||||
|
||||
@@ -91,6 +119,8 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)):
|
||||
inputs = [inputs]
|
||||
elif isinstance(inputs, list):
|
||||
pass # Already a list, use as-is
|
||||
elif isinstance(inputs, pa.Array):
|
||||
inputs = inputs.to_pylist()
|
||||
elif isinstance(inputs, pa.ChunkedArray):
|
||||
@@ -143,11 +173,16 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
* voyage-3
|
||||
* voyage-3-lite
|
||||
* voyage-multimodal-3
|
||||
* voyage-multimodal-3.5
|
||||
* voyage-finance-2
|
||||
* voyage-multilingual-2
|
||||
* voyage-law-2
|
||||
* voyage-code-2
|
||||
|
||||
output_dimension: int, optional
|
||||
The output dimension for models that support flexible dimensions.
|
||||
Currently only voyage-multimodal-3.5 supports this feature.
|
||||
Valid options: 256, 512, 1024 (default), 2048.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -175,7 +210,10 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
|
||||
name: str
|
||||
output_dimension: Optional[int] = None
|
||||
client: ClassVar = None
|
||||
_FLEXIBLE_DIM_MODELS: ClassVar[list] = ["voyage-multimodal-3.5"]
|
||||
_VALID_DIMENSIONS: ClassVar[list] = [256, 512, 1024, 2048]
|
||||
text_embedding_models: list = [
|
||||
"voyage-3.5",
|
||||
"voyage-3.5-lite",
|
||||
@@ -186,7 +224,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
"voyage-law-2",
|
||||
"voyage-code-2",
|
||||
]
|
||||
multimodal_embedding_models: list = ["voyage-multimodal-3"]
|
||||
multimodal_embedding_models: list = ["voyage-multimodal-3", "voyage-multimodal-3.5"]
|
||||
contextual_embedding_models: list = ["voyage-context-3"]
|
||||
|
||||
def _is_multimodal_model(self, model_name: str):
|
||||
@@ -198,6 +236,17 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
return model_name in self.contextual_embedding_models or "context" in model_name
|
||||
|
||||
def ndims(self):
|
||||
# Handle flexible dimension models
|
||||
if self.name in self._FLEXIBLE_DIM_MODELS:
|
||||
if self.output_dimension is not None:
|
||||
if self.output_dimension not in self._VALID_DIMENSIONS:
|
||||
raise ValueError(
|
||||
f"Invalid output_dimension {self.output_dimension} "
|
||||
f"for {self.name}. Valid options: {self._VALID_DIMENSIONS}"
|
||||
)
|
||||
return self.output_dimension
|
||||
return 1024 # default dimension
|
||||
|
||||
if self.name == "voyage-3-lite":
|
||||
return 512
|
||||
elif self.name == "voyage-code-2":
|
||||
@@ -211,12 +260,17 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
"voyage-finance-2",
|
||||
"voyage-multilingual-2",
|
||||
"voyage-law-2",
|
||||
"voyage-multimodal-3",
|
||||
]:
|
||||
return 1024
|
||||
else:
|
||||
raise ValueError(f"Model {self.name} not supported")
|
||||
|
||||
def _get_multimodal_kwargs(self, **kwargs):
|
||||
"""Get kwargs for multimodal embed call, including output_dimension if set."""
|
||||
if self.name in self._FLEXIBLE_DIM_MODELS and self.output_dimension is not None:
|
||||
kwargs["output_dimension"] = self.output_dimension
|
||||
return kwargs
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
@@ -234,6 +288,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
client = VoyageAIEmbeddingFunction._get_client()
|
||||
if self._is_multimodal_model(self.name):
|
||||
kwargs = self._get_multimodal_kwargs(**kwargs)
|
||||
result = client.multimodal_embed(
|
||||
inputs=[[query]], model=self.name, input_type="query", **kwargs
|
||||
)
|
||||
@@ -275,6 +330,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
)
|
||||
if has_images:
|
||||
# Use non-batched API for images
|
||||
kwargs = self._get_multimodal_kwargs(**kwargs)
|
||||
result = client.multimodal_embed(
|
||||
inputs=sanitized, model=self.name, input_type="document", **kwargs
|
||||
)
|
||||
@@ -357,6 +413,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
callable: A function that takes a batch of texts and returns embeddings.
|
||||
"""
|
||||
if self._is_multimodal_model(self.name):
|
||||
multimodal_kwargs = self._get_multimodal_kwargs(**kwargs)
|
||||
|
||||
def embed_batch(batch: List[str]) -> List[np.array]:
|
||||
batch_inputs = sanitize_multimodal_input(batch)
|
||||
@@ -364,7 +421,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
inputs=batch_inputs,
|
||||
model=self.name,
|
||||
input_type=input_type,
|
||||
**kwargs,
|
||||
**multimodal_kwargs,
|
||||
)
|
||||
return result.embeddings
|
||||
|
||||
|
||||
@@ -275,7 +275,7 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
|
||||
return pa.timestamp("us", tz=tz)
|
||||
elif getattr(py_type, "__origin__", None) in (list, tuple):
|
||||
child = py_type.__args__[0]
|
||||
return pa.list_(_py_type_to_arrow_type(child, field))
|
||||
return _pydantic_list_child_to_arrow(child, field)
|
||||
raise TypeError(
|
||||
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
|
||||
)
|
||||
@@ -298,12 +298,18 @@ else:
|
||||
|
||||
|
||||
def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
|
||||
def _safe_issubclass(candidate: Any, base: type) -> bool:
|
||||
try:
|
||||
return issubclass(candidate, base)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
if inspect.isclass(tp):
|
||||
if issubclass(tp, pydantic.BaseModel):
|
||||
if _safe_issubclass(tp, pydantic.BaseModel):
|
||||
# Struct
|
||||
fields = _pydantic_model_to_fields(tp)
|
||||
return pa.struct(fields)
|
||||
if issubclass(tp, FixedSizeListMixin):
|
||||
if _safe_issubclass(tp, FixedSizeListMixin):
|
||||
if getattr(tp, "is_multi_vector", lambda: False)():
|
||||
return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim()))
|
||||
# For regular Vector
|
||||
@@ -311,45 +317,67 @@ def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
|
||||
return _py_type_to_arrow_type(tp, field)
|
||||
|
||||
|
||||
def _pydantic_list_child_to_arrow(child: Any, field: FieldInfo) -> pa.DataType:
|
||||
unwrapped = _unwrap_optional_annotation(child)
|
||||
if unwrapped is not None:
|
||||
return pa.list_(
|
||||
pa.field("item", _pydantic_type_to_arrow_type(unwrapped, field), True)
|
||||
)
|
||||
return pa.list_(_pydantic_type_to_arrow_type(child, field))
|
||||
|
||||
|
||||
def _unwrap_optional_annotation(annotation: Any) -> Any | None:
|
||||
if isinstance(annotation, (_GenericAlias, GenericAlias)):
|
||||
origin = annotation.__origin__
|
||||
args = annotation.__args__
|
||||
if origin == Union:
|
||||
non_none = [arg for arg in args if arg is not type(None)]
|
||||
if len(non_none) == 1 and len(non_none) != len(args):
|
||||
return non_none[0]
|
||||
elif sys.version_info >= (3, 10) and isinstance(annotation, types.UnionType):
|
||||
args = annotation.__args__
|
||||
non_none = [arg for arg in args if arg is not type(None)]
|
||||
if len(non_none) == 1 and len(non_none) != len(args):
|
||||
return non_none[0]
|
||||
return None
|
||||
|
||||
|
||||
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
||||
"""Convert a Pydantic FieldInfo to Arrow DataType"""
|
||||
unwrapped = _unwrap_optional_annotation(field.annotation)
|
||||
if unwrapped is not None:
|
||||
return _pydantic_type_to_arrow_type(unwrapped, field)
|
||||
if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
|
||||
origin = field.annotation.__origin__
|
||||
args = field.annotation.__args__
|
||||
|
||||
if origin is list:
|
||||
child = args[0]
|
||||
return pa.list_(_py_type_to_arrow_type(child, field))
|
||||
elif origin == Union:
|
||||
if len(args) == 2 and args[1] is type(None):
|
||||
return _pydantic_type_to_arrow_type(args[0], field)
|
||||
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
|
||||
args = field.annotation.__args__
|
||||
if len(args) == 2:
|
||||
for typ in args:
|
||||
if typ is type(None):
|
||||
continue
|
||||
return _py_type_to_arrow_type(typ, field)
|
||||
return _pydantic_list_child_to_arrow(child, field)
|
||||
return _pydantic_type_to_arrow_type(field.annotation, field)
|
||||
|
||||
|
||||
def is_nullable(field: FieldInfo) -> bool:
|
||||
"""Check if a Pydantic FieldInfo is nullable."""
|
||||
if _unwrap_optional_annotation(field.annotation) is not None:
|
||||
return True
|
||||
if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
|
||||
origin = field.annotation.__origin__
|
||||
args = field.annotation.__args__
|
||||
if origin == Union:
|
||||
if len(args) == 2 and args[1] is type(None):
|
||||
if any(typ is type(None) for typ in args):
|
||||
return True
|
||||
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
|
||||
args = field.annotation.__args__
|
||||
for typ in args:
|
||||
if typ is type(None):
|
||||
return True
|
||||
elif inspect.isclass(field.annotation) and issubclass(
|
||||
field.annotation, FixedSizeListMixin
|
||||
):
|
||||
return field.annotation.nullable()
|
||||
elif inspect.isclass(field.annotation):
|
||||
try:
|
||||
if issubclass(field.annotation, FixedSizeListMixin):
|
||||
return field.annotation.nullable()
|
||||
except TypeError:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -384,6 +384,7 @@ class RemoteDBConnection(DBConnection):
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
mode: Optional[str] = None,
|
||||
exist_ok: bool = False,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
*,
|
||||
namespace: Optional[List[str]] = None,
|
||||
@@ -412,6 +413,12 @@ class RemoteDBConnection(DBConnection):
|
||||
- pyarrow.Schema
|
||||
|
||||
- [LanceModel][lancedb.pydantic.LanceModel]
|
||||
mode: str, default "create"
|
||||
The mode to use when creating the table.
|
||||
Can be either "create", "overwrite", or "exist_ok".
|
||||
exist_ok: bool, default False
|
||||
If exist_ok is True, and mode is None or "create", mode will be changed
|
||||
to "exist_ok".
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
@@ -483,6 +490,11 @@ class RemoteDBConnection(DBConnection):
|
||||
LanceTable(table4)
|
||||
|
||||
"""
|
||||
if exist_ok:
|
||||
if mode == "create":
|
||||
mode = "exist_ok"
|
||||
elif not mode:
|
||||
mode = "exist_ok"
|
||||
if namespace is None:
|
||||
namespace = []
|
||||
validate_table_name(name)
|
||||
|
||||
@@ -18,7 +18,17 @@ from lancedb._lancedb import (
|
||||
UpdateResult,
|
||||
)
|
||||
from lancedb.embeddings.base import EmbeddingFunctionConfig
|
||||
from lancedb.index import FTS, BTree, Bitmap, HnswSq, IvfFlat, IvfPq, IvfSq, LabelList
|
||||
from lancedb.index import (
|
||||
FTS,
|
||||
BTree,
|
||||
Bitmap,
|
||||
HnswSq,
|
||||
IvfFlat,
|
||||
IvfPq,
|
||||
IvfRq,
|
||||
IvfSq,
|
||||
LabelList,
|
||||
)
|
||||
from lancedb.remote.db import LOOP
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -265,6 +275,12 @@ class RemoteTable(Table):
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
num_bits=num_bits,
|
||||
)
|
||||
elif index_type == "IVF_RQ":
|
||||
config = IvfRq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_bits=num_bits,
|
||||
)
|
||||
elif index_type == "IVF_SQ":
|
||||
config = IvfSq(distance_type=metric, num_partitions=num_partitions)
|
||||
elif index_type == "IVF_HNSW_PQ":
|
||||
@@ -279,7 +295,8 @@ class RemoteTable(Table):
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown vector index type: {index_type}. Valid options are"
|
||||
" 'IVF_FLAT', 'IVF_SQ', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
||||
" 'IVF_FLAT', 'IVF_PQ', 'IVF_RQ', 'IVF_SQ',"
|
||||
" 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
||||
)
|
||||
|
||||
LOOP.run(
|
||||
@@ -638,6 +655,14 @@ class RemoteTable(Table):
|
||||
def stats(self):
|
||||
return LOOP.run(self._table.stats())
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
"""The table URI (storage location).
|
||||
|
||||
For remote tables, this fetches the location from the server via describe.
|
||||
"""
|
||||
return LOOP.run(self._table.uri())
|
||||
|
||||
def take_offsets(self, offsets: list[int]) -> LanceTakeQueryBuilder:
|
||||
return LanceTakeQueryBuilder(self._table.take_offsets(offsets))
|
||||
|
||||
|
||||
@@ -2218,6 +2218,10 @@ class LanceTable(Table):
|
||||
def stats(self) -> TableStatistics:
|
||||
return LOOP.run(self._table.stats())
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return LOOP.run(self._table.uri())
|
||||
|
||||
def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
@@ -3606,6 +3610,20 @@ class AsyncTable:
|
||||
"""
|
||||
return await self._inner.stats()
|
||||
|
||||
async def uri(self) -> str:
|
||||
"""
|
||||
Get the table URI (storage location).
|
||||
|
||||
For remote tables, this fetches the location from the server via describe.
|
||||
For local tables, this returns the dataset URI.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The full storage location of the table (e.g., S3/GCS path).
|
||||
"""
|
||||
return await self._inner.uri()
|
||||
|
||||
async def add(
|
||||
self,
|
||||
data: DATA,
|
||||
|
||||
@@ -2,12 +2,27 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from lancedb.db import AsyncConnection, DBConnection
|
||||
import lancedb
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
|
||||
def pandas_string_type():
|
||||
"""Return the PyArrow string type that pandas uses for string columns.
|
||||
|
||||
pandas 3.0+ uses large_string for string columns, pandas 2.x uses string.
|
||||
"""
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
|
||||
version = tuple(int(x) for x in pd.__version__.split(".")[:2])
|
||||
if version >= (3, 0):
|
||||
return pa.large_utf8()
|
||||
return pa.utf8()
|
||||
|
||||
|
||||
# Use an in-memory database for most tests.
|
||||
@pytest.fixture
|
||||
def mem_db() -> DBConnection:
|
||||
|
||||
@@ -268,6 +268,8 @@ async def test_create_table_from_iterator_async(mem_db_async: lancedb.AsyncConne
|
||||
|
||||
|
||||
def test_create_exist_ok(tmp_db: lancedb.DBConnection):
|
||||
from conftest import pandas_string_type
|
||||
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -286,10 +288,11 @@ def test_create_exist_ok(tmp_db: lancedb.DBConnection):
|
||||
assert tbl.schema == tbl2.schema
|
||||
assert len(tbl) == len(tbl2)
|
||||
|
||||
# pandas 3.0+ uses large_string, pandas 2.x uses string
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
|
||||
pa.field("item", pa.utf8()),
|
||||
pa.field("item", pandas_string_type()),
|
||||
pa.field("price", pa.float64()),
|
||||
]
|
||||
)
|
||||
@@ -299,7 +302,7 @@ def test_create_exist_ok(tmp_db: lancedb.DBConnection):
|
||||
bad_schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
|
||||
pa.field("item", pa.utf8()),
|
||||
pa.field("item", pandas_string_type()),
|
||||
pa.field("price", pa.float64()),
|
||||
pa.field("extra", pa.float32()),
|
||||
]
|
||||
@@ -365,6 +368,8 @@ async def test_create_mode_async(tmp_db_async: lancedb.AsyncConnection):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection):
|
||||
from conftest import pandas_string_type
|
||||
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -382,10 +387,11 @@ async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection):
|
||||
assert tbl.name == tbl2.name
|
||||
assert await tbl.schema() == await tbl2.schema()
|
||||
|
||||
# pandas 3.0+ uses large_string, pandas 2.x uses string
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
|
||||
pa.field("item", pa.utf8()),
|
||||
pa.field("item", pandas_string_type()),
|
||||
pa.field("price", pa.float64()),
|
||||
]
|
||||
)
|
||||
@@ -595,6 +601,8 @@ def test_open_table_sync(tmp_db: lancedb.DBConnection):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_table(tmp_path):
|
||||
from conftest import pandas_string_type
|
||||
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
@@ -614,10 +622,11 @@ async def test_open_table(tmp_path):
|
||||
)
|
||||
is not None
|
||||
)
|
||||
# pandas 3.0+ uses large_string, pandas 2.x uses string
|
||||
assert await tbl.schema() == pa.schema(
|
||||
{
|
||||
"vector": pa.list_(pa.float32(), list_size=2),
|
||||
"item": pa.utf8(),
|
||||
"item": pandas_string_type(),
|
||||
"price": pa.float64(),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -613,6 +613,133 @@ def test_voyageai_multimodal_embedding_text_function():
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_multimodal_35_embedding_function():
|
||||
"""Test voyage-multimodal-3.5 model with text input."""
|
||||
voyageai = (
|
||||
get_registry()
|
||||
.get("voyageai")
|
||||
.create(name="voyage-multimodal-3.5", max_retries=0)
|
||||
)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect("~/lancedb")
|
||||
tbl = db.create_table("test_multimodal_35", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
assert voyageai.ndims() == 1024
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_multimodal_35_flexible_dimensions():
|
||||
"""Test voyage-multimodal-3.5 model with custom output dimension."""
|
||||
voyageai = (
|
||||
get_registry()
|
||||
.get("voyageai")
|
||||
.create(name="voyage-multimodal-3.5", output_dimension=512, max_retries=0)
|
||||
)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
assert voyageai.ndims() == 512
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect("~/lancedb")
|
||||
tbl = db.create_table("test_multimodal_35_dim", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == 512
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_multimodal_35_image_embedding():
|
||||
"""Test voyage-multimodal-3.5 model with image input."""
|
||||
voyageai = (
|
||||
get_registry()
|
||||
.get("voyageai")
|
||||
.create(name="voyage-multimodal-3.5", max_retries=0)
|
||||
)
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
db = lancedb.connect("~/lancedb")
|
||||
table = db.create_table(
|
||||
"test_multimodal_35_images", schema=Images, mode="overwrite"
|
||||
)
|
||||
labels = ["cat", "dog"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
]
|
||||
table.add(pd.DataFrame({"label": labels, "image_uri": uris}))
|
||||
assert len(table.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
assert voyageai.ndims() == 1024
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
@pytest.mark.parametrize("dimension", [256, 512, 1024, 2048])
|
||||
def test_voyageai_multimodal_35_all_dimensions(dimension):
|
||||
"""Test voyage-multimodal-3.5 model with all valid output dimensions."""
|
||||
voyageai = (
|
||||
get_registry()
|
||||
.get("voyageai")
|
||||
.create(name="voyage-multimodal-3.5", output_dimension=dimension, max_retries=0)
|
||||
)
|
||||
|
||||
assert voyageai.ndims() == dimension
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world"]})
|
||||
db = lancedb.connect("~/lancedb")
|
||||
tbl = db.create_table(
|
||||
f"test_multimodal_35_dim_{dimension}", schema=TextModel, mode="overwrite"
|
||||
)
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == dimension
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_multimodal_35_invalid_dimension():
|
||||
"""Test voyage-multimodal-3.5 model raises error for invalid output dimension."""
|
||||
with pytest.raises(ValueError, match="Invalid output_dimension"):
|
||||
voyageai = (
|
||||
get_registry()
|
||||
.get("voyageai")
|
||||
.create(name="voyage-multimodal-3.5", output_dimension=999, max_retries=0)
|
||||
)
|
||||
# ndims() is where the validation happens
|
||||
voyageai.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("colpali_engine") is None,
|
||||
|
||||
@@ -26,6 +26,8 @@ import pytest
|
||||
from lance_namespace import (
|
||||
CreateEmptyTableRequest,
|
||||
CreateEmptyTableResponse,
|
||||
DeclareTableRequest,
|
||||
DeclareTableResponse,
|
||||
DescribeTableRequest,
|
||||
DescribeTableResponse,
|
||||
LanceNamespace,
|
||||
@@ -160,6 +162,19 @@ class TrackingNamespace(LanceNamespace):
|
||||
|
||||
return modified
|
||||
|
||||
def declare_table(self, request: DeclareTableRequest) -> DeclareTableResponse:
|
||||
"""Track declare_table calls and inject rotating credentials."""
|
||||
with self.lock:
|
||||
self.create_call_count += 1
|
||||
count = self.create_call_count
|
||||
|
||||
response = self.inner.declare_table(request)
|
||||
response.storage_options = self._modify_storage_options(
|
||||
response.storage_options, count
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def create_empty_table(
|
||||
self, request: CreateEmptyTableRequest
|
||||
) -> CreateEmptyTableResponse:
|
||||
|
||||
@@ -105,6 +105,253 @@ def test_optional_types_py310():
|
||||
assert schema == expect_schema
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 10),
|
||||
reason="using PEP 604 union types requires python3.10 or higher",
|
||||
)
|
||||
def test_optional_structs_py310():
|
||||
class SplitInfo(pydantic.BaseModel):
|
||||
start_frame: int
|
||||
end_frame: int
|
||||
|
||||
class TestModel(pydantic.BaseModel):
|
||||
id: str
|
||||
split: SplitInfo | None = None
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
|
||||
expect_schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.utf8(), False),
|
||||
pa.field(
|
||||
"split",
|
||||
pa.struct(
|
||||
[
|
||||
pa.field("start_frame", pa.int64(), False),
|
||||
pa.field("end_frame", pa.int64(), False),
|
||||
]
|
||||
),
|
||||
True,
|
||||
),
|
||||
]
|
||||
)
|
||||
assert schema == expect_schema
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 10),
|
||||
reason="using PEP 604 union types requires python3.10 or higher",
|
||||
)
|
||||
def test_optional_struct_list_py310():
|
||||
class SplitInfo(pydantic.BaseModel):
|
||||
start_frame: int
|
||||
end_frame: int
|
||||
|
||||
class TestModel(pydantic.BaseModel):
|
||||
id: str
|
||||
splits: list[SplitInfo] | None = None
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
|
||||
expect_schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.utf8(), False),
|
||||
pa.field(
|
||||
"splits",
|
||||
pa.list_(
|
||||
pa.struct(
|
||||
[
|
||||
pa.field("start_frame", pa.int64(), False),
|
||||
pa.field("end_frame", pa.int64(), False),
|
||||
]
|
||||
)
|
||||
),
|
||||
True,
|
||||
),
|
||||
]
|
||||
)
|
||||
assert schema == expect_schema
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9),
|
||||
reason="using native type alias requires python3.9 or higher",
|
||||
)
|
||||
def test_nested_struct_list():
|
||||
class SplitInfo(pydantic.BaseModel):
|
||||
start_frame: int
|
||||
end_frame: int
|
||||
|
||||
class TestModel(pydantic.BaseModel):
|
||||
id: str
|
||||
splits: list[SplitInfo]
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
|
||||
expect_schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.utf8(), False),
|
||||
pa.field(
|
||||
"splits",
|
||||
pa.list_(
|
||||
pa.struct(
|
||||
[
|
||||
pa.field("start_frame", pa.int64(), False),
|
||||
pa.field("end_frame", pa.int64(), False),
|
||||
]
|
||||
)
|
||||
),
|
||||
False,
|
||||
),
|
||||
]
|
||||
)
|
||||
assert schema == expect_schema
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9),
|
||||
reason="using native type alias requires python3.9 or higher",
|
||||
)
|
||||
def test_nested_struct_list_optional():
|
||||
class SplitInfo(pydantic.BaseModel):
|
||||
start_frame: int
|
||||
end_frame: int
|
||||
|
||||
class TestModel(pydantic.BaseModel):
|
||||
id: str
|
||||
splits: Optional[list[SplitInfo]] = None
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
|
||||
expect_schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.utf8(), False),
|
||||
pa.field(
|
||||
"splits",
|
||||
pa.list_(
|
||||
pa.struct(
|
||||
[
|
||||
pa.field("start_frame", pa.int64(), False),
|
||||
pa.field("end_frame", pa.int64(), False),
|
||||
]
|
||||
)
|
||||
),
|
||||
True,
|
||||
),
|
||||
]
|
||||
)
|
||||
assert schema == expect_schema
|
||||
|
||||
|
||||
def test_nested_struct_list_optional_items():
|
||||
class SplitInfo(pydantic.BaseModel):
|
||||
start_frame: int
|
||||
end_frame: int
|
||||
|
||||
class TestModel(pydantic.BaseModel):
|
||||
id: str
|
||||
splits: list[Optional[SplitInfo]]
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
|
||||
expect_schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.utf8(), False),
|
||||
pa.field(
|
||||
"splits",
|
||||
pa.list_(
|
||||
pa.field(
|
||||
"item",
|
||||
pa.struct(
|
||||
[
|
||||
pa.field("start_frame", pa.int64(), False),
|
||||
pa.field("end_frame", pa.int64(), False),
|
||||
]
|
||||
),
|
||||
True,
|
||||
)
|
||||
),
|
||||
False,
|
||||
),
|
||||
]
|
||||
)
|
||||
assert schema == expect_schema
|
||||
|
||||
|
||||
def test_nested_struct_list_optional_container_and_items():
|
||||
class SplitInfo(pydantic.BaseModel):
|
||||
start_frame: int
|
||||
end_frame: int
|
||||
|
||||
class TestModel(pydantic.BaseModel):
|
||||
id: str
|
||||
splits: Optional[list[Optional[SplitInfo]]] = None
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
|
||||
expect_schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.utf8(), False),
|
||||
pa.field(
|
||||
"splits",
|
||||
pa.list_(
|
||||
pa.field(
|
||||
"item",
|
||||
pa.struct(
|
||||
[
|
||||
pa.field("start_frame", pa.int64(), False),
|
||||
pa.field("end_frame", pa.int64(), False),
|
||||
]
|
||||
),
|
||||
True,
|
||||
)
|
||||
),
|
||||
True,
|
||||
),
|
||||
]
|
||||
)
|
||||
assert schema == expect_schema
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 10),
|
||||
reason="using PEP 604 union types requires python3.10 or higher",
|
||||
)
|
||||
def test_nested_struct_list_optional_items_pep604():
|
||||
class SplitInfo(pydantic.BaseModel):
|
||||
start_frame: int
|
||||
end_frame: int
|
||||
|
||||
class TestModel(pydantic.BaseModel):
|
||||
id: str
|
||||
splits: list[SplitInfo | None]
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
|
||||
expect_schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.utf8(), False),
|
||||
pa.field(
|
||||
"splits",
|
||||
pa.list_(
|
||||
pa.field(
|
||||
"item",
|
||||
pa.struct(
|
||||
[
|
||||
pa.field("start_frame", pa.int64(), False),
|
||||
pa.field("end_frame", pa.int64(), False),
|
||||
]
|
||||
),
|
||||
True,
|
||||
)
|
||||
),
|
||||
False,
|
||||
),
|
||||
]
|
||||
)
|
||||
assert schema == expect_schema
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info > (3, 8),
|
||||
reason="using native type alias requires python3.9 or higher",
|
||||
|
||||
@@ -168,6 +168,42 @@ def test_table_len_sync():
|
||||
assert len(table) == 1
|
||||
|
||||
|
||||
def test_create_table_exist_ok():
|
||||
def handler(request):
|
||||
if request.path == "/v1/table/test/create/?mode=exist_ok":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b"{}")
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
table = db.create_table("test", [{"id": 1}], exist_ok=True)
|
||||
assert table is not None
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
table = db.create_table("test", [{"id": 1}], mode="create", exist_ok=True)
|
||||
assert table is not None
|
||||
|
||||
|
||||
def test_create_table_exist_ok_with_mode_overwrite():
|
||||
def handler(request):
|
||||
if request.path == "/v1/table/test/create/?mode=overwrite":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b"{}")
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
table = db.create_table("test", [{"id": 1}], mode="overwrite", exist_ok=True)
|
||||
assert table is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_error():
|
||||
request_id_holder = {"request_id": None}
|
||||
|
||||
68
python/python/tests/test_s3_bucket_dots.py
Normal file
68
python/python/tests/test_s3_bucket_dots.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""
|
||||
Tests for S3 bucket names containing dots.
|
||||
|
||||
Related issue: https://github.com/lancedb/lancedb/issues/1898
|
||||
|
||||
These tests validate the early error checking for S3 bucket names with dots.
|
||||
No actual S3 connection is made - validation happens before connection.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import lancedb
|
||||
|
||||
# Test URIs
|
||||
BUCKET_WITH_DOTS = "s3://my.bucket.name/path"
|
||||
BUCKET_WITH_DOTS_AND_REGION = ("s3://my.bucket.name", {"region": "us-east-1"})
|
||||
BUCKET_WITH_DOTS_AND_AWS_REGION = ("s3://my.bucket.name", {"aws_region": "us-east-1"})
|
||||
BUCKET_WITHOUT_DOTS = "s3://my-bucket/path"
|
||||
|
||||
|
||||
class TestS3BucketWithDotsSync:
|
||||
"""Tests for connect()."""
|
||||
|
||||
def test_bucket_with_dots_requires_region(self):
|
||||
with pytest.raises(ValueError, match="contains dots"):
|
||||
lancedb.connect(BUCKET_WITH_DOTS)
|
||||
|
||||
def test_bucket_with_dots_and_region_passes(self):
|
||||
uri, opts = BUCKET_WITH_DOTS_AND_REGION
|
||||
db = lancedb.connect(uri, storage_options=opts)
|
||||
assert db is not None
|
||||
|
||||
def test_bucket_with_dots_and_aws_region_passes(self):
|
||||
uri, opts = BUCKET_WITH_DOTS_AND_AWS_REGION
|
||||
db = lancedb.connect(uri, storage_options=opts)
|
||||
assert db is not None
|
||||
|
||||
def test_bucket_without_dots_passes(self):
|
||||
db = lancedb.connect(BUCKET_WITHOUT_DOTS)
|
||||
assert db is not None
|
||||
|
||||
|
||||
class TestS3BucketWithDotsAsync:
|
||||
"""Tests for connect_async()."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bucket_with_dots_requires_region(self):
|
||||
with pytest.raises(ValueError, match="contains dots"):
|
||||
await lancedb.connect_async(BUCKET_WITH_DOTS)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bucket_with_dots_and_region_passes(self):
|
||||
uri, opts = BUCKET_WITH_DOTS_AND_REGION
|
||||
db = await lancedb.connect_async(uri, storage_options=opts)
|
||||
assert db is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bucket_with_dots_and_aws_region_passes(self):
|
||||
uri, opts = BUCKET_WITH_DOTS_AND_AWS_REGION
|
||||
db = await lancedb.connect_async(uri, storage_options=opts)
|
||||
assert db is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bucket_without_dots_passes(self):
|
||||
db = await lancedb.connect_async(BUCKET_WITHOUT_DOTS)
|
||||
assert db is not None
|
||||
@@ -1967,3 +1967,9 @@ def test_add_table_with_empty_embeddings(tmp_path):
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
assert table.count_rows() == 1
|
||||
|
||||
|
||||
def test_table_uri(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
table = db.create_table("my_table", data=[{"x": 0}])
|
||||
assert table.uri == str(tmp_path / "my_table.lance")
|
||||
|
||||
@@ -528,12 +528,19 @@ def test_sanitize_data(
|
||||
else:
|
||||
expected_schema = schema
|
||||
else:
|
||||
from conftest import pandas_string_type
|
||||
|
||||
# polars uses large_string, pandas 3.0+ uses large_string, others use string
|
||||
if isinstance(data, pl.DataFrame):
|
||||
text_type = pa.large_utf8()
|
||||
elif isinstance(data, pd.DataFrame):
|
||||
text_type = pandas_string_type()
|
||||
else:
|
||||
text_type = pa.string()
|
||||
expected_schema = pa.schema(
|
||||
{
|
||||
"id": pa.int64(),
|
||||
"text": pa.large_utf8()
|
||||
if isinstance(data, pl.DataFrame)
|
||||
else pa.string(),
|
||||
"text": text_type,
|
||||
"vector": pa.list_(pa.float32(), 10),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -10,8 +10,7 @@ use arrow::{
|
||||
use futures::stream::StreamExt;
|
||||
use lancedb::arrow::SendableRecordBatchStream;
|
||||
use pyo3::{
|
||||
exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, PyAny, PyObject, PyRef, PyResult,
|
||||
Python,
|
||||
exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, Py, PyAny, PyRef, PyResult, Python,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
@@ -36,8 +35,11 @@ impl RecordBatchStream {
|
||||
#[pymethods]
|
||||
impl RecordBatchStream {
|
||||
#[getter]
|
||||
pub fn schema(&self, py: Python) -> PyResult<PyObject> {
|
||||
(*self.schema).clone().into_pyarrow(py)
|
||||
pub fn schema(&self, py: Python) -> PyResult<Py<PyAny>> {
|
||||
(*self.schema)
|
||||
.clone()
|
||||
.into_pyarrow(py)
|
||||
.map(|obj| obj.unbind())
|
||||
}
|
||||
|
||||
pub fn __aiter__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> {
|
||||
@@ -53,7 +55,12 @@ impl RecordBatchStream {
|
||||
.next()
|
||||
.await
|
||||
.ok_or_else(|| PyStopAsyncIteration::new_err(""))?;
|
||||
Python::with_gil(|py| inner_next.infer_error()?.to_pyarrow(py))
|
||||
Python::attach(|py| {
|
||||
inner_next
|
||||
.infer_error()?
|
||||
.to_pyarrow(py)
|
||||
.map(|obj| obj.unbind())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
pyclass, pyfunction, pymethods,
|
||||
types::{PyDict, PyDictMethods},
|
||||
Bound, FromPyObject, Py, PyAny, PyObject, PyRef, PyResult, Python,
|
||||
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
@@ -114,7 +114,7 @@ impl Connection {
|
||||
data: Bound<'_, PyAny>,
|
||||
namespace: Vec<String>,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
storage_options_provider: Option<PyObject>,
|
||||
storage_options_provider: Option<Py<PyAny>>,
|
||||
location: Option<String>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
@@ -152,7 +152,7 @@ impl Connection {
|
||||
schema: Bound<'_, PyAny>,
|
||||
namespace: Vec<String>,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
storage_options_provider: Option<PyObject>,
|
||||
storage_options_provider: Option<Py<PyAny>>,
|
||||
location: Option<String>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
@@ -187,7 +187,7 @@ impl Connection {
|
||||
name: String,
|
||||
namespace: Vec<String>,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
storage_options_provider: Option<PyObject>,
|
||||
storage_options_provider: Option<Py<PyAny>>,
|
||||
index_cache_size: Option<u32>,
|
||||
location: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
@@ -304,9 +304,10 @@ impl Connection {
|
||||
},
|
||||
page_token,
|
||||
limit: limit.map(|l| l as i32),
|
||||
..Default::default()
|
||||
};
|
||||
let response = inner.list_namespaces(request).await.infer_error()?;
|
||||
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
|
||||
Python::attach(|py| -> PyResult<Py<PyDict>> {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("namespaces", response.namespaces)?;
|
||||
dict.set_item("page_token", response.page_token)?;
|
||||
@@ -325,11 +326,12 @@ impl Connection {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
let py = self_.py();
|
||||
future_into_py(py, async move {
|
||||
use lance_namespace::models::{create_namespace_request, CreateNamespaceRequest};
|
||||
let mode_enum = mode.and_then(|m| match m.to_lowercase().as_str() {
|
||||
"create" => Some(create_namespace_request::Mode::Create),
|
||||
"exist_ok" => Some(create_namespace_request::Mode::ExistOk),
|
||||
"overwrite" => Some(create_namespace_request::Mode::Overwrite),
|
||||
use lance_namespace::models::CreateNamespaceRequest;
|
||||
// Mode is now a string field
|
||||
let mode_str = mode.and_then(|m| match m.to_lowercase().as_str() {
|
||||
"create" => Some("Create".to_string()),
|
||||
"exist_ok" => Some("ExistOk".to_string()),
|
||||
"overwrite" => Some("Overwrite".to_string()),
|
||||
_ => None,
|
||||
});
|
||||
let request = CreateNamespaceRequest {
|
||||
@@ -338,11 +340,12 @@ impl Connection {
|
||||
} else {
|
||||
Some(namespace)
|
||||
},
|
||||
mode: mode_enum,
|
||||
mode: mode_str,
|
||||
properties,
|
||||
..Default::default()
|
||||
};
|
||||
let response = inner.create_namespace(request).await.infer_error()?;
|
||||
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
|
||||
Python::attach(|py| -> PyResult<Py<PyDict>> {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("properties", response.properties)?;
|
||||
Ok(dict.unbind())
|
||||
@@ -360,15 +363,16 @@ impl Connection {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
let py = self_.py();
|
||||
future_into_py(py, async move {
|
||||
use lance_namespace::models::{drop_namespace_request, DropNamespaceRequest};
|
||||
let mode_enum = mode.and_then(|m| match m.to_uppercase().as_str() {
|
||||
"SKIP" => Some(drop_namespace_request::Mode::Skip),
|
||||
"FAIL" => Some(drop_namespace_request::Mode::Fail),
|
||||
use lance_namespace::models::DropNamespaceRequest;
|
||||
// Mode and Behavior are now string fields
|
||||
let mode_str = mode.and_then(|m| match m.to_uppercase().as_str() {
|
||||
"SKIP" => Some("Skip".to_string()),
|
||||
"FAIL" => Some("Fail".to_string()),
|
||||
_ => None,
|
||||
});
|
||||
let behavior_enum = behavior.and_then(|b| match b.to_uppercase().as_str() {
|
||||
"RESTRICT" => Some(drop_namespace_request::Behavior::Restrict),
|
||||
"CASCADE" => Some(drop_namespace_request::Behavior::Cascade),
|
||||
let behavior_str = behavior.and_then(|b| match b.to_uppercase().as_str() {
|
||||
"RESTRICT" => Some("Restrict".to_string()),
|
||||
"CASCADE" => Some("Cascade".to_string()),
|
||||
_ => None,
|
||||
});
|
||||
let request = DropNamespaceRequest {
|
||||
@@ -377,11 +381,12 @@ impl Connection {
|
||||
} else {
|
||||
Some(namespace)
|
||||
},
|
||||
mode: mode_enum,
|
||||
behavior: behavior_enum,
|
||||
mode: mode_str,
|
||||
behavior: behavior_str,
|
||||
..Default::default()
|
||||
};
|
||||
let response = inner.drop_namespace(request).await.infer_error()?;
|
||||
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
|
||||
Python::attach(|py| -> PyResult<Py<PyDict>> {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("properties", response.properties)?;
|
||||
dict.set_item("transaction_id", response.transaction_id)?;
|
||||
@@ -405,9 +410,10 @@ impl Connection {
|
||||
} else {
|
||||
Some(namespace)
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
let response = inner.describe_namespace(request).await.infer_error()?;
|
||||
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
|
||||
Python::attach(|py| -> PyResult<Py<PyDict>> {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("properties", response.properties)?;
|
||||
Ok(dict.unbind())
|
||||
@@ -434,9 +440,10 @@ impl Connection {
|
||||
},
|
||||
page_token,
|
||||
limit: limit.map(|l| l as i32),
|
||||
..Default::default()
|
||||
};
|
||||
let response = inner.list_tables(request).await.infer_error()?;
|
||||
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
|
||||
Python::attach(|py| -> PyResult<Py<PyDict>> {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("tables", response.tables)?;
|
||||
dict.set_item("page_token", response.page_token)?;
|
||||
|
||||
@@ -40,7 +40,7 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
|
||||
request_id,
|
||||
source,
|
||||
status_code,
|
||||
} => Python::with_gil(|py| {
|
||||
} => Python::attach(|py| {
|
||||
let message = err.to_string();
|
||||
let http_err_cls = py
|
||||
.import(intern!(py, "lancedb.remote.errors"))?
|
||||
@@ -75,7 +75,7 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
|
||||
max_read_failures,
|
||||
source,
|
||||
status_code,
|
||||
} => Python::with_gil(|py| {
|
||||
} => Python::attach(|py| {
|
||||
let cause_err = http_from_rust_error(
|
||||
py,
|
||||
source.as_ref(),
|
||||
|
||||
@@ -12,7 +12,7 @@ pub struct PyHeaderProvider {
|
||||
|
||||
impl Clone for PyHeaderProvider {
|
||||
fn clone(&self) -> Self {
|
||||
Python::with_gil(|py| Self {
|
||||
Python::attach(|py| Self {
|
||||
provider: self.provider.clone_ref(py),
|
||||
})
|
||||
}
|
||||
@@ -25,7 +25,7 @@ impl PyHeaderProvider {
|
||||
|
||||
/// Get headers from the Python provider (internal implementation)
|
||||
fn get_headers_internal(&self) -> Result<HashMap<String, String>, String> {
|
||||
Python::with_gil(|py| {
|
||||
Python::attach(|py| {
|
||||
// Call the get_headers method
|
||||
let result = self.provider.call_method0(py, "get_headers");
|
||||
|
||||
|
||||
@@ -281,7 +281,7 @@ impl PyPermutationReader {
|
||||
let reader = slf.reader.clone();
|
||||
future_into_py(slf.py(), async move {
|
||||
let schema = reader.output_schema(selection).await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -453,7 +453,7 @@ impl Query {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.output_schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -532,7 +532,7 @@ impl TakeQuery {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.output_schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -627,7 +627,7 @@ impl FTSQuery {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.output_schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -806,7 +806,7 @@ impl VectorQuery {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.output_schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -17,20 +17,20 @@ use pyo3::types::PyDict;
|
||||
/// Internal wrapper around a Python object implementing StorageOptionsProvider
|
||||
pub struct PyStorageOptionsProvider {
|
||||
/// The Python object implementing fetch_storage_options()
|
||||
inner: PyObject,
|
||||
inner: Py<PyAny>,
|
||||
}
|
||||
|
||||
impl Clone for PyStorageOptionsProvider {
|
||||
fn clone(&self) -> Self {
|
||||
Python::with_gil(|py| Self {
|
||||
Python::attach(|py| Self {
|
||||
inner: self.inner.clone_ref(py),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl PyStorageOptionsProvider {
|
||||
pub fn new(obj: PyObject) -> PyResult<Self> {
|
||||
Python::with_gil(|py| {
|
||||
pub fn new(obj: Py<PyAny>) -> PyResult<Self> {
|
||||
Python::attach(|py| {
|
||||
// Verify the object has a fetch_storage_options method
|
||||
if !obj.bind(py).hasattr("fetch_storage_options")? {
|
||||
return Err(pyo3::exceptions::PyTypeError::new_err(
|
||||
@@ -60,7 +60,7 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
|
||||
let py_provider = self.py_provider.clone();
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
Python::with_gil(|py| {
|
||||
Python::attach(|py| {
|
||||
// Call the Python fetch_storage_options method
|
||||
let result = py_provider
|
||||
.inner
|
||||
@@ -119,7 +119,7 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> String {
|
||||
Python::with_gil(|py| {
|
||||
Python::attach(|py| {
|
||||
// Call provider_id() method on the Python object
|
||||
let obj = self.py_provider.inner.bind(py);
|
||||
obj.call_method0("provider_id")
|
||||
@@ -143,7 +143,7 @@ impl std::fmt::Debug for PyStorageOptionsProviderWrapper {
|
||||
/// This is the main entry point for converting Python StorageOptionsProvider objects
|
||||
/// to Rust trait objects that can be used by the Lance ecosystem.
|
||||
pub fn py_object_to_storage_options_provider(
|
||||
py_obj: PyObject,
|
||||
py_obj: Py<PyAny>,
|
||||
) -> PyResult<Arc<dyn StorageOptionsProvider>> {
|
||||
let py_provider = PyStorageOptionsProvider::new(py_obj)?;
|
||||
Ok(Arc::new(PyStorageOptionsProviderWrapper::new(py_provider)))
|
||||
|
||||
@@ -287,7 +287,7 @@ impl Table {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -437,7 +437,7 @@ impl Table {
|
||||
future_into_py(self_.py(), async move {
|
||||
let stats = inner.index_stats(&index_name).await.infer_error()?;
|
||||
if let Some(stats) = stats {
|
||||
Python::with_gil(|py| {
|
||||
Python::attach(|py| {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("num_indexed_rows", stats.num_indexed_rows)?;
|
||||
dict.set_item("num_unindexed_rows", stats.num_unindexed_rows)?;
|
||||
@@ -467,7 +467,7 @@ impl Table {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let stats = inner.stats().await.infer_error()?;
|
||||
Python::with_gil(|py| {
|
||||
Python::attach(|py| {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("total_bytes", stats.total_bytes)?;
|
||||
dict.set_item("num_rows", stats.num_rows)?;
|
||||
@@ -497,6 +497,11 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn uri(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move { inner.uri().await.infer_error() })
|
||||
}
|
||||
|
||||
pub fn __repr__(&self) -> String {
|
||||
match &self.inner {
|
||||
None => format!("ClosedTable({})", self.name),
|
||||
@@ -516,7 +521,7 @@ impl Table {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let versions = inner.list_versions().await.infer_error()?;
|
||||
let versions_as_dict = Python::with_gil(|py| {
|
||||
let versions_as_dict = Python::attach(|py| {
|
||||
versions
|
||||
.iter()
|
||||
.map(|v| {
|
||||
@@ -867,7 +872,7 @@ impl Tags {
|
||||
let tags = inner.tags().await.infer_error()?;
|
||||
let res = tags.list().await.infer_error()?;
|
||||
|
||||
Python::with_gil(|py| {
|
||||
Python::attach(|py| {
|
||||
let py_dict = PyDict::new(py);
|
||||
for (key, contents) in res {
|
||||
let value_dict = PyDict::new(py);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.23.1-beta.1"
|
||||
version = "0.24.1"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
@@ -104,11 +104,16 @@ test-log = "0.2"
|
||||
|
||||
|
||||
[features]
|
||||
default = ["aws", "gcs", "azure", "dynamodb", "oss"]
|
||||
default = []
|
||||
aws = ["lance/aws", "lance-io/aws", "lance-namespace-impls/dir-aws"]
|
||||
oss = ["lance/oss", "lance-io/oss", "lance-namespace-impls/dir-oss"]
|
||||
gcs = ["lance/gcp", "lance-io/gcp", "lance-namespace-impls/dir-gcp"]
|
||||
azure = ["lance/azure", "lance-io/azure", "lance-namespace-impls/dir-azure"]
|
||||
huggingface = [
|
||||
"lance/huggingface",
|
||||
"lance-io/huggingface",
|
||||
"lance-namespace-impls/dir-huggingface",
|
||||
]
|
||||
dynamodb = ["lance/dynamodb", "aws"]
|
||||
remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
|
||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||
@@ -148,3 +153,6 @@ name = "ivf_pq"
|
||||
[[example]]
|
||||
name = "hybrid_search"
|
||||
required-features = ["sentence-transformers"]
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
|
||||
@@ -36,10 +36,42 @@ use crate::remote::{
|
||||
};
|
||||
use crate::table::{TableDefinition, WriteOptions};
|
||||
use crate::Table;
|
||||
use lance::io::ObjectStoreParams;
|
||||
pub use lance_encoding::version::LanceFileVersion;
|
||||
#[cfg(feature = "remote")]
|
||||
use lance_io::object_store::StorageOptions;
|
||||
use lance_io::object_store::StorageOptionsProvider;
|
||||
use lance_io::object_store::{StorageOptionsAccessor, StorageOptionsProvider};
|
||||
|
||||
fn merge_storage_options(
|
||||
store_params: &mut ObjectStoreParams,
|
||||
pairs: impl IntoIterator<Item = (String, String)>,
|
||||
) {
|
||||
let mut options = store_params.storage_options().cloned().unwrap_or_default();
|
||||
for (key, value) in pairs {
|
||||
options.insert(key, value);
|
||||
}
|
||||
let provider = store_params
|
||||
.storage_options_accessor
|
||||
.as_ref()
|
||||
.and_then(|accessor| accessor.provider().cloned());
|
||||
let accessor = if let Some(provider) = provider {
|
||||
StorageOptionsAccessor::with_initial_and_provider(options, provider)
|
||||
} else {
|
||||
StorageOptionsAccessor::with_static_options(options)
|
||||
};
|
||||
store_params.storage_options_accessor = Some(Arc::new(accessor));
|
||||
}
|
||||
|
||||
fn set_storage_options_provider(
|
||||
store_params: &mut ObjectStoreParams,
|
||||
provider: Arc<dyn StorageOptionsProvider>,
|
||||
) {
|
||||
let accessor = match store_params.storage_options().cloned() {
|
||||
Some(options) => StorageOptionsAccessor::with_initial_and_provider(options, provider),
|
||||
None => StorageOptionsAccessor::with_provider(provider),
|
||||
};
|
||||
store_params.storage_options_accessor = Some(Arc::new(accessor));
|
||||
}
|
||||
|
||||
/// A builder for configuring a [`Connection::table_names`] operation
|
||||
pub struct TableNamesBuilder {
|
||||
@@ -246,16 +278,14 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
|
||||
///
|
||||
/// See available options at <https://lancedb.com/docs/storage/>
|
||||
pub fn storage_option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
|
||||
let store_options = self
|
||||
let store_params = self
|
||||
.request
|
||||
.write_options
|
||||
.lance_write_params
|
||||
.get_or_insert(Default::default())
|
||||
.store_params
|
||||
.get_or_insert(Default::default())
|
||||
.storage_options
|
||||
.get_or_insert(Default::default());
|
||||
store_options.insert(key.into(), value.into());
|
||||
merge_storage_options(store_params, [(key.into(), value.into())]);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -269,19 +299,17 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
|
||||
mut self,
|
||||
pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
|
||||
) -> Self {
|
||||
let store_options = self
|
||||
let store_params = self
|
||||
.request
|
||||
.write_options
|
||||
.lance_write_params
|
||||
.get_or_insert(Default::default())
|
||||
.store_params
|
||||
.get_or_insert(Default::default())
|
||||
.storage_options
|
||||
.get_or_insert(Default::default());
|
||||
|
||||
for (key, value) in pairs {
|
||||
store_options.insert(key.into(), value.into());
|
||||
}
|
||||
let updates = pairs
|
||||
.into_iter()
|
||||
.map(|(key, value)| (key.into(), value.into()));
|
||||
merge_storage_options(store_params, updates);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -318,23 +346,21 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
|
||||
/// This has no effect in LanceDB Cloud.
|
||||
#[deprecated(since = "0.15.1", note = "Use `database_options` instead")]
|
||||
pub fn enable_v2_manifest_paths(mut self, use_v2_manifest_paths: bool) -> Self {
|
||||
let storage_options = self
|
||||
let store_params = self
|
||||
.request
|
||||
.write_options
|
||||
.lance_write_params
|
||||
.get_or_insert_with(Default::default)
|
||||
.store_params
|
||||
.get_or_insert_with(Default::default)
|
||||
.storage_options
|
||||
.get_or_insert_with(Default::default);
|
||||
|
||||
storage_options.insert(
|
||||
OPT_NEW_TABLE_V2_MANIFEST_PATHS.to_string(),
|
||||
if use_v2_manifest_paths {
|
||||
"true".to_string()
|
||||
} else {
|
||||
"false".to_string()
|
||||
},
|
||||
let value = if use_v2_manifest_paths {
|
||||
"true".to_string()
|
||||
} else {
|
||||
"false".to_string()
|
||||
};
|
||||
merge_storage_options(
|
||||
store_params,
|
||||
[(OPT_NEW_TABLE_V2_MANIFEST_PATHS.to_string(), value)],
|
||||
);
|
||||
self
|
||||
}
|
||||
@@ -344,19 +370,19 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
|
||||
/// The default is `LanceFileVersion::Stable`.
|
||||
#[deprecated(since = "0.15.1", note = "Use `database_options` instead")]
|
||||
pub fn data_storage_version(mut self, data_storage_version: LanceFileVersion) -> Self {
|
||||
let storage_options = self
|
||||
let store_params = self
|
||||
.request
|
||||
.write_options
|
||||
.lance_write_params
|
||||
.get_or_insert_with(Default::default)
|
||||
.store_params
|
||||
.get_or_insert_with(Default::default)
|
||||
.storage_options
|
||||
.get_or_insert_with(Default::default);
|
||||
|
||||
storage_options.insert(
|
||||
OPT_NEW_TABLE_STORAGE_VERSION.to_string(),
|
||||
data_storage_version.to_string(),
|
||||
merge_storage_options(
|
||||
store_params,
|
||||
[(
|
||||
OPT_NEW_TABLE_STORAGE_VERSION.to_string(),
|
||||
data_storage_version.to_string(),
|
||||
)],
|
||||
);
|
||||
self
|
||||
}
|
||||
@@ -381,13 +407,14 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
|
||||
/// This allows tables to automatically refresh cloud storage credentials
|
||||
/// when they expire, enabling long-running operations on remote storage.
|
||||
pub fn storage_options_provider(mut self, provider: Arc<dyn StorageOptionsProvider>) -> Self {
|
||||
self.request
|
||||
let store_params = self
|
||||
.request
|
||||
.write_options
|
||||
.lance_write_params
|
||||
.get_or_insert(Default::default())
|
||||
.store_params
|
||||
.get_or_insert(Default::default())
|
||||
.storage_options_provider = Some(provider);
|
||||
.get_or_insert(Default::default());
|
||||
set_storage_options_provider(store_params, provider);
|
||||
self
|
||||
}
|
||||
}
|
||||
@@ -450,15 +477,13 @@ impl OpenTableBuilder {
|
||||
///
|
||||
/// See available options at <https://lancedb.com/docs/storage/>
|
||||
pub fn storage_option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
|
||||
let storage_options = self
|
||||
let store_params = self
|
||||
.request
|
||||
.lance_read_params
|
||||
.get_or_insert(Default::default())
|
||||
.store_options
|
||||
.get_or_insert(Default::default())
|
||||
.storage_options
|
||||
.get_or_insert(Default::default());
|
||||
storage_options.insert(key.into(), value.into());
|
||||
merge_storage_options(store_params, [(key.into(), value.into())]);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -472,18 +497,16 @@ impl OpenTableBuilder {
|
||||
mut self,
|
||||
pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
|
||||
) -> Self {
|
||||
let storage_options = self
|
||||
let store_params = self
|
||||
.request
|
||||
.lance_read_params
|
||||
.get_or_insert(Default::default())
|
||||
.store_options
|
||||
.get_or_insert(Default::default())
|
||||
.storage_options
|
||||
.get_or_insert(Default::default());
|
||||
|
||||
for (key, value) in pairs {
|
||||
storage_options.insert(key.into(), value.into());
|
||||
}
|
||||
let updates = pairs
|
||||
.into_iter()
|
||||
.map(|(key, value)| (key.into(), value.into()));
|
||||
merge_storage_options(store_params, updates);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -507,12 +530,13 @@ impl OpenTableBuilder {
|
||||
/// This allows tables to automatically refresh cloud storage credentials
|
||||
/// when they expire, enabling long-running operations on remote storage.
|
||||
pub fn storage_options_provider(mut self, provider: Arc<dyn StorageOptionsProvider>) -> Self {
|
||||
self.request
|
||||
let store_params = self
|
||||
.request
|
||||
.lance_read_params
|
||||
.get_or_insert(Default::default())
|
||||
.store_options
|
||||
.get_or_insert(Default::default())
|
||||
.storage_options_provider = Some(provider);
|
||||
.get_or_insert(Default::default());
|
||||
set_storage_options_provider(store_params, provider);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -868,6 +892,10 @@ pub struct ConnectBuilder {
|
||||
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
const ENV_VARS_TO_STORAGE_OPTS: [(&str, &str); 1] =
|
||||
[("AZURE_STORAGE_ACCOUNT_NAME", "azure_storage_account_name")];
|
||||
|
||||
impl ConnectBuilder {
|
||||
/// Create a new [`ConnectOptions`] with the given database URI.
|
||||
pub fn new(uri: &str) -> Self {
|
||||
@@ -1051,11 +1079,27 @@ impl ConnectBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
fn apply_env_defaults(
|
||||
env_var_to_remote_storage_option: &[(&str, &str)],
|
||||
options: &mut HashMap<String, String>,
|
||||
) {
|
||||
for (env_key, opt_key) in env_var_to_remote_storage_option {
|
||||
if let Ok(env_value) = std::env::var(env_key) {
|
||||
if !options.contains_key(*opt_key) {
|
||||
options.insert((*opt_key).to_string(), env_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
fn execute_remote(self) -> Result<Connection> {
|
||||
use crate::remote::db::RemoteDatabaseOptions;
|
||||
|
||||
let options = RemoteDatabaseOptions::parse_from_map(&self.request.options)?;
|
||||
let mut merged_options = self.request.options.clone();
|
||||
Self::apply_env_defaults(&ENV_VARS_TO_STORAGE_OPTS, &mut merged_options);
|
||||
let options = RemoteDatabaseOptions::parse_from_map(&merged_options)?;
|
||||
|
||||
let region = options.region.ok_or_else(|| Error::InvalidInput {
|
||||
message: "A region is required when connecting to LanceDb Cloud".to_string(),
|
||||
@@ -1277,8 +1321,6 @@ mod test_utils {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs::create_dir_all;
|
||||
|
||||
use crate::database::listing::{ListingDatabaseOptions, NewTableConfig};
|
||||
use crate::query::QueryBase;
|
||||
use crate::query::{ExecutableQuery, QueryExecutionOptions};
|
||||
@@ -1302,6 +1344,23 @@ mod tests {
|
||||
assert_eq!(tc.connection.uri(), tc.uri);
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
#[test]
|
||||
fn test_apply_env_defaults() {
|
||||
let env_key = "TEST_APPLY_ENV_DEFAULTS_ENVIRONMENT_VARIABLE_ENV_KEY";
|
||||
let env_val = "TEST_APPLY_ENV_DEFAULTS_ENVIRONMENT_VARIABLE_ENV_VAL";
|
||||
let opts_key = "test_apply_env_defaults_environment_variable_opts_key";
|
||||
std::env::set_var(env_key, env_val);
|
||||
|
||||
let mut options = HashMap::new();
|
||||
ConnectBuilder::apply_env_defaults(&[(env_key, opts_key)], &mut options);
|
||||
assert_eq!(Some(&env_val.to_string()), options.get(opts_key));
|
||||
|
||||
options.insert(opts_key.to_string(), "EXPLICIT-VALUE".to_string());
|
||||
ConnectBuilder::apply_env_defaults(&[(env_key, opts_key)], &mut options);
|
||||
assert_eq!(Some(&"EXPLICIT-VALUE".to_string()), options.get(opts_key));
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[tokio::test]
|
||||
async fn test_connect_relative() {
|
||||
@@ -1325,25 +1384,27 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_table_names() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let tc = new_test_connection().await.unwrap();
|
||||
let db = tc.connection;
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));
|
||||
let mut names = Vec::with_capacity(100);
|
||||
for _ in 0..100 {
|
||||
let mut name = uuid::Uuid::new_v4().to_string();
|
||||
let name = uuid::Uuid::new_v4().to_string();
|
||||
names.push(name.clone());
|
||||
name.push_str(".lance");
|
||||
create_dir_all(tmp_dir.path().join(&name)).unwrap();
|
||||
db.create_empty_table(name, schema.clone())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
names.sort();
|
||||
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
let tables = db.table_names().execute().await.unwrap();
|
||||
let tables = db.table_names().limit(100).execute().await.unwrap();
|
||||
|
||||
assert_eq!(tables, names);
|
||||
|
||||
let tables = db
|
||||
.table_names()
|
||||
.start_after(&names[30])
|
||||
.limit(100)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1524,18 +1585,27 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn drop_table() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let tc = new_test_connection().await.unwrap();
|
||||
let db = tc.connection;
|
||||
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
if tc.is_remote {
|
||||
// All the typical endpoints such as s3:///, file-object-store:///, etc. treat drop_table
|
||||
// as idempotent.
|
||||
assert!(db.drop_table("invalid_table", &[]).await.is_ok());
|
||||
} else {
|
||||
// The behavior of drop_table when using a file:/// endpoint differs from all other
|
||||
// object providers, in that it returns an error when deleting a non-existent table.
|
||||
assert!(matches!(
|
||||
db.drop_table("invalid_table", &[]).await,
|
||||
Err(crate::Error::TableNotFound { .. }),
|
||||
));
|
||||
}
|
||||
|
||||
// drop non-exist table
|
||||
assert!(matches!(
|
||||
db.drop_table("invalid_table", &[]).await,
|
||||
Err(crate::Error::TableNotFound { .. }),
|
||||
));
|
||||
|
||||
create_dir_all(tmp_dir.path().join("table1.lance")).unwrap();
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));
|
||||
db.create_empty_table("table1", schema.clone())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
db.drop_table("table1", &[]).await.unwrap();
|
||||
|
||||
let tables = db.table_names().execute().await.unwrap();
|
||||
|
||||
@@ -12,7 +12,7 @@ use lance::dataset::{builder::DatasetBuilder, ReadParams, WriteMode};
|
||||
use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore};
|
||||
use lance_datafusion::utils::StreamingWriteSource;
|
||||
use lance_encoding::version::LanceFileVersion;
|
||||
use lance_io::object_store::StorageOptionsProvider;
|
||||
use lance_io::object_store::{StorageOptionsAccessor, StorageOptionsProvider};
|
||||
use lance_table::io::commit::commit_handler_from_url;
|
||||
use object_store::local::LocalFileSystem;
|
||||
use snafu::ResultExt;
|
||||
@@ -356,7 +356,13 @@ impl ListingDatabase {
|
||||
.clone()
|
||||
.unwrap_or_else(|| Arc::new(lance::session::Session::default()));
|
||||
let os_params = ObjectStoreParams {
|
||||
storage_options: Some(options.storage_options.clone()),
|
||||
storage_options_accessor: if options.storage_options.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Arc::new(StorageOptionsAccessor::with_static_options(
|
||||
options.storage_options.clone(),
|
||||
)))
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
let (object_store, base_path) = ObjectStore::from_uri_and_params(
|
||||
@@ -463,9 +469,20 @@ impl ListingDatabase {
|
||||
validate_table_name(name)?;
|
||||
|
||||
let mut uri = self.uri.clone();
|
||||
// If the URI does not end with a slash, add one
|
||||
if !uri.ends_with('/') {
|
||||
uri.push('/');
|
||||
// If the URI does not end with a path separator, add one
|
||||
// Use forward slash for URIs (http://, s3://, gs://, file://, etc.)
|
||||
// Use platform-specific separator for local paths without scheme
|
||||
let has_scheme = uri.contains("://");
|
||||
let ends_with_separator = uri.ends_with('/') || uri.ends_with('\\');
|
||||
|
||||
if !ends_with_separator {
|
||||
if has_scheme {
|
||||
// URIs always use forward slash
|
||||
uri.push('/');
|
||||
} else {
|
||||
// Local path without scheme - use platform separator
|
||||
uri.push(std::path::MAIN_SEPARATOR);
|
||||
}
|
||||
}
|
||||
// Append the table name with the lance file extension
|
||||
uri.push_str(&format!("{}.{}", name, LANCE_FILE_EXTENSION));
|
||||
@@ -481,7 +498,13 @@ impl ListingDatabase {
|
||||
|
||||
async fn drop_tables(&self, names: Vec<String>) -> Result<()> {
|
||||
let object_store_params = ObjectStoreParams {
|
||||
storage_options: Some(self.storage_options.clone()),
|
||||
storage_options_accessor: if self.storage_options.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Arc::new(StorageOptionsAccessor::with_static_options(
|
||||
self.storage_options.clone(),
|
||||
)))
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
let mut uri = self.uri.clone();
|
||||
@@ -530,7 +553,7 @@ impl ListingDatabase {
|
||||
.lance_write_params
|
||||
.as_ref()
|
||||
.and_then(|p| p.store_params.as_ref())
|
||||
.and_then(|sp| sp.storage_options.as_ref());
|
||||
.and_then(|sp| sp.storage_options());
|
||||
|
||||
let storage_version_override = storage_options
|
||||
.and_then(|opts| opts.get(OPT_NEW_TABLE_STORAGE_VERSION))
|
||||
@@ -581,21 +604,20 @@ impl ListingDatabase {
|
||||
// will cause a new connection to be created, and that connection will
|
||||
// be dropped from the cache when python GCs the table object, which
|
||||
// confounds reuse across tables.
|
||||
if !self.storage_options.is_empty() {
|
||||
let storage_options = write_params
|
||||
if !self.storage_options.is_empty() || self.storage_options_provider.is_some() {
|
||||
let store_params = write_params
|
||||
.store_params
|
||||
.get_or_insert_with(Default::default)
|
||||
.storage_options
|
||||
.get_or_insert_with(Default::default);
|
||||
self.inherit_storage_options(storage_options);
|
||||
}
|
||||
|
||||
// Set storage options provider if available
|
||||
if self.storage_options_provider.is_some() {
|
||||
write_params
|
||||
.store_params
|
||||
.get_or_insert_with(Default::default)
|
||||
.storage_options_provider = self.storage_options_provider.clone();
|
||||
let mut storage_options = store_params.storage_options().cloned().unwrap_or_default();
|
||||
if !self.storage_options.is_empty() {
|
||||
self.inherit_storage_options(&mut storage_options);
|
||||
}
|
||||
let accessor = if let Some(ref provider) = self.storage_options_provider {
|
||||
StorageOptionsAccessor::with_initial_and_provider(storage_options, provider.clone())
|
||||
} else {
|
||||
StorageOptionsAccessor::with_static_options(storage_options)
|
||||
};
|
||||
store_params.storage_options_accessor = Some(Arc::new(accessor));
|
||||
}
|
||||
|
||||
write_params.data_storage_version = self
|
||||
@@ -881,7 +903,13 @@ impl Database for ListingDatabase {
|
||||
validate_table_name(&request.target_table_name)?;
|
||||
|
||||
let storage_params = ObjectStoreParams {
|
||||
storage_options: Some(self.storage_options.clone()),
|
||||
storage_options_accessor: if self.storage_options.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Arc::new(StorageOptionsAccessor::with_static_options(
|
||||
self.storage_options.clone(),
|
||||
)))
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
let read_params = ReadParams {
|
||||
@@ -945,25 +973,28 @@ impl Database for ListingDatabase {
|
||||
// will cause a new connection to be created, and that connection will
|
||||
// be dropped from the cache when python GCs the table object, which
|
||||
// confounds reuse across tables.
|
||||
if !self.storage_options.is_empty() {
|
||||
let storage_options = request
|
||||
if !self.storage_options.is_empty() || self.storage_options_provider.is_some() {
|
||||
let store_params = request
|
||||
.lance_read_params
|
||||
.get_or_insert_with(Default::default)
|
||||
.store_options
|
||||
.get_or_insert_with(Default::default)
|
||||
.storage_options
|
||||
.get_or_insert_with(Default::default);
|
||||
self.inherit_storage_options(storage_options);
|
||||
}
|
||||
|
||||
// Set storage options provider if available
|
||||
if self.storage_options_provider.is_some() {
|
||||
request
|
||||
.lance_read_params
|
||||
.get_or_insert_with(Default::default)
|
||||
.store_options
|
||||
.get_or_insert_with(Default::default)
|
||||
.storage_options_provider = self.storage_options_provider.clone();
|
||||
let mut storage_options = store_params.storage_options().cloned().unwrap_or_default();
|
||||
if !self.storage_options.is_empty() {
|
||||
self.inherit_storage_options(&mut storage_options);
|
||||
}
|
||||
// Preserve request-level provider if no connection-level provider exists
|
||||
let request_provider = store_params
|
||||
.storage_options_accessor
|
||||
.as_ref()
|
||||
.and_then(|a| a.provider().cloned());
|
||||
let provider = self.storage_options_provider.clone().or(request_provider);
|
||||
let accessor = if let Some(provider) = provider {
|
||||
StorageOptionsAccessor::with_initial_and_provider(storage_options, provider)
|
||||
} else {
|
||||
StorageOptionsAccessor::with_static_options(storage_options)
|
||||
};
|
||||
store_params.storage_options_accessor = Some(Arc::new(accessor));
|
||||
}
|
||||
|
||||
// Some ReadParams are exposed in the OpenTableBuilder, but we also
|
||||
@@ -1071,6 +1102,7 @@ mod tests {
|
||||
use crate::table::{Table, TableDefinition};
|
||||
use arrow_array::{Int32Array, RecordBatch, StringArray};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use std::path::PathBuf;
|
||||
use tempfile::tempdir;
|
||||
|
||||
async fn setup_database() -> (tempfile::TempDir, ListingDatabase) {
|
||||
@@ -1869,7 +1901,9 @@ mod tests {
|
||||
let write_options = WriteOptions {
|
||||
lance_write_params: Some(lance::dataset::WriteParams {
|
||||
store_params: Some(lance::io::ObjectStoreParams {
|
||||
storage_options: Some(storage_options),
|
||||
storage_options_accessor: Some(Arc::new(
|
||||
StorageOptionsAccessor::with_static_options(storage_options),
|
||||
)),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
@@ -1943,7 +1977,9 @@ mod tests {
|
||||
let write_options = WriteOptions {
|
||||
lance_write_params: Some(lance::dataset::WriteParams {
|
||||
store_params: Some(lance::io::ObjectStoreParams {
|
||||
storage_options: Some(storage_options),
|
||||
storage_options_accessor: Some(Arc::new(
|
||||
StorageOptionsAccessor::with_static_options(storage_options),
|
||||
)),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
@@ -2046,6 +2082,19 @@ mod tests {
|
||||
assert_eq!(db_options.new_table_config.enable_stable_row_ids, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_table_uri() {
|
||||
let (_tempdir, db) = setup_database().await;
|
||||
|
||||
let mut pb = PathBuf::new();
|
||||
pb.push(db.uri.clone());
|
||||
pb.push("test.lance");
|
||||
|
||||
let expected = pb.to_str().unwrap();
|
||||
let uri = db.table_uri("test").ok().unwrap();
|
||||
assert_eq!(uri, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_namespace_client() {
|
||||
let (_tempdir, db) = setup_database().await;
|
||||
|
||||
@@ -10,13 +10,14 @@ use async_trait::async_trait;
|
||||
use lance_namespace::{
|
||||
models::{
|
||||
CreateEmptyTableRequest, CreateNamespaceRequest, CreateNamespaceResponse,
|
||||
DescribeNamespaceRequest, DescribeNamespaceResponse, DescribeTableRequest,
|
||||
DropNamespaceRequest, DropNamespaceResponse, DropTableRequest, ListNamespacesRequest,
|
||||
ListNamespacesResponse, ListTablesRequest, ListTablesResponse,
|
||||
DeclareTableRequest, DescribeNamespaceRequest, DescribeNamespaceResponse,
|
||||
DescribeTableRequest, DropNamespaceRequest, DropNamespaceResponse, DropTableRequest,
|
||||
ListNamespacesRequest, ListNamespacesResponse, ListTablesRequest, ListTablesResponse,
|
||||
},
|
||||
LanceNamespace,
|
||||
};
|
||||
use lance_namespace_impls::ConnectBuilder;
|
||||
use log::warn;
|
||||
|
||||
use crate::database::ReadConsistency;
|
||||
use crate::error::{Error, Result};
|
||||
@@ -137,6 +138,7 @@ impl Database for LanceNamespaceDatabase {
|
||||
id: Some(request.namespace),
|
||||
page_token: request.start_after,
|
||||
limit: request.limit.map(|l| l as i32),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let response = self.namespace.list_tables(ns_request).await?;
|
||||
@@ -153,7 +155,7 @@ impl Database for LanceNamespaceDatabase {
|
||||
table_id.push(request.name.clone());
|
||||
let describe_request = DescribeTableRequest {
|
||||
id: Some(table_id.clone()),
|
||||
version: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let describe_result = self.namespace.describe_table(describe_request).await;
|
||||
@@ -171,6 +173,7 @@ impl Database for LanceNamespaceDatabase {
|
||||
// Drop the existing table - must succeed
|
||||
let drop_request = DropTableRequest {
|
||||
id: Some(table_id.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
self.namespace
|
||||
.drop_table(drop_request)
|
||||
@@ -202,29 +205,53 @@ impl Database for LanceNamespaceDatabase {
|
||||
let mut table_id = request.namespace.clone();
|
||||
table_id.push(request.name.clone());
|
||||
|
||||
let create_empty_request = CreateEmptyTableRequest {
|
||||
// Try declare_table first, falling back to create_empty_table for backwards
|
||||
// compatibility with older namespace clients that don't support declare_table
|
||||
let declare_request = DeclareTableRequest {
|
||||
id: Some(table_id.clone()),
|
||||
location: None,
|
||||
properties: if self.storage_options.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(self.storage_options.clone())
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let create_empty_response = self
|
||||
.namespace
|
||||
.create_empty_table(create_empty_request)
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to create empty table: {}", e),
|
||||
})?;
|
||||
let location = match self.namespace.declare_table(declare_request).await {
|
||||
Ok(response) => response.location.ok_or_else(|| Error::Runtime {
|
||||
message: "Table location is missing from declare_table response".to_string(),
|
||||
})?,
|
||||
Err(e) => {
|
||||
// Check if the error is "not supported" and try create_empty_table as fallback
|
||||
let err_str = e.to_string().to_lowercase();
|
||||
if err_str.contains("not supported") || err_str.contains("not implemented") {
|
||||
warn!(
|
||||
"declare_table is not supported by the namespace client, \
|
||||
falling back to deprecated create_empty_table. \
|
||||
create_empty_table is deprecated and will be removed in Lance 3.0.0. \
|
||||
Please upgrade your namespace client to support declare_table."
|
||||
);
|
||||
#[allow(deprecated)]
|
||||
let create_empty_request = CreateEmptyTableRequest {
|
||||
id: Some(table_id.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let location = create_empty_response
|
||||
.location
|
||||
.ok_or_else(|| Error::Runtime {
|
||||
message: "Table location is missing from create_empty_table response".to_string(),
|
||||
})?;
|
||||
#[allow(deprecated)]
|
||||
let create_response = self
|
||||
.namespace
|
||||
.create_empty_table(create_empty_request)
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to create empty table: {}", e),
|
||||
})?;
|
||||
|
||||
create_response.location.ok_or_else(|| Error::Runtime {
|
||||
message: "Table location is missing from create_empty_table response"
|
||||
.to_string(),
|
||||
})?
|
||||
} else {
|
||||
return Err(Error::Runtime {
|
||||
message: format!("Failed to declare table: {}", e),
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let native_table = NativeTable::create_from_namespace(
|
||||
self.namespace.clone(),
|
||||
@@ -281,7 +308,10 @@ impl Database for LanceNamespaceDatabase {
|
||||
let mut table_id = namespace.to_vec();
|
||||
table_id.push(name.to_string());
|
||||
|
||||
let drop_request = DropTableRequest { id: Some(table_id) };
|
||||
let drop_request = DropTableRequest {
|
||||
id: Some(table_id),
|
||||
..Default::default()
|
||||
};
|
||||
self.namespace
|
||||
.drop_table(drop_request)
|
||||
.await
|
||||
@@ -436,8 +466,7 @@ mod tests {
|
||||
// Create a child namespace first
|
||||
conn.create_namespace(CreateNamespaceRequest {
|
||||
id: Some(vec!["test_ns".into()]),
|
||||
mode: None,
|
||||
properties: None,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("Failed to create namespace");
|
||||
@@ -497,8 +526,7 @@ mod tests {
|
||||
// Create a child namespace first
|
||||
conn.create_namespace(CreateNamespaceRequest {
|
||||
id: Some(vec!["test_ns".into()]),
|
||||
mode: None,
|
||||
properties: None,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("Failed to create namespace");
|
||||
@@ -561,8 +589,7 @@ mod tests {
|
||||
// Create a child namespace first
|
||||
conn.create_namespace(CreateNamespaceRequest {
|
||||
id: Some(vec!["test_ns".into()]),
|
||||
mode: None,
|
||||
properties: None,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("Failed to create namespace");
|
||||
@@ -645,8 +672,7 @@ mod tests {
|
||||
// Create a child namespace first
|
||||
conn.create_namespace(CreateNamespaceRequest {
|
||||
id: Some(vec!["test_ns".into()]),
|
||||
mode: None,
|
||||
properties: None,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("Failed to create namespace");
|
||||
@@ -701,8 +727,7 @@ mod tests {
|
||||
// Create a child namespace first
|
||||
conn.create_namespace(CreateNamespaceRequest {
|
||||
id: Some(vec!["test_ns".into()]),
|
||||
mode: None,
|
||||
properties: None,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("Failed to create namespace");
|
||||
@@ -782,8 +807,7 @@ mod tests {
|
||||
// Create a child namespace first
|
||||
conn.create_namespace(CreateNamespaceRequest {
|
||||
id: Some(vec!["test_ns".into()]),
|
||||
mode: None,
|
||||
properties: None,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("Failed to create namespace");
|
||||
@@ -816,8 +840,7 @@ mod tests {
|
||||
// Create a child namespace first
|
||||
conn.create_namespace(CreateNamespaceRequest {
|
||||
id: Some(vec!["test_ns".into()]),
|
||||
mode: None,
|
||||
properties: None,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("Failed to create namespace");
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use datafusion::prelude::{SessionConfig, SessionContext};
|
||||
use datafusion_catalog::streaming::StreamingTable;
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
|
||||
use datafusion_expr::col;
|
||||
use datafusion_physical_plan::{
|
||||
stream::RecordBatchStreamAdapter, streaming::PartitionStream,
|
||||
SendableRecordBatchStream as DataFusionRecordBatchStream,
|
||||
};
|
||||
use futures::TryStreamExt;
|
||||
use lance_core::ROW_ID;
|
||||
use lance_datafusion::exec::SessionContextExt;
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream},
|
||||
@@ -27,6 +35,51 @@ pub const SRC_ROW_ID_COL: &str = "row_id";
|
||||
|
||||
pub const SPLIT_NAMES_CONFIG_KEY: &str = "split_names";
|
||||
|
||||
pub const DEFAULT_MEMORY_LIMIT: usize = 100 * 1024 * 1024;
|
||||
|
||||
struct OneShotPartitionStream {
|
||||
schema: arrow_schema::SchemaRef,
|
||||
stream: Mutex<Option<DataFusionRecordBatchStream>>,
|
||||
}
|
||||
|
||||
impl OneShotPartitionStream {
|
||||
fn new(schema: arrow_schema::SchemaRef, stream: DataFusionRecordBatchStream) -> Self {
|
||||
Self {
|
||||
schema,
|
||||
stream: Mutex::new(Some(stream)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartitionStream for OneShotPartitionStream {
|
||||
fn schema(&self) -> &arrow_schema::SchemaRef {
|
||||
&self.schema
|
||||
}
|
||||
|
||||
fn execute(&self, _ctx: Arc<datafusion_execution::TaskContext>) -> DataFusionRecordBatchStream {
|
||||
self.stream
|
||||
.lock()
|
||||
.ok()
|
||||
.and_then(|mut stream| stream.take())
|
||||
.unwrap_or_else(|| {
|
||||
Box::pin(RecordBatchStreamAdapter::new(
|
||||
Arc::clone(&self.schema),
|
||||
futures::stream::empty::<
|
||||
std::result::Result<arrow_array::RecordBatch, DataFusionError>,
|
||||
>(),
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OneShotPartitionStream {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OneShotPartitionStream")
|
||||
.field("schema", &self.schema)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Where to store the permutation table
|
||||
#[derive(Debug, Clone, Default)]
|
||||
enum PermutationDestination {
|
||||
@@ -167,10 +220,20 @@ impl PermutationBuilder {
|
||||
&self,
|
||||
data: SendableRecordBatchStream,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let memory_limit = std::env::var("LANCEDB_PERM_BUILDER_MEMORY_LIMIT")
|
||||
.unwrap_or_else(|_| DEFAULT_MEMORY_LIMIT.to_string())
|
||||
.parse::<usize>()
|
||||
.unwrap_or_else(|_| {
|
||||
log::error!(
|
||||
"Failed to parse LANCEDB_PERM_BUILDER_MEMORY_LIMIT, using default: {}",
|
||||
DEFAULT_MEMORY_LIMIT
|
||||
);
|
||||
DEFAULT_MEMORY_LIMIT
|
||||
});
|
||||
let ctx = SessionContext::new_with_config_rt(
|
||||
SessionConfig::default(),
|
||||
RuntimeEnvBuilder::new()
|
||||
.with_memory_limit(100 * 1024 * 1024, 1.0)
|
||||
.with_memory_limit(memory_limit, 1.0)
|
||||
.with_disk_manager_builder(
|
||||
DiskManagerBuilder::default()
|
||||
.with_mode(self.config.temp_dir.to_disk_manager_mode()),
|
||||
@@ -178,12 +241,17 @@ impl PermutationBuilder {
|
||||
.build_arc()
|
||||
.unwrap(),
|
||||
);
|
||||
let df = ctx
|
||||
.read_one_shot(data.into_df_stream())
|
||||
.map_err(|e| Error::Other {
|
||||
message: format!("Failed to setup sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?;
|
||||
let df_stream = data.into_df_stream();
|
||||
let schema = df_stream.schema();
|
||||
let partition = Arc::new(OneShotPartitionStream::new(schema.clone(), df_stream));
|
||||
let table = StreamingTable::try_new(schema, vec![partition]).map_err(|e| Error::Other {
|
||||
message: format!("Failed to create streaming table: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?;
|
||||
let df = ctx.read_table(Arc::new(table)).map_err(|e| Error::Other {
|
||||
message: format!("Failed to setup sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?;
|
||||
let df_stream = df
|
||||
.sort_by(vec![col(SPLIT_ID_COLUMN)])
|
||||
.map_err(|e| Error::Other {
|
||||
|
||||
@@ -171,7 +171,7 @@ impl Shuffler {
|
||||
// This is kind of an annoying limitation but if we allow runt clumps from batches then
|
||||
// clumps will get unaligned and we will mess up the clumps when we do the in-memory
|
||||
// shuffle step. If this is a problem we can probably figure out a better way to do this.
|
||||
if !is_last && batch.num_rows() as u64 % clump_size != 0 {
|
||||
if !is_last && !(batch.num_rows() as u64).is_multiple_of(clump_size) {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"Expected batch size ({}) to be divisible by clump size ({})",
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::{
|
||||
iter,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
|
||||
Arc,
|
||||
};
|
||||
|
||||
use arrow_array::{Array, BooleanArray, RecordBatch, UInt64Array};
|
||||
@@ -158,7 +155,7 @@ impl Splitter {
|
||||
remaining_in_split
|
||||
};
|
||||
|
||||
split_ids.extend(iter::repeat(split_id as u64).take(rows_to_add as usize));
|
||||
split_ids.extend(std::iter::repeat_n(split_id as u64, rows_to_add as usize));
|
||||
if done {
|
||||
// Quit early if we've run out of splits
|
||||
break;
|
||||
@@ -662,7 +659,7 @@ mod tests {
|
||||
assert_eq!(split_batch.num_rows(), total_split_sizes as usize);
|
||||
let mut expected = Vec::with_capacity(total_split_sizes as usize);
|
||||
for (i, size) in expected_split_sizes.iter().enumerate() {
|
||||
expected.extend(iter::repeat(i as u64).take(*size as usize));
|
||||
expected.extend(std::iter::repeat_n(i as u64, *size as usize));
|
||||
}
|
||||
let expected = Arc::new(UInt64Array::from(expected)) as Arc<dyn Array>;
|
||||
|
||||
|
||||
@@ -120,8 +120,13 @@ impl MemoryRegistry {
|
||||
}
|
||||
|
||||
/// A record batch reader that has embeddings applied to it
|
||||
/// This is a wrapper around another record batch reader that applies an embedding function
|
||||
/// when reading from the record batch
|
||||
///
|
||||
/// This is a wrapper around another record batch reader that applies embedding functions
|
||||
/// when reading from the record batch.
|
||||
///
|
||||
/// When multiple embedding functions are defined, they are computed in parallel using
|
||||
/// scoped threads to improve performance. For a single embedding function, computation
|
||||
/// is done inline without threading overhead.
|
||||
pub struct WithEmbeddings<R: RecordBatchReader> {
|
||||
inner: R,
|
||||
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
||||
@@ -235,6 +240,48 @@ impl<R: RecordBatchReader> WithEmbeddings<R> {
|
||||
column_definitions,
|
||||
})
|
||||
}
|
||||
|
||||
fn compute_embeddings_parallel(&self, batch: &RecordBatch) -> Result<Vec<Arc<dyn Array>>> {
|
||||
if self.embeddings.len() == 1 {
|
||||
let (fld, func) = &self.embeddings[0];
|
||||
let src_column =
|
||||
batch
|
||||
.column_by_name(&fld.source_column)
|
||||
.ok_or_else(|| Error::InvalidInput {
|
||||
message: format!("Source column '{}' not found", fld.source_column),
|
||||
})?;
|
||||
return Ok(vec![func.compute_source_embeddings(src_column.clone())?]);
|
||||
}
|
||||
|
||||
// Parallel path: multiple embeddings
|
||||
std::thread::scope(|s| {
|
||||
let handles: Vec<_> = self
|
||||
.embeddings
|
||||
.iter()
|
||||
.map(|(fld, func)| {
|
||||
let src_column = batch.column_by_name(&fld.source_column).ok_or_else(|| {
|
||||
Error::InvalidInput {
|
||||
message: format!("Source column '{}' not found", fld.source_column),
|
||||
}
|
||||
})?;
|
||||
|
||||
let handle =
|
||||
s.spawn(move || func.compute_source_embeddings(src_column.clone()));
|
||||
|
||||
Ok(handle)
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
|
||||
handles
|
||||
.into_iter()
|
||||
.map(|h| {
|
||||
h.join().map_err(|e| Error::Runtime {
|
||||
message: format!("Thread panicked during embedding computation: {:?}", e),
|
||||
})?
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> Iterator for MaybeEmbedded<R> {
|
||||
@@ -262,19 +309,19 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let batch = self.inner.next()?;
|
||||
match batch {
|
||||
Ok(mut batch) => {
|
||||
// todo: parallelize this
|
||||
for (fld, func) in self.embeddings.iter() {
|
||||
let src_column = batch.column_by_name(&fld.source_column).unwrap();
|
||||
let embedding = match func.compute_source_embeddings(src_column.clone()) {
|
||||
Ok(embedding) => embedding,
|
||||
Err(e) => {
|
||||
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
|
||||
"Error computing embedding: {}",
|
||||
e
|
||||
))))
|
||||
}
|
||||
};
|
||||
Ok(batch) => {
|
||||
let embeddings = match self.compute_embeddings_parallel(&batch) {
|
||||
Ok(emb) => emb,
|
||||
Err(e) => {
|
||||
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
|
||||
"Error computing embedding: {}",
|
||||
e
|
||||
))))
|
||||
}
|
||||
};
|
||||
|
||||
let mut batch = batch;
|
||||
for ((fld, _), embedding) in self.embeddings.iter().zip(embeddings.iter()) {
|
||||
let dst_field_name = fld
|
||||
.dest_column
|
||||
.clone()
|
||||
@@ -286,7 +333,7 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
|
||||
embedding.nulls().is_some(),
|
||||
);
|
||||
|
||||
match batch.try_with_column(dst_field.clone(), embedding) {
|
||||
match batch.try_with_column(dst_field.clone(), embedding.clone()) {
|
||||
Ok(b) => batch = b,
|
||||
Err(e) => return Some(Err(e)),
|
||||
};
|
||||
|
||||
@@ -297,10 +297,10 @@ impl IvfPqIndexBuilder {
|
||||
}
|
||||
|
||||
pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
|
||||
if dim % 16 == 0 {
|
||||
if dim.is_multiple_of(16) {
|
||||
// Should be more aggressive than this default.
|
||||
dim / 16
|
||||
} else if dim % 8 == 0 {
|
||||
} else if dim.is_multiple_of(8) {
|
||||
dim / 8
|
||||
} else {
|
||||
log::warn!(
|
||||
|
||||
@@ -25,13 +25,14 @@
|
||||
//!
|
||||
//! ## Crate Features
|
||||
//!
|
||||
//! ### Experimental Features
|
||||
//!
|
||||
//! These features are not enabled by default. They are experimental or in-development features that
|
||||
//! are not yet ready to be released.
|
||||
//!
|
||||
//! - `remote` - Enable remote client to connect to LanceDB cloud. This is not yet fully implemented
|
||||
//! and should not be enabled.
|
||||
//! - `aws` - Enable AWS S3 object store support.
|
||||
//! - `dynamodb` - Enable DynamoDB manifest store support.
|
||||
//! - `azure` - Enable Azure Blob Storage object store support.
|
||||
//! - `gcs` - Enable Google Cloud Storage object store support.
|
||||
//! - `oss` - Enable Alibaba Cloud OSS object store support.
|
||||
//! - `remote` - Enable remote client to connect to LanceDB cloud.
|
||||
//! - `huggingface` - Enable HuggingFace Hub integration for loading datasets from the Hub.
|
||||
//! - `fp16kernels` - Enable FP16 kernels for faster vector search on CPU.
|
||||
//!
|
||||
//! ### Quick Start
|
||||
//!
|
||||
@@ -50,17 +51,15 @@
|
||||
//! - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
|
||||
//! - `db://dbname` - Lance Cloud
|
||||
//!
|
||||
//! You can also use [`ConnectOptions`] to configure the connection to the database.
|
||||
//! You can also use [`ConnectBuilder`] to configure the connection to the database.
|
||||
//!
|
||||
//! ```rust
|
||||
//! use object_store::aws::AwsCredential;
|
||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
//! let db = lancedb::connect("data/sample-lancedb")
|
||||
//! .aws_creds(AwsCredential {
|
||||
//! key_id: "some_key".to_string(),
|
||||
//! secret_key: "some_secret".to_string(),
|
||||
//! token: None,
|
||||
//! })
|
||||
//! .storage_options([
|
||||
//! ("aws_access_key_id", "some_key"),
|
||||
//! ("aws_secret_access_key", "some_secret"),
|
||||
//! ])
|
||||
//! .execute()
|
||||
//! .await
|
||||
//! .unwrap();
|
||||
|
||||
@@ -1718,8 +1718,7 @@ mod tests {
|
||||
let namespace = vec!["test_ns".to_string()];
|
||||
conn.create_namespace(CreateNamespaceRequest {
|
||||
id: Some(namespace.clone()),
|
||||
mode: None,
|
||||
properties: None,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("Failed to create namespace");
|
||||
@@ -1744,8 +1743,7 @@ mod tests {
|
||||
let list_response = conn
|
||||
.list_tables(ListTablesRequest {
|
||||
id: Some(namespace.clone()),
|
||||
page_token: None,
|
||||
limit: None,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("Failed to list tables");
|
||||
@@ -1756,8 +1754,7 @@ mod tests {
|
||||
let list_response = namespace_client
|
||||
.list_tables(ListTablesRequest {
|
||||
id: Some(namespace.clone()),
|
||||
page_token: None,
|
||||
limit: None,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1797,8 +1794,7 @@ mod tests {
|
||||
let namespace = vec!["multi_table_ns".to_string()];
|
||||
conn.create_namespace(CreateNamespaceRequest {
|
||||
id: Some(namespace.clone()),
|
||||
mode: None,
|
||||
properties: None,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("Failed to create namespace");
|
||||
@@ -1823,8 +1819,7 @@ mod tests {
|
||||
let list_response = conn
|
||||
.list_tables(ListTablesRequest {
|
||||
id: Some(namespace.clone()),
|
||||
page_token: None,
|
||||
limit: None,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -204,6 +204,7 @@ pub struct RemoteTable<S: HttpSend = Sender> {
|
||||
server_version: ServerVersion,
|
||||
|
||||
version: RwLock<Option<u64>>,
|
||||
location: RwLock<Option<String>>,
|
||||
}
|
||||
|
||||
impl<S: HttpSend> RemoteTable<S> {
|
||||
@@ -221,6 +222,7 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
identifier,
|
||||
server_version,
|
||||
version: RwLock::new(None),
|
||||
location: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -639,6 +641,7 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
struct TableDescription {
|
||||
version: u64,
|
||||
schema: JsonSchema,
|
||||
location: Option<String>,
|
||||
}
|
||||
|
||||
impl<S: HttpSend> std::fmt::Display for RemoteTable<S> {
|
||||
@@ -667,6 +670,7 @@ mod test_utils {
|
||||
identifier: name,
|
||||
server_version: version.map(ServerVersion).unwrap_or_default(),
|
||||
version: RwLock::new(None),
|
||||
location: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1088,6 +1092,17 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
|
||||
}
|
||||
}
|
||||
Index::IvfRq(index) => {
|
||||
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_RQ".to_string());
|
||||
body[METRIC_TYPE_KEY] =
|
||||
serde_json::Value::String(index.distance_type.to_string().to_lowercase());
|
||||
if let Some(num_partitions) = index.num_partitions {
|
||||
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
|
||||
}
|
||||
if let Some(num_bits) = index.num_bits {
|
||||
body["num_bits"] = serde_json::Value::Number(num_bits.into());
|
||||
}
|
||||
}
|
||||
Index::BTree(_) => {
|
||||
body[INDEX_TYPE_KEY] = serde_json::Value::String("BTREE".to_string());
|
||||
}
|
||||
@@ -1450,8 +1465,28 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
message: "table_definition is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
fn dataset_uri(&self) -> &str {
|
||||
"NOT_SUPPORTED"
|
||||
async fn uri(&self) -> Result<String> {
|
||||
// Check if we already have the location cached
|
||||
{
|
||||
let location = self.location.read().await;
|
||||
if let Some(ref loc) = *location {
|
||||
return Ok(loc.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch from server via describe
|
||||
let description = self.describe().await?;
|
||||
let location = description.location.ok_or_else(|| Error::NotSupported {
|
||||
message: "Table URI not supported by the server".into(),
|
||||
})?;
|
||||
|
||||
// Cache the location for future use
|
||||
{
|
||||
let mut cached_location = self.location.write().await;
|
||||
*cached_location = Some(location.clone());
|
||||
}
|
||||
|
||||
Ok(location)
|
||||
}
|
||||
|
||||
async fn storage_options(&self) -> Option<HashMap<String, String>> {
|
||||
@@ -3321,4 +3356,69 @@ mod tests {
|
||||
let result = table.drop_columns(&["old_col1", "old_col2"]).await.unwrap();
|
||||
assert_eq!(result.version, 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_uri() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/describe/");
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"version": 1, "schema": {"fields": []}, "location": "s3://bucket/path/to/table"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let uri = table.uri().await.unwrap();
|
||||
assert_eq!(uri, "s3://bucket/path/to/table");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_uri_missing_location() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/describe/");
|
||||
|
||||
// Server returns response without location field
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"version": 1, "schema": {"fields": []}}"#)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let result = table.uri().await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(&result, Err(Error::NotSupported { .. })));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_uri_caching() {
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
let call_count = Arc::new(AtomicUsize::new(0));
|
||||
let call_count_clone = call_count.clone();
|
||||
|
||||
let table = Table::new_with_handler("my_table", move |request| {
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/describe/");
|
||||
call_count_clone.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
r#"{"version": 1, "schema": {"fields": []}, "location": "gs://bucket/table"}"#,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
// First call should fetch from server
|
||||
let uri1 = table.uri().await.unwrap();
|
||||
assert_eq!(uri1, "gs://bucket/table");
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 1);
|
||||
|
||||
// Second call should use cached value
|
||||
let uri2 = table.uri().await.unwrap();
|
||||
assert_eq!(uri2, "gs://bucket/table");
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 1); // Still 1, no new call
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,10 +40,10 @@ use lance_index::vector::pq::PQBuildParams;
|
||||
use lance_index::vector::sq::builder::SQBuildParams;
|
||||
use lance_index::DatasetIndexExt;
|
||||
use lance_index::IndexType;
|
||||
use lance_io::object_store::LanceNamespaceStorageOptionsProvider;
|
||||
use lance_io::object_store::{LanceNamespaceStorageOptionsProvider, StorageOptionsAccessor};
|
||||
use lance_namespace::models::{
|
||||
QueryTableRequest as NsQueryTableRequest, QueryTableRequestFullTextQuery,
|
||||
QueryTableRequestVector, StringFtsQuery,
|
||||
QueryTableRequest as NsQueryTableRequest, QueryTableRequestColumns,
|
||||
QueryTableRequestFullTextQuery, QueryTableRequestVector, StringFtsQuery,
|
||||
};
|
||||
use lance_namespace::LanceNamespace;
|
||||
use lance_table::format::Manifest;
|
||||
@@ -608,8 +608,8 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
async fn list_versions(&self) -> Result<Vec<Version>>;
|
||||
/// Get the table definition.
|
||||
async fn table_definition(&self) -> Result<TableDefinition>;
|
||||
/// Get the table URI
|
||||
fn dataset_uri(&self) -> &str;
|
||||
/// Get the table URI (storage location)
|
||||
async fn uri(&self) -> Result<String>;
|
||||
/// Get the storage options used when opening this table, if any.
|
||||
async fn storage_options(&self) -> Option<HashMap<String, String>>;
|
||||
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
|
||||
@@ -1317,11 +1317,12 @@ impl Table {
|
||||
self.inner.list_indices().await
|
||||
}
|
||||
|
||||
/// Get the underlying dataset URI
|
||||
/// Get the table URI (storage location)
|
||||
///
|
||||
/// Warning: This is an internal API and the return value is subject to change.
|
||||
pub fn dataset_uri(&self) -> &str {
|
||||
self.inner.dataset_uri()
|
||||
/// Returns the full storage location of the table (e.g., S3/GCS path).
|
||||
/// For remote tables, this fetches the location from the server via describe.
|
||||
pub async fn uri(&self) -> Result<String> {
|
||||
self.inner.uri().await
|
||||
}
|
||||
|
||||
/// Get the storage options used when opening this table, if any.
|
||||
@@ -1424,7 +1425,10 @@ impl Table {
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let unioned = Arc::new(UnionExec::new(projected_plans));
|
||||
let unioned = UnionExec::try_new(projected_plans).map_err(|e| Error::Other {
|
||||
message: format!("Failed to union query plans: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?;
|
||||
// We require 1 partition in the final output
|
||||
let repartitioned = RepartitionExec::try_new(
|
||||
unioned,
|
||||
@@ -1665,18 +1669,14 @@ impl NativeTable {
|
||||
|
||||
// Use DatasetBuilder::from_namespace which automatically fetches location
|
||||
// and storage options from the namespace
|
||||
let builder = DatasetBuilder::from_namespace(
|
||||
namespace_client.clone(),
|
||||
table_id,
|
||||
false, // Don't ignore namespace storage options
|
||||
)
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
lance::Error::Namespace { source, .. } => Error::Runtime {
|
||||
message: format!("Failed to get table info from namespace: {:?}", source),
|
||||
},
|
||||
source => Error::Lance { source },
|
||||
})?;
|
||||
let builder = DatasetBuilder::from_namespace(namespace_client.clone(), table_id)
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
lance::Error::Namespace { source, .. } => Error::Runtime {
|
||||
message: format!("Failed to get table info from namespace: {:?}", source),
|
||||
},
|
||||
source => Error::Lance { source },
|
||||
})?;
|
||||
|
||||
let dataset = builder
|
||||
.with_read_params(params)
|
||||
@@ -1880,7 +1880,13 @@ impl NativeTable {
|
||||
let store_params = params
|
||||
.store_params
|
||||
.get_or_insert_with(ObjectStoreParams::default);
|
||||
store_params.storage_options_provider = Some(storage_options_provider);
|
||||
let accessor = match store_params.storage_options().cloned() {
|
||||
Some(options) => {
|
||||
StorageOptionsAccessor::with_initial_and_provider(options, storage_options_provider)
|
||||
}
|
||||
None => StorageOptionsAccessor::with_provider(storage_options_provider),
|
||||
};
|
||||
store_params.storage_options_accessor = Some(Arc::new(accessor));
|
||||
|
||||
// Patch the params if we have a write store wrapper
|
||||
let params = match write_store_wrapper.clone() {
|
||||
@@ -2056,7 +2062,7 @@ impl NativeTable {
|
||||
return provided;
|
||||
}
|
||||
let suggested = suggested_num_sub_vectors(dim);
|
||||
if num_bits.is_some_and(|num_bits| num_bits == 4) && suggested % 2 != 0 {
|
||||
if num_bits.is_some_and(|num_bits| num_bits == 4) && !suggested.is_multiple_of(2) {
|
||||
// num_sub_vectors must be even when 4 bits are used
|
||||
suggested + 1
|
||||
} else {
|
||||
@@ -2348,7 +2354,10 @@ impl NativeTable {
|
||||
// Convert select to columns list
|
||||
let columns = match &vq.base.select {
|
||||
Select::All => None,
|
||||
Select::Columns(cols) => Some(cols.clone()),
|
||||
Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns {
|
||||
column_names: Some(cols.clone()),
|
||||
column_aliases: None,
|
||||
})),
|
||||
Select::Dynamic(_) => {
|
||||
return Err(Error::NotSupported {
|
||||
message:
|
||||
@@ -2401,7 +2410,7 @@ impl NativeTable {
|
||||
with_row_id: Some(vq.base.with_row_id),
|
||||
bypass_vector_index: Some(!vq.use_index),
|
||||
full_text_query,
|
||||
version: None,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
AnyQuery::Query(q) => {
|
||||
@@ -2421,7 +2430,10 @@ impl NativeTable {
|
||||
|
||||
let columns = match &q.select {
|
||||
Select::All => None,
|
||||
Select::Columns(cols) => Some(cols.clone()),
|
||||
Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns {
|
||||
column_names: Some(cols.clone()),
|
||||
column_aliases: None,
|
||||
})),
|
||||
Select::Dynamic(_) => {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Dynamic columns are not supported for server-side query"
|
||||
@@ -2460,18 +2472,11 @@ impl NativeTable {
|
||||
columns,
|
||||
prefilter: Some(q.prefilter),
|
||||
offset: q.offset.map(|o| o as i32),
|
||||
ef: None,
|
||||
refine_factor: None,
|
||||
distance_type: None,
|
||||
nprobes: None,
|
||||
vector_column: None, // No vector column for plain queries
|
||||
with_row_id: Some(q.with_row_id),
|
||||
bypass_vector_index: Some(true), // No vector index for plain queries
|
||||
full_text_query,
|
||||
version: None,
|
||||
fast_search: None,
|
||||
lower_bound: None,
|
||||
upper_bound: None,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3224,8 +3229,8 @@ impl BaseTable for NativeTable {
|
||||
Ok(results.into_iter().flatten().collect())
|
||||
}
|
||||
|
||||
fn dataset_uri(&self) -> &str {
|
||||
self.uri.as_str()
|
||||
async fn uri(&self) -> Result<String> {
|
||||
Ok(self.uri.clone())
|
||||
}
|
||||
|
||||
async fn storage_options(&self) -> Option<HashMap<String, String>> {
|
||||
@@ -3233,7 +3238,7 @@ impl BaseTable for NativeTable {
|
||||
.get()
|
||||
.await
|
||||
.ok()
|
||||
.and_then(|dataset| dataset.storage_options().cloned())
|
||||
.and_then(|dataset| dataset.initial_storage_options().cloned())
|
||||
}
|
||||
|
||||
async fn index_stats(&self, index_name: &str) -> Result<Option<IndexStatistics>> {
|
||||
@@ -3398,7 +3403,6 @@ pub struct FragmentSummaryStats {
|
||||
#[cfg(test)]
|
||||
#[allow(deprecated)]
|
||||
mod tests {
|
||||
use std::iter;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -4015,7 +4019,7 @@ mod tests {
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(offset..(offset + 10))),
|
||||
Arc::new(Int32Array::from_iter_values(iter::repeat(age).take(10))),
|
||||
Arc::new(Int32Array::from_iter_values(std::iter::repeat_n(age, 10))),
|
||||
],
|
||||
)],
|
||||
schema,
|
||||
@@ -5146,7 +5150,13 @@ mod tests {
|
||||
assert_eq!(ns_request.k, 10);
|
||||
assert_eq!(ns_request.offset, Some(5));
|
||||
assert_eq!(ns_request.filter, Some("id > 0".to_string()));
|
||||
assert_eq!(ns_request.columns, Some(vec!["id".to_string()]));
|
||||
assert_eq!(
|
||||
ns_request
|
||||
.columns
|
||||
.as_ref()
|
||||
.and_then(|c| c.column_names.as_ref()),
|
||||
Some(&vec!["id".to_string()])
|
||||
);
|
||||
assert_eq!(ns_request.vector_column, Some("vector".to_string()));
|
||||
assert_eq!(ns_request.distance_type, Some("l2".to_string()));
|
||||
assert!(ns_request.vector.single_vector.is_some());
|
||||
@@ -5187,7 +5197,13 @@ mod tests {
|
||||
assert_eq!(ns_request.k, 20);
|
||||
assert_eq!(ns_request.offset, Some(5));
|
||||
assert_eq!(ns_request.filter, Some("id > 5".to_string()));
|
||||
assert_eq!(ns_request.columns, Some(vec!["id".to_string()]));
|
||||
assert_eq!(
|
||||
ns_request
|
||||
.columns
|
||||
.as_ref()
|
||||
.and_then(|c| c.column_names.as_ref()),
|
||||
Some(&vec!["id".to_string()])
|
||||
);
|
||||
assert_eq!(ns_request.with_row_id, Some(true));
|
||||
assert_eq!(ns_request.bypass_vector_index, Some(true));
|
||||
assert!(ns_request.vector_column.is_none()); // No vector column for plain queries
|
||||
|
||||
@@ -101,6 +101,7 @@ impl DatasetRef {
|
||||
refs::Ref::Version(_, Some(target_ver)) => version != target_ver,
|
||||
refs::Ref::Version(_, None) => true, // No specific version, always checkout
|
||||
refs::Ref::Tag(_) => true, // Always checkout for tags
|
||||
refs::Ref::VersionNumber(target_ver) => version != target_ver,
|
||||
};
|
||||
|
||||
if should_checkout {
|
||||
|
||||
@@ -5,16 +5,19 @@
|
||||
|
||||
use regex::Regex;
|
||||
use std::env;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::process::{Child, ChildStdout, Command, Stdio};
|
||||
use std::process::Stdio;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::{Child, ChildStdout, Command};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::{connect, Connection};
|
||||
use anyhow::{bail, Result};
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use tempfile::{tempdir, TempDir};
|
||||
|
||||
pub struct TestConnection {
|
||||
pub uri: String,
|
||||
pub connection: Connection,
|
||||
pub is_remote: bool,
|
||||
_temp_dir: Option<TempDir>,
|
||||
_process: Option<TestProcess>,
|
||||
}
|
||||
@@ -37,6 +40,56 @@ pub async fn new_test_connection() -> Result<TestConnection> {
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn_stdout_reader(
|
||||
mut stdout: BufReader<ChildStdout>,
|
||||
port_sender: mpsc::Sender<anyhow::Result<String>>,
|
||||
) -> tokio::task::JoinHandle<()> {
|
||||
let print_stdout = env::var("PRINT_LANCEDB_TEST_CONNECTION_SCRIPT_OUTPUT").is_ok();
|
||||
tokio::spawn(async move {
|
||||
let mut line = String::new();
|
||||
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
|
||||
loop {
|
||||
line.clear();
|
||||
let result = stdout.read_line(&mut line).await;
|
||||
if let Err(err) = result {
|
||||
port_sender
|
||||
.send(Err(anyhow!(
|
||||
"error while reading from process output: {}",
|
||||
err
|
||||
)))
|
||||
.await
|
||||
.unwrap();
|
||||
return;
|
||||
} else if result.unwrap() == 0 {
|
||||
port_sender
|
||||
.send(Err(anyhow!(
|
||||
" hit EOF before reading port from process output."
|
||||
)))
|
||||
.await
|
||||
.unwrap();
|
||||
return;
|
||||
}
|
||||
if re.is_match(&line) {
|
||||
let caps = re.captures(&line).unwrap();
|
||||
port_sender.send(Ok(caps[1].to_string())).await.unwrap();
|
||||
break;
|
||||
}
|
||||
}
|
||||
loop {
|
||||
line.clear();
|
||||
match stdout.read_line(&mut line).await {
|
||||
Err(_) => return,
|
||||
Ok(0) => return,
|
||||
Ok(_size) => {
|
||||
if print_stdout {
|
||||
print!("{}", line);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
|
||||
let temp_dir = tempdir()?;
|
||||
let data_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
@@ -57,38 +110,25 @@ async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
|
||||
child: child_result.unwrap(),
|
||||
};
|
||||
let stdout = BufReader::new(process.child.stdout.take().unwrap());
|
||||
let port = read_process_port(stdout)?;
|
||||
let (port_sender, mut port_receiver) = mpsc::channel(5);
|
||||
let _reader = spawn_stdout_reader(stdout, port_sender).await;
|
||||
let port = match port_receiver.recv().await {
|
||||
None => bail!("Unable to determine the port number used by the phalanx process we spawned, because the reader thread was closed too soon."),
|
||||
Some(Err(err)) => bail!("Unable to determine the port number used by the phalanx process we spawned, because of an error, {}", err),
|
||||
Some(Ok(port)) => port,
|
||||
};
|
||||
let uri = "db://test";
|
||||
let host_override = format!("http://localhost:{}", port);
|
||||
let connection = create_new_connection(uri, &host_override).await?;
|
||||
Ok(TestConnection {
|
||||
uri: uri.to_string(),
|
||||
connection,
|
||||
is_remote: true,
|
||||
_temp_dir: Some(temp_dir),
|
||||
_process: Some(process),
|
||||
})
|
||||
}
|
||||
|
||||
fn read_process_port(mut stdout: BufReader<ChildStdout>) -> Result<String> {
|
||||
let mut line = String::new();
|
||||
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
|
||||
loop {
|
||||
let result = stdout.read_line(&mut line);
|
||||
if let Err(err) = result {
|
||||
bail!(format!(
|
||||
"read_process_port: error while reading from process output: {}",
|
||||
err
|
||||
));
|
||||
} else if result.unwrap() == 0 {
|
||||
bail!("read_process_port: hit EOF before reading port from process output.");
|
||||
}
|
||||
if re.is_match(&line) {
|
||||
let caps = re.captures(&line).unwrap();
|
||||
return Ok(caps[1].to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
async fn create_new_connection(uri: &str, host_override: &str) -> crate::error::Result<Connection> {
|
||||
connect(uri)
|
||||
@@ -114,6 +154,7 @@ async fn new_local_connection() -> Result<TestConnection> {
|
||||
Ok(TestConnection {
|
||||
uri: uri.to_string(),
|
||||
connection,
|
||||
is_remote: false,
|
||||
_temp_dir: Some(temp_dir),
|
||||
_process: None,
|
||||
})
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::{HashMap, HashSet},
|
||||
iter::repeat,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
@@ -268,9 +267,10 @@ fn create_some_records() -> Result<impl IntoArrow> {
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
|
||||
Arc::new(StringArray::from_iter(
|
||||
repeat(Some("hello world".to_string())).take(TOTAL),
|
||||
)),
|
||||
Arc::new(StringArray::from_iter(std::iter::repeat_n(
|
||||
Some("hello world".to_string()),
|
||||
TOTAL,
|
||||
))),
|
||||
],
|
||||
)
|
||||
.unwrap()]
|
||||
|
||||
253
rust/lancedb/tests/embeddings_parallel_test.rs
Normal file
253
rust/lancedb/tests/embeddings_parallel_test.rs
Normal file
@@ -0,0 +1,253 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use arrow::buffer::NullBuffer;
|
||||
use arrow_array::{
|
||||
Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
|
||||
};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use lancedb::{
|
||||
embeddings::{EmbeddingDefinition, EmbeddingFunction, MaybeEmbedded, WithEmbeddings},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SlowMockEmbed {
|
||||
name: String,
|
||||
dim: usize,
|
||||
delay_ms: u64,
|
||||
call_count: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl SlowMockEmbed {
|
||||
pub fn new(name: String, dim: usize, delay_ms: u64) -> Self {
|
||||
Self {
|
||||
name,
|
||||
dim,
|
||||
delay_ms,
|
||||
call_count: Arc::new(AtomicUsize::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_call_count(&self) -> usize {
|
||||
self.call_count.load(Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingFunction for SlowMockEmbed {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn source_type(&self) -> Result<Cow<'_, DataType>> {
|
||||
Ok(Cow::Owned(DataType::Utf8))
|
||||
}
|
||||
|
||||
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
|
||||
Ok(Cow::Owned(DataType::new_fixed_size_list(
|
||||
DataType::Float32,
|
||||
self.dim as _,
|
||||
true,
|
||||
)))
|
||||
}
|
||||
|
||||
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||
// Simulate slow embedding computation
|
||||
std::thread::sleep(Duration::from_millis(self.delay_ms));
|
||||
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
let len = source.len();
|
||||
let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim]));
|
||||
let field = Field::new("item", inner.data_type().clone(), false);
|
||||
let arr = FixedSizeListArray::new(
|
||||
Arc::new(field),
|
||||
self.dim as _,
|
||||
inner,
|
||||
Some(NullBuffer::new_valid(len)),
|
||||
);
|
||||
|
||||
Ok(Arc::new(arr))
|
||||
}
|
||||
|
||||
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_batch() -> Result<RecordBatch> {
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
|
||||
let text = StringArray::from(vec!["hello", "world"]);
|
||||
RecordBatch::try_new(schema, vec![Arc::new(text)]).map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to create test batch: {}", e),
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_embedding_fast_path() {
|
||||
// Single embedding should execute without spawning threads
|
||||
let batch = create_test_batch().unwrap();
|
||||
let schema = batch.schema();
|
||||
|
||||
let embed = Arc::new(SlowMockEmbed::new("test".to_string(), 2, 10));
|
||||
let embedding_def = EmbeddingDefinition::new("text", "test", Some("embedding"));
|
||||
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
let embeddings = vec![(embedding_def, embed.clone() as Arc<dyn EmbeddingFunction>)];
|
||||
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
|
||||
|
||||
let result = with_embeddings.next().unwrap().unwrap();
|
||||
assert!(result.column_by_name("embedding").is_some());
|
||||
assert_eq!(embed.get_call_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_embeddings_parallel() {
|
||||
// Multiple embeddings should execute in parallel
|
||||
let batch = create_test_batch().unwrap();
|
||||
let schema = batch.schema();
|
||||
|
||||
let embed1 = Arc::new(SlowMockEmbed::new("embed1".to_string(), 2, 100));
|
||||
let embed2 = Arc::new(SlowMockEmbed::new("embed2".to_string(), 3, 100));
|
||||
let embed3 = Arc::new(SlowMockEmbed::new("embed3".to_string(), 4, 100));
|
||||
|
||||
let def1 = EmbeddingDefinition::new("text", "embed1", Some("emb1"));
|
||||
let def2 = EmbeddingDefinition::new("text", "embed2", Some("emb2"));
|
||||
let def3 = EmbeddingDefinition::new("text", "embed3", Some("emb3"));
|
||||
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
let embeddings = vec![
|
||||
(def1, embed1.clone() as Arc<dyn EmbeddingFunction>),
|
||||
(def2, embed2.clone() as Arc<dyn EmbeddingFunction>),
|
||||
(def3, embed3.clone() as Arc<dyn EmbeddingFunction>),
|
||||
];
|
||||
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
|
||||
|
||||
let result = with_embeddings.next().unwrap().unwrap();
|
||||
|
||||
// Verify all embedding columns are present
|
||||
assert!(result.column_by_name("emb1").is_some());
|
||||
assert!(result.column_by_name("emb2").is_some());
|
||||
assert!(result.column_by_name("emb3").is_some());
|
||||
|
||||
// Verify all embeddings were computed
|
||||
assert_eq!(embed1.get_call_count(), 1);
|
||||
assert_eq!(embed2.get_call_count(), 1);
|
||||
assert_eq!(embed3.get_call_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_column_order_preserved() {
|
||||
// Verify that embedding columns are added in the same order as definitions
|
||||
let batch = create_test_batch().unwrap();
|
||||
let schema = batch.schema();
|
||||
|
||||
let embed1 = Arc::new(SlowMockEmbed::new("embed1".to_string(), 2, 10));
|
||||
let embed2 = Arc::new(SlowMockEmbed::new("embed2".to_string(), 3, 10));
|
||||
let embed3 = Arc::new(SlowMockEmbed::new("embed3".to_string(), 4, 10));
|
||||
|
||||
let def1 = EmbeddingDefinition::new("text", "embed1", Some("first"));
|
||||
let def2 = EmbeddingDefinition::new("text", "embed2", Some("second"));
|
||||
let def3 = EmbeddingDefinition::new("text", "embed3", Some("third"));
|
||||
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
let embeddings = vec![
|
||||
(def1, embed1 as Arc<dyn EmbeddingFunction>),
|
||||
(def2, embed2 as Arc<dyn EmbeddingFunction>),
|
||||
(def3, embed3 as Arc<dyn EmbeddingFunction>),
|
||||
];
|
||||
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
|
||||
|
||||
let result = with_embeddings.next().unwrap().unwrap();
|
||||
let result_schema = result.schema();
|
||||
|
||||
// Original column is first
|
||||
assert_eq!(result_schema.field(0).name(), "text");
|
||||
// Embedding columns follow in order
|
||||
assert_eq!(result_schema.field(1).name(), "first");
|
||||
assert_eq!(result_schema.field(2).name(), "second");
|
||||
assert_eq!(result_schema.field(3).name(), "third");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_error_propagation() {
|
||||
// Test that errors from embedding computation are properly propagated
|
||||
#[derive(Debug)]
|
||||
struct FailingEmbed {
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl EmbeddingFunction for FailingEmbed {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn source_type(&self) -> Result<Cow<'_, DataType>> {
|
||||
Ok(Cow::Owned(DataType::Utf8))
|
||||
}
|
||||
|
||||
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
|
||||
Ok(Cow::Owned(DataType::new_fixed_size_list(
|
||||
DataType::Float32,
|
||||
2,
|
||||
true,
|
||||
)))
|
||||
}
|
||||
|
||||
fn compute_source_embeddings(&self, _source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||
Err(Error::Runtime {
|
||||
message: "Intentional failure".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
let batch = create_test_batch().unwrap();
|
||||
let schema = batch.schema();
|
||||
|
||||
let embed = Arc::new(FailingEmbed {
|
||||
name: "failing".to_string(),
|
||||
});
|
||||
let def = EmbeddingDefinition::new("text", "failing", Some("emb"));
|
||||
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
let embeddings = vec![(def, embed as Arc<dyn EmbeddingFunction>)];
|
||||
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
|
||||
|
||||
let result = with_embeddings.next().unwrap();
|
||||
assert!(result.is_err());
|
||||
let err_msg = format!("{}", result.err().unwrap());
|
||||
assert!(err_msg.contains("Intentional failure"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_maybe_embedded_with_no_embeddings() {
|
||||
// Test that MaybeEmbedded::No variant works correctly
|
||||
let batch = create_test_batch().unwrap();
|
||||
let schema = batch.schema();
|
||||
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone());
|
||||
let table_def = lancedb::table::TableDefinition {
|
||||
schema: schema.clone(),
|
||||
column_definitions: vec![lancedb::table::ColumnDefinition {
|
||||
kind: lancedb::table::ColumnKind::Physical,
|
||||
}],
|
||||
};
|
||||
|
||||
let mut maybe_embedded = MaybeEmbedded::try_new(reader, table_def, None).unwrap();
|
||||
|
||||
let result = maybe_embedded.next().unwrap().unwrap();
|
||||
assert_eq!(result.num_columns(), 1);
|
||||
assert_eq!(result.column(0).as_ref(), batch.column(0).as_ref());
|
||||
}
|
||||
Reference in New Issue
Block a user