Compare commits

..

21 Commits

Author SHA1 Message Date
Will Jones
e244d753e5 docs(nodejs): regenerate Table.md for setLsmWriteSpec
JSDoc was updated in the previous commit; regenerate the markdown to match.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-22 15:00:53 -07:00
Will Jones
a69dd0f62b fix: remove primary key constraint from MemWAL bucket sharding
Lance v7.0.0-rc.1 intentionally removed the requirement for
bucket_sharding to match (or even require) the unenforced primary key
column. Update LanceDB to match: drop the PK-related doc comments and
the test assertions that expected rejection when no PK is set or when
the bucket column differs from the PK.

The Rust changes are taken from #3435; this commit additionally applies
the equivalent updates to the Python and TypeScript bindings.

See https://github.com/lance-format/lance/issues/6917

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-22 14:18:58 -07:00
Will Jones
02b112e931 test: ignore latest_version_hint.json in manifest dir assertions
Lance v7.0.0-rc.1 writes _versions/latest_version_hint.json on
non-lexically-ordered stores (e.g. the local filesystem) to speed up
latest-version lookup. The V2-manifest-path tests iterate the whole
_versions directory and asserted every entry matches the manifest
filename pattern; filter to *.manifest files so the hint file is ignored.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-22 12:54:31 -07:00
lancedb automation
05504f280a chore: update lance dependency to v7.0.0-rc.1 2026-05-21 22:55:23 +00:00
Lance Release
1bb7acb74f Bump version: 0.29.1-beta.0 → 0.30.0-beta.0 2026-05-21 21:36:18 +00:00
Lance Release
4ce175276c Bump version: 0.32.1-beta.0 → 0.33.0-beta.0 2026-05-21 21:35:22 +00:00
Justin Miller
4bccb43e56 fix(python): route sync BaseQueryBuilder.to_batches through async path (#3425)
## Summary

Fixes #3424.

`LanceTakeQueryBuilder.to_batches()` raised `AttributeError:
'AsyncTakeQuery' object has no attribute 'execute'`. The inherited
`BaseQueryBuilder.to_batches` called `self._inner.execute(...)`, but
`self._inner` is an `AsyncQueryBase` (Python wrapper) — only its native
inner exposes `execute`. Every other sync builder overrides
`to_batches`, so the bug only surfaced on take-query builders, which
inherit the base unchanged. `take_offsets(...).to_batches()` is broken
for the same reason.

Route the sync wrapper through the async `to_batches` on the background
event loop, so the native `execute` is invoked from inside an awaiting
context (matching how the async path works correctly).

## Repro

```python
import lancedb, pyarrow as pa, tempfile
db = lancedb.connect(tempfile.mkdtemp())
tbl = db.create_table("t", data=pa.table({"a": list(range(100))}))

tbl.take_row_ids([0, 1, 2]).to_arrow()        # works
tbl.search().to_batches()                     # works
list(tbl.take_row_ids([0, 1, 2]).to_batches())  # AttributeError (before)
```

## Test plan

- [x] New regression test `test_take_queries_to_batches` covers
`take_offsets(...).to_batches()`, `take_row_ids(...).to_batches()`, and
the `select(...)` projection — all fail on `main` with the patch
reverted, all pass with the fix.
- [x] `test_take_queries`, `test_query_builder_batches`, and
`test_query_schema` still pass.
- [x] `ruff format --check` and `ruff check` clean on changed files.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 12:11:13 -07:00
Xuanwo
d5dc4c0f06 fix: discover nested vector columns by default (#3423)
LanceDB default vector column discovery only considered top-level
fields, so tables with a single nested vector leaf still required users
to pass an explicit field path. This updates Rust and Python discovery
to recurse into struct fields, return canonical field paths, and
preserve actionable errors when no default or multiple defaults exist.

The explicit nested path flow for index creation and search remains
supported across Rust, Python, and Node, with regression coverage for
single nested vector leaves, multiple candidate leaves, and schemas
without vector leaves.

Closes #3405.
2026-05-21 19:02:41 +08:00
Sean Mackrory
55ae6197c1 fix(python): drop version from Table __repr__ (#3411)
There have been a couple of reports of this function freezing debuggers
because it triggers a network round-trip but is assumed to be extremely
light-weight: https://github.com/lancedb/lancedb/discussions/2853. We'll
just cache the last version we see.

I considered digging into see if we could assume or get the version at
create time or after other operations, but that could be a bit of a
rabbit hole as I'm a bit unfamiliar with this. Claude was having a hard
time of it too 😅 I propose we see how the currently implementation goes
and improve it if people find "unknown" or stale values coming up
disruptively often before improving this further.

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 12:20:46 -07:00
Pragnyan Ramtha
15bd821825 fix(python): check all table pages for db membership (#3395)
## Summary

- Fix `name in db` and `len(db)` for local Python connections with more
than one page of tables.
- Use `list_tables()` pagination instead of deprecated `table_names()`
with its default 10-item page.
- Add regression coverage with 20 tables so later pages are included.

Fixes #2727.

## Validation

- `python3 -m py_compile python/python/lancedb/db.py
python/python/tests/test_db.py`
- No-build Python harness that extracts and executes the edited
`LanceDBConnection` pagination methods: passed
- `uvx ruff check python/python/lancedb/db.py
python/python/tests/test_db.py`
- `uvx ruff format --check python/python/lancedb/db.py
python/python/tests/test_db.py`

Note: `uv run pytest
python/tests/test_db.py::test_db_contains_and_len_include_all_table_name_pages
-q` was attempted first, but it stayed in the broad Rust/PyO3 native
extension build and was stopped before pytest started.
2026-05-20 10:31:10 -07:00
Xuanwo
cf162c8a10 test(python): cover nested FTS field paths (#3418)
Adds regression coverage for Python FTS APIs targeting nested text
leaves, including sync and async match, phrase, and hybrid query paths.
This also locks in the intended error boundary: nested text leaf paths
are valid, while struct containers, non-text leaves, and missing paths
remain rejected.

Fixes #3404.
2026-05-21 00:49:00 +08:00
Xuanwo
2eba7ebd02 fix: return canonical nested index paths (#3413)
Index metadata APIs now resolve stored field ids back to Lance canonical
field paths instead of leaf names, so nested indexes such as
`metadata.user_id` and escaped literal-dot fields round-trip through
`list_indices()`. Native index creation also canonicalizes the input
path before handing it to Lance, keeping local metadata consistent with
the field-path contract while remote responses continue to expose
server-provided canonical columns.

Fixes #3403.
2026-05-21 00:20:47 +08:00
dependabot[bot]
2d5298b6ee chore(deps): bump the rust-minor-patch group across 1 directory with 23 updates (#3382)
Weekly dependabot refresh of `Cargo.lock`.

Dependabot's original PR also raised the lower-bound version
requirements
in `Cargo.toml` (arrow, tokio, aws-sdk-*, etc.) to match the new
lockfile
versions. That forces our library's consumers onto newer minimum
versions and broke the MSRV check, which downgrades aws-sdk-* crates to
verify they still build on Rust 1.91.

Changes from the original:

- Reverted all `Cargo.toml` requirement changes; `Cargo.lock`
regenerated
  with `cargo update` within the existing ranges. The lockfile (and the
  binaries we ship) stays current on security fixes without bumping our
  public minimum versions.
- Set `versioning-strategy: lockfile-only` in `.github/dependabot.yml`
so
  future cargo dependabot PRs only touch `Cargo.lock`.

Note: `aws-lc-rs` stays at 1.16.3 — `nodejs/Cargo.toml` pins it with
`=`,
which `lockfile-only` cannot move; bumping it needs a manual change.

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Will Jones <will.jones127@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Will Jones <willjones127@gmail.com>
2026-05-20 09:09:39 -07:00
Brendan Clement
4cb9147bbf feat(nodejs): add renameTable on Connection (#3386)
Adds `Connection.renameTable` to the Node SDK. Closes #3381.
2026-05-20 09:05:48 -07:00
Xuanwo
54a1982ef1 docs: document Python uv agent workflow (#3417) 2026-05-20 21:35:42 +08:00
Xuanwo
5bfde47a8e fix: support nested field paths in native index creation (#3408)
Native index creation was resolving requested columns through top-level
Arrow schema lookup before handing the request to Lance, which rejected
nested paths and could collapse a nested field to its leaf name. This PR
resolves index targets with Lance field-path semantics, passes the
canonical path through to Lance, and reports indexed columns from field
ids as canonical full paths.

This also removes the Python native FTS guard that rejected dotted paths
so scalar, vector, and FTS index creation share the same nested-field
contract. Related to #3402.
2026-05-20 11:15:15 +08:00
Brendan Clement
049b0c8f09 feat(nodejs): add progress to Table.add (#3398)
### Summary

- Add an optional `progress` callback to `Table.add(data, { progress
})`. Callback fires once per batch written and once more with `done:
true` when the write completes.
- Errors thrown from the user's callback are logged with `console.warn`
and swallowed

### Testing
- npm test 
- ran smoke test script to verify functionality
2026-05-19 18:35:07 -07:00
Vishal Kumar Singh
20556e23a9 docs: add missing Python index classes to API reference (#3392)
Adds three index configuration classes to the Python API Reference that
were missing from the documentation:

- `IvfSq` - IVF Scalar Quantization index
- `IvfRq` - IVF RabitQ Quantization index
- `HnswFlat` - HNSW without quantization (stores raw vectors)

These classes are exported in `lancedb.index.__all__` and have complete
docstrings in the source, but weren't showing up in the rendered docs at
https://lancedb.github.io/lancedb/python/python/#indices-asynchronous.

Closes #1855
2026-05-19 16:06:41 -07:00
Weston Pace
01e272c0b0 fix(rust): match embedding scannable columns by name (#3410)
Fixes #3136.

## Summary

- `WithEmbeddingsScannable::scan_as_stream` matched columns positionally
  against the table schema, so a `CastError` was raised whenever the
  computed batch order differed from the table schema order.
- The mismatch surfaced when `add_columns` added a new physical column
  **after** an embedding column: the table schema became
  `[..., embedding, extra]`, but `compute_embeddings_for_batch` always
  appends embeddings at the end, producing `[..., extra, embedding]`.
  Position 2 then tried to cast e.g. `score: Float64` →
  `embedding: FixedSizeList` and failed.
- Now we look each output column up by name in the result batch, which
  is order-independent. If a non-embedding column required by the table
  schema is missing from the input, we return a clear `InvalidInput`
  error instead of a confusing cast error.

## Reproduction (from the issue)

```text
Table created with:           [id, text, text_vec(embedding)]
add_columns("score")        → schema: [id, text, text_vec, score]
table.add([id, text, score]) → BEFORE: CastError on position 2
                               AFTER:  succeeds, embedding is computed
```

## Tests

-
`data::scannable.rs::test_with_embeddings_scannable_column_added_after_embedding`
  — unit test exercising the exact column-order mismatch via
  `WithEmbeddingsScannable::with_schema`.
-
`data::scannable.rs::test_with_embeddings_scannable_missing_required_column`
  — covers the new "missing column" error path.
- `table::add_data.rs::test_add_with_embeddings_after_add_columns`
  — end-to-end regression test mirroring the reproduction in the issue
  (create table with embedding → `add_columns` → `table.add`).

## Test plan

- [x] `cargo check --quiet --features remote --tests --examples`
- [x] `cargo clippy --quiet --features remote --tests --examples`
- [x] `cargo fmt --all`
- [x] `cargo test --quiet --features remote -p lancedb embedding` — 18
tests pass

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 15:08:12 -07:00
Xuanwo
ad1634a0a5 docs: document CI preflight requirements (#3409)
This updates the agent instructions to codify the CI gates that failed
before code review started: PR titles must satisfy Conventional Commits,
and Python changes need root-level ruff format/lint checks.

It also makes the touched-language preflight explicit so mixed Rust,
Python, and TypeScript changes run the checks CI expects before opening
a PR.
2026-05-19 08:51:29 -07:00
Yang Cen
5d1c28922a feat(python): align to_pandas pandas kwargs (#3397)
## Feature

This PR aligns LanceDB Python `to_pandas()` APIs with Lance pandas
conversion capabilities while keeping LanceDB query-specific semantics
intact.

- Adds `blob_mode` and pandas `**kwargs` support to local table
`to_pandas()`.
- Delegates local `LanceTable.to_pandas()` to Lance dataset
`to_pandas(blob_mode=..., **kwargs)`.
- Keeps remote table `to_pandas()` unsupported with
`NotImplementedError`.
- Allows sync and async query `to_pandas()` to forward pandas kwargs
after LanceDB `flatten` and `timeout` handling.

Why we need this feature:

Users can access Lance blob-aware pandas conversion from LanceDB local
tables and can pass PyArrow pandas conversion options through
table/query APIs without losing existing `flatten` or `timeout`
behavior.

How it works:

The table API exposes a `BlobMode` literal type for `lazy`, `bytes`, and
`descriptions`. Local tables call through to the backing Lance dataset.
Query APIs do not add `blob_mode`; they materialize Arrow results, apply
LanceDB flattening when requested, and then call `to_pandas(**kwargs)`.

## Validation

- `uv run --frozen --extra tests pytest
python/tests/test_table.py::test_table_to_pandas_default_matches_arrow
python/tests/test_table.py::test_table_to_pandas_blob_bytes
python/tests/test_table.py::test_table_to_pandas_kwargs
python/tests/test_query.py::test_query_to_pandas_kwargs
python/tests/test_query.py::test_query_timeout
python/tests/test_remote_db.py::test_table_to_pandas_not_supported`
- `uv run --frozen --extra dev ruff check python/lancedb/table.py
python/lancedb/query.py python/lancedb/remote/table.py
python/tests/test_table.py python/tests/test_query.py
python/tests/test_remote_db.py`
- `uv run --frozen --extra tests pytest python/tests/test_table.py
python/tests/test_query.py python/tests/test_remote_db.py`

Note: `python/uv.lock` was intentionally not committed in this branch.
2026-05-19 20:05:51 +08:00
73 changed files with 2831 additions and 2130 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.29.1-beta.0"
current_version = "0.30.0-beta.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -11,6 +11,11 @@ updates:
schedule:
interval: weekly
open-pull-requests-limit: 10
# Only update Cargo.lock, never widen/raise the version requirements in
# Cargo.toml. The goal is keeping the lockfile (and the binaries we ship)
# current on security fixes, not forcing our library's consumers onto
# newer minimum versions.
versioning-strategy: lockfile-only
groups:
rust-minor-patch:
update-types:

View File

@@ -157,7 +157,10 @@ jobs:
npx jest --testEnvironment jest-environment-node-single-context --verbose
macos:
timeout-minutes: 30
runs-on: "macos-14"
# macos-15 ships a newer linker; the older macos-14 linker fails to insert
# branch islands when the debug cdylib's __text section exceeds the 128 MB
# AArch64 B/BL branch range.
runs-on: "macos-15"
defaults:
run:
shell: bash

View File

@@ -205,7 +205,7 @@ jobs:
- name: Delete wheels
run: rm -rf target/wheels
pydantic1x:
timeout-minutes: 30
timeout-minutes: 60
runs-on: "ubuntu-24.04"
defaults:
run:

View File

@@ -233,6 +233,26 @@ jobs:
cargo update -p aws-sdk-sso --precise 1.62.0
cargo update -p aws-sdk-ssooidc --precise 1.63.0
cargo update -p aws-sdk-sts --precise 1.63.0
# aws-runtime/sigv4/credential-types/types and the aws-smithy-*
# crates bumped their MSRV to 1.91.1 in late 2026; pin to the last
# 1.91.0-compatible versions. The order matters — each downgrade
# only succeeds once everything that still pins it at a higher
# version has itself been downgraded.
cargo update -p aws-runtime --precise 1.5.12
cargo update -p aws-types --precise 1.3.9
cargo update -p aws-sigv4 --precise 1.3.5
cargo update -p aws-credential-types --precise 1.2.8
cargo update -p aws-smithy-checksums --precise 0.63.9
cargo update -p aws-smithy-runtime --precise 1.9.3
cargo update -p aws-smithy-http --precise 0.62.4
cargo update -p aws-smithy-eventstream --precise 0.60.12
cargo update -p aws-smithy-http-client --precise 1.1.3
cargo update -p aws-smithy-observability --precise 0.1.4
cargo update -p aws-smithy-query --precise 0.60.8
cargo update -p aws-smithy-runtime-api --precise 1.9.1
cargo update -p aws-smithy-async --precise 1.2.6
cargo update -p aws-smithy-types --precise 1.3.5
cargo update -p aws-smithy-xml --precise 0.60.11
cargo update -p home --precise 0.5.9
- name: cargo +${{ matrix.msrv }} check
env:

View File

@@ -17,9 +17,30 @@ Common commands:
* Run tests: `cargo test --quiet --features remote --tests`
* Run specific test: `cargo test --quiet --features remote -p <package_name> --test <test_name>`
* Lint: `cargo clippy --quiet --features remote --tests --examples`
* Format: `cargo fmt --all`
* Format Rust: `cargo fmt --all`
* Format Python: `ruff format .`
* Lint Python: `ruff check .`
* Bootstrap Python dev env: `cd python && uv run --extra tests --extra dev maturin develop --extras tests,dev`
* Run Python tests: `cd python && uv run --extra tests pytest python/tests -vv --durations=10 -m "not slow and not s3_test"`
* Run specific Python test: `cd python && uv run --extra tests pytest python/tests/<test_file>.py::<test_name> -q`
Before committing changes, run formatting.
For Python validation, prefer the uv-managed environment declared by `python/uv.lock`.
Do not treat system `python`, global `pytest`, or missing editable-install errors as
final blockers; bootstrap or enter the uv environment instead. If `lancedb._lancedb`
is missing or stale, or if Rust/PyO3 binding code changed, rebuild the Python
extension with the bootstrap command above before running tests.
Before committing changes, run formatting for every language you touched. At minimum:
* Rust changes: run `cargo fmt --all`.
* Python changes: run `ruff format .` and `ruff check .` from the repository root,
and run targeted tests through `cd python && uv run ...`.
* TypeScript changes: run the relevant `npm`/`pnpm` lint, format, build, and docs commands in `nodejs`.
Before creating a PR, make sure the PR title follows Conventional Commits, such as
`fix: support nested field paths in native index creation` or
`feat(python): add dataset multiprocessing support`. The semantic-release check uses the
PR title and body as the merge commit message, so a non-conventional PR title will fail CI.
## Coding tips

1177
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -13,20 +13,20 @@ categories = ["database-implementations"]
rust-version = "1.91.0"
[workspace.dependencies]
lance = { "version" = "=7.0.0-beta.13", default-features = false, "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=7.0.0-beta.13", default-features = false, "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=7.0.0-beta.13", default-features = false, "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
lance = { "version" = "=7.0.0-rc.1", default-features = false, "tag" = "v7.0.0-rc.1", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=7.0.0-rc.1", "tag" = "v7.0.0-rc.1", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=7.0.0-rc.1", "tag" = "v7.0.0-rc.1", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=7.0.0-rc.1", "tag" = "v7.0.0-rc.1", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=7.0.0-rc.1", default-features = false, "tag" = "v7.0.0-rc.1", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=7.0.0-rc.1", "tag" = "v7.0.0-rc.1", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=7.0.0-rc.1", "tag" = "v7.0.0-rc.1", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=7.0.0-rc.1", "tag" = "v7.0.0-rc.1", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=7.0.0-rc.1", default-features = false, "tag" = "v7.0.0-rc.1", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=7.0.0-rc.1", "tag" = "v7.0.0-rc.1", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=7.0.0-rc.1", "tag" = "v7.0.0-rc.1", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=7.0.0-rc.1", "tag" = "v7.0.0-rc.1", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=7.0.0-rc.1", "tag" = "v7.0.0-rc.1", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=7.0.0-rc.1", "tag" = "v7.0.0-rc.1", "git" = "https://github.com/lance-format/lance.git" }
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "58.0.0", optional = false }

View File

@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
<dependency>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-core</artifactId>
<version>0.29.1-beta.0</version>
<version>0.30.0-beta.0</version>
</dependency>
```

View File

@@ -441,18 +441,28 @@ Open a table in the database.
```ts
abstract renameTable(
oldName,
currentName,
newName,
namespacePath?): Promise<void>
options?): Promise<void>
```
Rename a table.
Currently only supported by LanceDB Cloud. Local OSS connections and
namespace-backed connections (via [connectNamespace](../functions/connectNamespace.md)) reject with
a "not supported" error.
#### Parameters
* **oldName**: `string`
* **currentName**: `string`
The current name of the table.
* **newName**: `string`
The new name for the table.
* **namespacePath?**: `string`[]
* **options?**: [`RenameTableOptions`](../interfaces/RenameTableOptions.md)
Optional namespace paths. When
`newNamespacePath` is omitted the table stays in `namespacePath`.
#### Returns

View File

@@ -701,15 +701,11 @@ LSM-style write path for future `mergeInsert` calls.
`LsmWriteSpec` chooses one of three sharding strategies via `specType`:
- `"bucket"` — hash-bucket writes by the single-column unenforced primary
key (`column` and `numBuckets` required).
- `"bucket"` — hash-bucket writes by a scalar `column` (`column` and
`numBuckets` required).
- `"identity"` — shard by the raw value of a scalar `column`.
- `"unsharded"` — route every write to a single shard.
All variants require the table to have an unenforced primary key
([Table#setUnenforcedPrimaryKey](Table.md#setunenforcedprimarykey)); bucket sharding additionally
requires it to be the single column being bucketed.
#### Parameters
* **spec**: [`LsmWriteSpec`](../interfaces/LsmWriteSpec.md)
@@ -722,7 +718,6 @@ requires it to be the single column being bucketed.
#### Example
```ts
await table.setUnenforcedPrimaryKey("id");
await table.setLsmWriteSpec({
specType: "bucket",
column: "id",

View File

@@ -12,7 +12,6 @@
## Enumerations
- [FullTextQueryType](enumerations/FullTextQueryType.md)
- [OAuthFlowType](enumerations/OAuthFlowType.md)
- [Occur](enumerations/Occur.md)
- [Operator](enumerations/Operator.md)
@@ -83,13 +82,12 @@
- [ListNamespacesResponse](interfaces/ListNamespacesResponse.md)
- [LsmWriteSpec](interfaces/LsmWriteSpec.md)
- [MergeResult](interfaces/MergeResult.md)
- [NativeOAuthConfig](interfaces/NativeOAuthConfig.md)
- [OAuthConfig](interfaces/OAuthConfig.md)
- [OpenTableOptions](interfaces/OpenTableOptions.md)
- [OptimizeOptions](interfaces/OptimizeOptions.md)
- [OptimizeStats](interfaces/OptimizeStats.md)
- [QueryExecutionOptions](interfaces/QueryExecutionOptions.md)
- [RemovalStats](interfaces/RemovalStats.md)
- [RenameTableOptions](interfaces/RenameTableOptions.md)
- [RestNamespaceConfig](interfaces/RestNamespaceConfig.md)
- [RetryConfig](interfaces/RetryConfig.md)
- [ScannableOptions](interfaces/ScannableOptions.md)
@@ -107,6 +105,7 @@
- [UpdateResult](interfaces/UpdateResult.md)
- [Version](interfaces/Version.md)
- [WriteExecutionOptions](interfaces/WriteExecutionOptions.md)
- [WriteProgress](interfaces/WriteProgress.md)
## Type Aliases

View File

@@ -19,3 +19,39 @@ mode: "append" | "overwrite";
If "append" (the default) then the new data will be added to the table
If "overwrite" then the new data will replace the existing data in the table.
***
### progress()
```ts
progress: (progress) => void;
```
Optional callback invoked periodically with write progress.
The callback is fired once per batch written and once more with
`done: true` when the write completes. Calls are dispatched
asynchronously to the JS event loop and never block the write — a slow
callback will queue events rather than back-pressure the writer.
Errors thrown from the callback are logged with `console.warn` and
swallowed — they do not abort the write.
#### Parameters
* **progress**: [`WriteProgress`](WriteProgress.md)
#### Returns
`void`
#### Example
```ts
await table.add(data, {
progress: (p) => {
console.log(`${p.outputRows}/${p.totalRows ?? "?"} rows`);
},
});
```

View File

@@ -64,18 +64,6 @@ client used by manifest-enabled native connections.
***
### oauthConfig?
```ts
optional oauthConfig: NativeOAuthConfig;
```
(For LanceDB cloud only): OAuth configuration for IdP-based
authentication (e.g., Azure Entra ID). When set, token acquisition
and refresh are handled entirely in Rust.
***
### readConsistencyInterval?
```ts

View File

@@ -0,0 +1,29 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / RenameTableOptions
# Interface: RenameTableOptions
## Properties
### namespacePath?
```ts
optional namespacePath: string[];
```
The namespace path of the table being renamed. Defaults to the root
namespace (`[]`) when omitted.
***
### newNamespacePath?
```ts
optional newNamespacePath: string[];
```
The namespace path to move the table to as part of the rename. When
omitted the table stays in `namespacePath`.

View File

@@ -0,0 +1,84 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / WriteProgress
# Interface: WriteProgress
Progress snapshot for a write operation, delivered to the `progress`
callback passed to [Table.add](../classes/Table.md#add).
## Properties
### activeTasks
```ts
activeTasks: number;
```
Number of parallel write tasks currently in flight.
***
### done
```ts
done: boolean;
```
`true` for the final callback; `false` otherwise.
***
### elapsedSeconds
```ts
elapsedSeconds: number;
```
Wall-clock seconds since the write started.
***
### outputBytes
```ts
outputBytes: number;
```
Number of bytes written so far.
***
### outputRows
```ts
outputRows: number;
```
Number of rows written so far.
***
### totalRows?
```ts
optional totalRows: number;
```
Total rows expected, when the input source reports it.
Always set on the final callback (the one with `done: true`), falling
back to the actual number of rows written when the source could not
report a row count up front.
***
### totalTasks
```ts
totalTasks: number;
```
Total number of parallel write tasks (the write parallelism).

View File

@@ -166,6 +166,12 @@ lists the indices that LanceDb supports.
::: lancedb.index.IvfFlat
::: lancedb.index.IvfSq
::: lancedb.index.IvfRq
::: lancedb.index.HnswFlat
::: lancedb.table.IndexStatistics
## Querying (Asynchronous)

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.29.1-beta.0</version>
<version>0.30.0-beta.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.29.1-beta.0</version>
<version>0.30.0-beta.0</version>
<packaging>pom</packaging>
<name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description>
@@ -28,7 +28,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version>
<lance-core.version>7.0.0-beta.13</lance-core.version>
<lance-core.version>7.0.0-rc.1</lance-core.version>
<spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.29.1-beta.0"
version = "0.30.0-beta.0"
publish = false
license.workspace = true
description.workspace = true

View File

@@ -7,6 +7,13 @@ import * as tmp from "tmp";
import { Connection, Table, connect, connectNamespace } from "../lancedb";
import { LocalTable } from "../lancedb/table";
// The `_versions` directory may also contain `latest_version_hint.json`, which
// Lance writes on non-lexically-ordered stores (e.g. the local filesystem) to
// speed up latest-version lookup. Only assert on manifest files.
function manifestFiles(versionsDir: string): string[] {
return readdirSync(versionsDir).filter((f) => f.endsWith(".manifest"));
}
describe("when connecting", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
@@ -47,6 +54,14 @@ describe("given a connection", () => {
await db.close();
expect(db.isOpen()).toBe(false);
await expect(db.tableNames()).rejects.toThrow("Connection is closed");
await expect(db.renameTable("a", "b")).rejects.toThrow(
"Connection is closed",
);
});
it("should report renameTable as unsupported on an OSS connection", async () => {
await db.createTable("a", [{ id: 1 }]);
await expect(db.renameTable("a", "b")).rejects.toThrow(/not supported/);
});
it("should be able to create a table from an object arg `createTable(options)`, or args `createTable(name, data, options)`", async () => {
let tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]);
@@ -81,16 +96,6 @@ describe("given a connection", () => {
await db.createTable("test4", [{ id: 1 }, { id: 2 }]);
});
it("should expose renameTable and reject on OSS listing DB", async () => {
await db.createTable("old_name", [{ id: 1 }]);
await expect(db.renameTable("old_name", "new_name")).rejects.toThrow(
"rename_table is not supported in LanceDB OSS",
);
await expect(db.tableNames()).resolves.toEqual(["old_name"]);
});
it("should fail if creating table twice, unless overwrite is true", async () => {
let tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]);
await expect(tbl.countRows()).resolves.toBe(2);
@@ -173,7 +178,7 @@ describe("given a connection", () => {
let manifestDir =
tmpDir.name + "/test_manifest_paths_v2_empty.lance/_versions";
readdirSync(manifestDir).forEach((file) => {
manifestFiles(manifestDir).forEach((file) => {
expect(file).toMatch(/^\d{20}\.manifest$/);
});
@@ -182,7 +187,7 @@ describe("given a connection", () => {
})) as LocalTable;
expect(await table.usesV2ManifestPaths()).toBe(true);
manifestDir = tmpDir.name + "/test_manifest_paths_v2.lance/_versions";
readdirSync(manifestDir).forEach((file) => {
manifestFiles(manifestDir).forEach((file) => {
expect(file).toMatch(/^\d{20}\.manifest$/);
});
});
@@ -201,14 +206,14 @@ describe("given a connection", () => {
const manifestDir =
tmpDir.name + "/test_manifest_path_migration.lance/_versions";
readdirSync(manifestDir).forEach((file) => {
manifestFiles(manifestDir).forEach((file) => {
expect(file).toMatch(/^\d\.manifest$/);
});
await table.migrateManifestPathsV2();
expect(await table.usesV2ManifestPaths()).toBe(true);
readdirSync(manifestDir).forEach((file) => {
manifestFiles(manifestDir).forEach((file) => {
expect(file).toMatch(/^\d{20}\.manifest$/);
});
});

View File

@@ -617,4 +617,68 @@ describe("remote connection", () => {
);
});
});
describe("renameTable", () => {
async function captureRenameRequest(
call: (db: Connection) => Promise<void>,
): Promise<{ url: string; body: Record<string, unknown> }> {
let captured: { url: string; body: Record<string, unknown> } | undefined;
await withMockDatabase((req, res) => {
let raw = "";
req.on("data", (chunk) => {
raw += chunk;
});
req.on("end", () => {
captured = {
url: req.url ?? "",
body: raw ? JSON.parse(raw) : {},
};
res.writeHead(200, { "Content-Type": "application/json" }).end("");
});
}, call);
if (!captured) {
throw new Error("mock server never saw a request");
}
return captured;
}
it("sends rename request for a table in the root namespace", async () => {
const { url, body } = await captureRenameRequest(async (db) => {
await db.renameTable("table1", "table2");
});
expect(url).toBe("/v1/table/table1/rename/");
// biome-ignore lint/style/useNamingConvention: snake_case mandated by the server wire format
expect(body).toEqual({ new_table_name: "table2" });
});
it("omits new_namespace when only the current namespace is supplied", async () => {
// Safe-default check: passing namespacePath alone must not send
// `new_namespace`, so the server keeps the table in its current
// namespace instead of silently moving it to root.
const { url, body } = await captureRenameRequest(async (db) => {
await db.renameTable("table1", "table2", {
namespacePath: ["ns1"],
});
});
expect(url).toBe("/v1/table/ns1$table1/rename/");
// biome-ignore lint/style/useNamingConvention: snake_case mandated by the server wire format
expect(body).toEqual({ new_table_name: "table2" });
});
it("includes new_namespace in the body for a cross-namespace rename", async () => {
const { url, body } = await captureRenameRequest(async (db) => {
await db.renameTable("table1", "table2", {
namespacePath: ["ns1"],
newNamespacePath: ["ns2"],
});
});
expect(url).toBe("/v1/table/ns1$table1/rename/");
expect(body).toEqual({
// biome-ignore lint/style/useNamingConvention: snake_case mandated by the server wire format
new_table_name: "table2",
// biome-ignore lint/style/useNamingConvention: snake_case mandated by the server wire format
new_namespace: ["ns2"],
});
});
});
});

View File

@@ -28,6 +28,7 @@ import {
List,
Schema,
SchemaLike,
Struct,
Type,
Uint8,
Utf8,
@@ -115,6 +116,48 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
await expect(table.countRows()).resolves.toBe(1);
});
it("should invoke the progress callback", async () => {
const events: import("../lancedb").WriteProgress[] = [];
await table.add([{ id: 1 }, { id: 2 }, { id: 3 }], {
progress: (p) => events.push(p),
});
expect(events.length).toBeGreaterThan(0);
const last = events[events.length - 1];
expect(last.done).toBe(true);
// Earlier callbacks must have done=false.
for (const ev of events.slice(0, -1)) {
expect(ev.done).toBe(false);
}
// outputRows reflects the rows added in this call, not table size.
expect(last.outputRows).toBe(3);
// The input source (an array) reports a row count, so totalRows is set.
expect(last.totalRows).toBe(3);
// outputRows is monotonic.
for (let i = 1; i < events.length; i++) {
expect(events[i].outputRows).toBeGreaterThanOrEqual(
events[i - 1].outputRows,
);
}
});
it("should swallow errors thrown from the progress callback", async () => {
const warn = jest
.spyOn(console, "warn")
.mockImplementation(() => undefined);
try {
const res = await table.add([{ id: 1 }, { id: 2 }], {
progress: () => {
throw new Error("callback bomb");
},
});
expect(res.version).toBeGreaterThan(0);
expect(warn).toHaveBeenCalled();
} finally {
warn.mockRestore();
}
});
it("should let me close the table", async () => {
expect(table.isOpen()).toBe(true);
table.close();
@@ -738,6 +781,113 @@ describe("When creating an index", () => {
expect(indices2.length).toBe(0);
});
it("should create and search a nested vector index", async () => {
const db = await connect(tmpDir.name);
const nestedSchema = new Schema([
new Field("id", new Int32(), true),
new Field(
"image",
new Struct([
new Field(
"embedding",
new FixedSizeList(2, new Field("item", new Float32(), true)),
true,
),
]),
true,
),
]);
const nestedTable = await db.createTable(
"nested_vector",
makeArrowTable(
Array.from({ length: 300 }, (_, id) => ({
id,
image: { embedding: [id, id + 1] },
})),
{ schema: nestedSchema },
),
);
await nestedTable.createIndex("image.embedding", {
name: "image_embedding_idx",
});
const indices = await nestedTable.listIndices();
expect(indices).toContainEqual({
name: "image_embedding_idx",
indexType: "IvfPq",
columns: ["image.embedding"],
});
const explicit = await nestedTable
.query()
.nearestTo([0.0, 1.0])
.column("image.embedding")
.limit(1)
.toArray();
const inferred = await nestedTable
.query()
.nearestTo([0.0, 1.0])
.limit(1)
.toArray();
expect(inferred[0].id).toEqual(explicit[0].id);
});
it("should report multiple nested vector candidates", async () => {
const db = await connect(tmpDir.name);
const nestedSchema = new Schema([
new Field(
"image",
new Struct([
new Field(
"embedding",
new FixedSizeList(2, new Field("item", new Float32(), true)),
true,
),
]),
true,
),
new Field(
"text",
new Struct([
new Field(
"embedding",
new FixedSizeList(2, new Field("item", new Float32(), true)),
true,
),
]),
true,
),
]);
const nestedTable = await db.createTable(
"multiple_nested_vectors",
makeArrowTable(
[
{
image: { embedding: [0.0, 1.0] },
text: { embedding: [2.0, 3.0] },
},
],
{ schema: nestedSchema },
),
);
await expect(
nestedTable.query().nearestTo([0.0, 1.0]).limit(1).toArray(),
).rejects.toThrow(/image\.embedding.*text\.embedding/);
});
it("should report when no default vector column exists", async () => {
const db = await connect(tmpDir.name);
const noVectorTable = await db.createTable(
"no_vector",
makeArrowTable([{ id: 0, label: "cat" }]),
);
await expect(
noVectorTable.query().nearestTo([0.0, 1.0]).limit(1).toArray(),
).rejects.toThrow(/No vector column/);
});
it("should wait for index readiness", async () => {
// Create an index and then wait for it to be ready
await tbl.createIndex("vec");

View File

@@ -144,6 +144,19 @@ export interface DropNamespaceOptions {
behavior?: "restrict" | "cascade";
}
export interface RenameTableOptions {
/**
* The namespace path of the table being renamed. Defaults to the root
* namespace (`[]`) when omitted.
*/
namespacePath?: string[];
/**
* The namespace path to move the table to as part of the rename. When
* omitted the table stays in `namespacePath`.
*/
newNamespacePath?: string[];
}
/**
* A LanceDB Connection that allows you to open tables and create new ones.
*
@@ -296,12 +309,6 @@ export abstract class Connection {
*/
abstract dropTable(name: string, namespacePath?: string[]): Promise<void>;
abstract renameTable(
oldName: string,
newName: string,
namespacePath?: string[],
): Promise<void>;
/**
* Drop all tables in the database.
* @param {string[]} namespacePath The namespace path to drop tables from (defaults to root namespace).
@@ -397,6 +404,24 @@ export abstract class Connection {
isShallow?: boolean;
},
): Promise<Table>;
/**
* Rename a table.
*
* Currently only supported by LanceDB Cloud. Local OSS connections and
* namespace-backed connections (via {@link connectNamespace}) reject with
* a "not supported" error.
*
* @param {string} currentName - The current name of the table.
* @param {string} newName - The new name for the table.
* @param {RenameTableOptions} options - Optional namespace paths. When
* `newNamespacePath` is omitted the table stays in `namespacePath`.
*/
abstract renameTable(
currentName: string,
newName: string,
options?: RenameTableOptions,
): Promise<void>;
}
/** @hideconstructor */
@@ -615,14 +640,6 @@ export class LocalConnection extends Connection {
return this.inner.dropTable(name, namespacePath ?? []);
}
async renameTable(
oldName: string,
newName: string,
namespacePath?: string[],
): Promise<void> {
return this.inner.renameTable(oldName, newName, namespacePath ?? []);
}
async dropAllTables(namespacePath?: string[]): Promise<void> {
return this.inner.dropAllTables(namespacePath ?? []);
}
@@ -665,6 +682,19 @@ export class LocalConnection extends Connection {
options?.behavior,
);
}
async renameTable(
currentName: string,
newName: string,
options?: RenameTableOptions,
): Promise<void> {
return this.inner.renameTable(
currentName,
newName,
options?.namespacePath ?? [],
options?.newNamespacePath,
);
}
}
/**

View File

@@ -50,7 +50,6 @@ export {
SplitHashOptions,
SplitSequentialOptions,
ShuffleOptions,
OAuthConfig as NativeOAuthConfig,
} from "./native.js";
export {
@@ -72,6 +71,7 @@ export {
CreateNamespaceResponse,
DropNamespaceResponse,
DescribeNamespaceResponse,
RenameTableOptions,
} from "./connection";
export { Session } from "./native.js";
@@ -114,6 +114,7 @@ export {
UpdateOptions,
OptimizeOptions,
Version,
WriteProgress,
LsmWriteSpec,
ColumnAlteration,
} from "./table";
@@ -125,8 +126,6 @@ export {
TokenResponse,
} from "./header";
export { OAuthConfig, OAuthFlowType } from "./oauth";
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
export * as embedding from "./embedding";

View File

@@ -1,82 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
/**
* OAuth authentication flow types.
*/
export enum OAuthFlowType {
/** Client Credentials grant (service-to-service / M2M). */
ClientCredentials = "client_credentials",
/** Authorization Code with PKCE (interactive browser-based auth). */
AuthorizationCodePKCE = "authorization_code_pkce",
/** Device Code grant (CLI / headless environments). */
DeviceCode = "device_code",
/** Azure Managed Identity via IMDS. */
AzureManagedIdentity = "azure_managed_identity",
/** Workload Identity Federation (K8s, GitHub Actions). */
WorkloadIdentity = "workload_identity",
}
/**
* OAuth configuration for LanceDB authentication.
*
* All token acquisition and refresh is handled in the Rust layer.
* This config is passed through to Rust via napi-rs.
*
* @example Client Credentials (service-to-service):
* ```typescript
* const config: OAuthConfig = {
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
* clientId: "app-id",
* clientSecret: "secret",
* scopes: ["api://lancedb-api/.default"],
* };
* ```
*
* @example Azure Managed Identity:
* ```typescript
* const config: OAuthConfig = {
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
* clientId: "app-id",
* scopes: ["api://lancedb-api/.default"],
* flow: OAuthFlowType.AzureManagedIdentity,
* };
* ```
*/
export interface OAuthConfig {
/**
* OIDC issuer URL or OAuth authority URL.
* For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
*/
issuerUrl: string;
/** Application / Client ID. */
clientId: string;
/**
* OAuth scopes to request.
* For Azure: `["api://{app_id}/.default"]`
*/
scopes: string[];
/** Authentication flow (default: ClientCredentials). */
flow?: OAuthFlowType;
/** Client secret (required for ClientCredentials). */
clientSecret?: string;
/** Redirect URI (AuthorizationCodePKCE flow). */
redirectUri?: string;
/** Port for local callback server (AuthorizationCodePKCE, default: 8400). */
callbackPort?: number;
/** Client ID for user-assigned managed identity (AzureManagedIdentity). */
managedIdentityClientId?: string;
/** Path to federated token file (WorkloadIdentity). */
tokenFile?: string;
/** Seconds before expiry to trigger proactive refresh (default: 300). */
refreshBufferSecs?: number;
}

View File

@@ -46,6 +46,33 @@ import { sanitizeType } from "./sanitize";
import { IntoSql, toSQL } from "./util";
export { IndexConfig } from "./native";
/**
* Progress snapshot for a write operation, delivered to the `progress`
* callback passed to {@link Table.add}.
*/
export interface WriteProgress {
/** Number of rows written so far. */
outputRows: number;
/** Number of bytes written so far. */
outputBytes: number;
/**
* Total rows expected, when the input source reports it.
*
* Always set on the final callback (the one with `done: true`), falling
* back to the actual number of rows written when the source could not
* report a row count up front.
*/
totalRows?: number;
/** Wall-clock seconds since the write started. */
elapsedSeconds: number;
/** Number of parallel write tasks currently in flight. */
activeTasks: number;
/** Total number of parallel write tasks (the write parallelism). */
totalTasks: number;
/** `true` for the final callback; `false` otherwise. */
done: boolean;
}
/**
* Options for adding data to a table.
*/
@@ -56,6 +83,28 @@ export interface AddDataOptions {
* If "overwrite" then the new data will replace the existing data in the table.
*/
mode: "append" | "overwrite";
/**
* Optional callback invoked periodically with write progress.
*
* The callback is fired once per batch written and once more with
* `done: true` when the write completes. Calls are dispatched
* asynchronously to the JS event loop and never block the write — a slow
* callback will queue events rather than back-pressure the writer.
*
* Errors thrown from the callback are logged with `console.warn` and
* swallowed — they do not abort the write.
*
* @example
* ```ts
* await table.add(data, {
* progress: (p) => {
* console.log(`${p.outputRows}/${p.totalRows ?? "?"} rows`);
* },
* });
* ```
*/
progress: (progress: WriteProgress) => void;
}
export interface UpdateOptions {
@@ -488,19 +537,14 @@ export abstract class Table {
*
* `LsmWriteSpec` chooses one of three sharding strategies via `specType`:
*
* - `"bucket"` — hash-bucket writes by the single-column unenforced primary
* key (`column` and `numBuckets` required).
* - `"bucket"` — hash-bucket writes by a scalar `column` (`column` and
* `numBuckets` required).
* - `"identity"` — shard by the raw value of a scalar `column`.
* - `"unsharded"` — route every write to a single shard.
*
* All variants require the table to have an unenforced primary key
* ({@link Table#setUnenforcedPrimaryKey}); bucket sharding additionally
* requires it to be the single column being bucketed.
* @param {LsmWriteSpec} spec The sharding spec to install.
* @returns {Promise<void>}
* @example
* ```ts
* await table.setUnenforcedPrimaryKey("id");
* await table.setLsmWriteSpec({
* specType: "bucket",
* column: "id",
@@ -705,7 +749,20 @@ export class LocalTable extends Table {
const schema = await this.schema();
const buffer = await fromDataToBuffer(data, undefined, schema);
return await this.inner.add(buffer, mode);
// Wrap the user callback so a thrown error doesn't surface as an
// unhandled exception (the callback fires from a napi threadsafe
// function — exceptions there crash the process).
const userProgress = options?.progress;
const progress = userProgress
? (p: WriteProgress) => {
try {
userProgress(p);
} catch (e) {
console.warn("Table.add progress callback threw:", e);
}
}
: undefined;
return await this.inner.add(buffer, mode, progress);
}
async update(

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.29.1-beta.0",
"version": "0.30.0-beta.0",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.29.1-beta.0",
"version": "0.30.0-beta.0",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.29.1-beta.0",
"version": "0.30.0-beta.0",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.29.1-beta.0",
"version": "0.30.0-beta.0",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.29.1-beta.0",
"version": "0.30.0-beta.0",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.29.1-beta.0",
"version": "0.30.0-beta.0",
"os": [
"win32"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.29.1-beta.0",
"version": "0.30.0-beta.0",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.29.1-beta.0",
"version": "0.30.0-beta.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.29.1-beta.0",
"version": "0.30.0-beta.0",
"cpu": [
"x64",
"arm64"

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.29.1-beta.0",
"version": "0.30.0-beta.0",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -112,11 +112,6 @@ impl Connection {
builder = builder.client_config(rust_config);
if let Some(oauth_config) = options.oauth_config {
let config: lancedb::remote::oauth::OAuthConfig = oauth_config.into();
builder = builder.oauth_config(config);
}
if let Some(api_key) = options.api_key {
builder = builder.api_key(&api_key);
}
@@ -333,20 +328,6 @@ impl Connection {
.default_error()
}
#[napi(catch_unwind)]
pub async fn rename_table(
&self,
old_name: String,
new_name: String,
namespace_path: Option<Vec<String>>,
) -> napi::Result<()> {
let ns = namespace_path.unwrap_or_default();
self.get_inner()?
.rename_table(&old_name, &new_name, &ns, &ns)
.await
.default_error()
}
#[napi(catch_unwind)]
pub async fn drop_all_tables(&self, namespace_path: Option<Vec<String>>) -> napi::Result<()> {
let ns = namespace_path.unwrap_or_default();
@@ -478,4 +459,23 @@ impl Connection {
transaction_id: resp.transaction_id,
})
}
/// Rename a table. `current_namespace_path` and `new_namespace_path` default to
/// the root namespace when omitted; the caller is expected to either pass both
/// or pass neither.
#[napi(catch_unwind)]
pub async fn rename_table(
&self,
current_name: String,
new_name: String,
current_namespace_path: Option<Vec<String>>,
new_namespace_path: Option<Vec<String>>,
) -> napi::Result<()> {
let cur_ns = current_namespace_path.unwrap_or_default();
let new_ns = new_namespace_path.unwrap_or_default();
self.get_inner()?
.rename_table(&current_name, &new_name, &cur_ns, &new_ns)
.await
.default_error()
}
}

View File

@@ -61,10 +61,6 @@ pub struct ConnectionOptions {
/// (For LanceDB cloud only): the host to use for LanceDB cloud. Used
/// for testing purposes.
pub host_override: Option<String>,
/// (For LanceDB cloud only): OAuth configuration for IdP-based
/// authentication (e.g., Azure Entra ID). When set, token acquisition
/// and refresh are handled entirely in Rust.
pub oauth_config: Option<remote::OAuthConfig>,
}
#[napi(object)]

View File

@@ -140,67 +140,6 @@ impl From<TlsConfig> for lancedb::remote::TlsConfig {
}
}
/// OAuth configuration for LanceDB authentication.
/// All token acquisition and refresh is handled in the Rust layer.
#[napi(object)]
#[derive(Debug, Clone)]
pub struct OAuthConfig {
/// OIDC issuer URL or OAuth authority URL.
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
pub issuer_url: String,
/// Application / Client ID.
pub client_id: String,
/// OAuth scopes to request. For Azure: `["api://{app_id}/.default"]`
pub scopes: Vec<String>,
/// Authentication flow: "client_credentials", "authorization_code_pkce",
/// "device_code", "azure_managed_identity", "workload_identity"
pub flow: Option<String>,
/// Client secret (required for client_credentials).
pub client_secret: Option<String>,
/// Redirect URI (authorization_code_pkce flow).
pub redirect_uri: Option<String>,
/// Port for local callback server (authorization_code_pkce, default: 8400).
pub callback_port: Option<u16>,
/// Client ID for user-assigned managed identity (azure_managed_identity).
pub managed_identity_client_id: Option<String>,
/// Path to federated token file (workload_identity).
pub token_file: Option<String>,
/// Seconds before expiry to trigger proactive refresh (default: 300).
pub refresh_buffer_secs: Option<u32>,
}
impl From<OAuthConfig> for lancedb::remote::oauth::OAuthConfig {
fn from(config: OAuthConfig) -> Self {
use lancedb::remote::oauth::OAuthFlow;
let flow = match config.flow.as_deref().unwrap_or("client_credentials") {
"authorization_code_pkce" => OAuthFlow::AuthorizationCodePKCE {
redirect_uri: config.redirect_uri,
callback_port: config.callback_port,
},
"device_code" => OAuthFlow::DeviceCode,
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
client_id: config.managed_identity_client_id,
},
"workload_identity" => OAuthFlow::WorkloadIdentity {
token_file: config
.token_file
.expect("tokenFile is required for workload_identity flow"),
},
other => panic!("Unknown OAuth flow type: {other}"),
};
Self {
issuer_url: config.issuer_url,
client_id: config.client_id,
client_secret: config.client_secret,
scopes: config.scopes,
flow,
refresh_buffer_secs: config.refresh_buffer_secs.map(|v| v as u64),
}
}
}
impl From<ClientConfig> for lancedb::remote::ClientConfig {
fn from(config: ClientConfig) -> Self {
Self {

View File

@@ -9,6 +9,7 @@ use lancedb::table::{
OptimizeAction, OptimizeOptions, Table as LanceDbTable,
};
use napi::bindgen_prelude::*;
use napi::threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode};
use napi_derive::napi;
use crate::error::NapiErrorExt;
@@ -67,8 +68,16 @@ impl Table {
schema_to_buffer(&schema)
}
#[napi(catch_unwind)]
pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result<AddResult> {
#[napi(
catch_unwind,
ts_args_type = "buf: Buffer, mode: string, progressCallback?: (progress: WriteProgressInfo) => void"
)]
pub async fn add(
&self,
buf: Buffer,
mode: String,
progress_callback: Option<ProgressFn>,
) -> napi::Result<AddResult> {
let batches = ipc_file_to_batches(buf.to_vec())
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
let batches = batches
@@ -92,6 +101,19 @@ impl Table {
return Err(napi::Error::from_reason(format!("Invalid mode: {}", mode)));
};
if let Some(tsfn) = progress_callback {
op = op.progress(move |p| {
// NonBlocking: dispatch onto the JS event loop without
// blocking the writer thread. With napi-rs's default
// unbounded queue, events are not dropped — a slow JS
// callback will just queue them.
tsfn.call(
WriteProgressInfo::from(p),
ThreadsafeFunctionCallMode::NonBlocking,
);
});
}
let res = op.execute().await.default_error()?;
Ok(res.into())
}
@@ -654,6 +676,44 @@ pub struct OptimizeStats {
pub prune: RemovalStats,
}
/// Progress snapshot for a write operation, delivered to the JS callback
/// passed to `Table.add`.
#[napi(object)]
#[derive(Clone, Debug)]
pub struct WriteProgressInfo {
/// Number of rows written so far.
pub output_rows: i64,
/// Number of bytes written so far.
pub output_bytes: i64,
/// Total rows expected, if the input source reports it.
/// Always set on the final callback (where `done` is `true`).
pub total_rows: Option<i64>,
/// Wall-clock seconds since monitoring started.
pub elapsed_seconds: f64,
/// Number of parallel write tasks currently in flight.
pub active_tasks: i64,
/// Total number of parallel write tasks (the write parallelism).
pub total_tasks: i64,
/// `true` for the final callback; `false` otherwise.
pub done: bool,
}
impl From<&lancedb::table::write_progress::WriteProgress> for WriteProgressInfo {
fn from(p: &lancedb::table::write_progress::WriteProgress) -> Self {
Self {
output_rows: p.output_rows() as i64,
output_bytes: p.output_bytes() as i64,
total_rows: p.total_rows().map(|n| n as i64),
elapsed_seconds: p.elapsed().as_secs_f64(),
active_tasks: p.active_tasks() as i64,
total_tasks: p.total_tasks() as i64,
done: p.done(),
}
}
}
type ProgressFn = ThreadsafeFunction<WriteProgressInfo, (), WriteProgressInfo, Status, false>;
/// A definition of a column alteration. The alteration changes the column at
/// `path` to have the new name `name`, to be nullable if `nullable` is true,
/// and to have the data type `data_type`. At least one of `rename` or `nullable`

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.32.1-beta.0"
current_version = "0.33.0-beta.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -4,16 +4,26 @@ code is in the `src/` directory and the Python bindings are in the `lancedb/` di
Common commands:
* Bootstrap dev env: `uv run --extra tests --extra dev maturin develop --extras tests,dev`
* Build: `make develop`
* Format: `make format`
* Lint: `make check`
* Fix lints: `make fix`
* Test: `make test`
* Doc test: `make doctest`
* Test: `uv run --extra tests pytest python/tests -vv --durations=10 -m "not slow and not s3_test"`
* Run specific test: `uv run --extra tests pytest python/tests/<test_file>.py::<test_name> -q`
* Doc test: `uv run --extra tests pytest --doctest-modules python/lancedb`
Use the uv-managed environment declared by `uv.lock` for Python validation. Do
not treat system `python`, global `pytest`, or missing editable-install errors
as final blockers; bootstrap or enter the uv environment instead. `make test`
and `make doctest` assume the development environment is already prepared.
Before committing changes, run lints and then formatting.
When you change the Rust code, you will need to recompile the Python bindings: `make develop`.
When you change the Rust code, PyO3 binding code, or see a missing/stale
`lancedb._lancedb`, recompile the Python bindings with
`uv run --extra tests --extra dev maturin develop --extras tests,dev` before
running tests.
When you export new types from Rust to Python, you must manually update `python/lancedb/_lancedb.pyi`
with the corresponding type hints. You can run `pyright` to check for type errors in the Python code.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.32.1-beta.0"
version = "0.33.0-beta.0"
publish = false
edition.workspace = true
description = "Python bindings for LanceDB"

View File

@@ -320,7 +320,6 @@ async def connect_async(
session: Optional[Session] = None,
manifest_enabled: bool = False,
namespace_client_properties: Optional[Dict[str, str]] = None,
oauth_config=None,
) -> AsyncConnection:
"""Connect to a LanceDB database.
@@ -411,7 +410,6 @@ async def connect_async(
session,
manifest_enabled,
namespace_client_properties,
oauth_config,
)
)

View File

@@ -250,7 +250,6 @@ async def connect(
session: Optional[Session],
manifest_enabled: bool = False,
namespace_client_properties: Optional[Dict[str, str]] = None,
oauth_config: Optional[Any] = None,
) -> Connection: ...
class RecordBatchStream:

View File

@@ -8,7 +8,17 @@ from abc import abstractmethod
from datetime import timedelta
from pathlib import Path
import sys
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generator,
Iterable,
List,
Literal,
Optional,
Union,
)
if sys.version_info >= (3, 12):
from typing import override
@@ -313,7 +323,7 @@ class DBConnection(EnforceOverrides):
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
>>> db.create_table("my_table", data)
LanceTable(name='my_table', version=1, ...)
LanceTable(name='my_table', ...)
>>> db["my_table"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
@@ -334,7 +344,7 @@ class DBConnection(EnforceOverrides):
... "long": [-122.7, -74.1]
... })
>>> db.create_table("table2", data)
LanceTable(name='table2', version=1, ...)
LanceTable(name='table2', ...)
>>> db["table2"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
@@ -357,7 +367,7 @@ class DBConnection(EnforceOverrides):
... pa.field("long", pa.float32())
... ])
>>> db.create_table("table3", data, schema = custom_schema)
LanceTable(name='table3', version=1, ...)
LanceTable(name='table3', ...)
>>> db["table3"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
@@ -391,7 +401,7 @@ class DBConnection(EnforceOverrides):
... pa.field("price", pa.float32()),
... ])
>>> db.create_table("table4", make_batches(), schema=schema)
LanceTable(name='table4', version=1, ...)
LanceTable(name='table4', ...)
"""
raise NotImplementedError
@@ -568,15 +578,15 @@ class LanceDBConnection(DBConnection):
>>> db = lancedb.connect("./.lancedb")
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
... {"vector": [0.5, 1.3], "b": 4}])
LanceTable(name='my_table', version=1, ...)
LanceTable(name='my_table', ...)
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
LanceTable(name='another_table', version=1, ...)
LanceTable(name='another_table', ...)
>>> sorted(db.table_names())
['another_table', 'my_table']
>>> len(db)
2
>>> db["my_table"]
LanceTable(name='my_table', version=1, ...)
LanceTable(name='my_table', ...)
>>> "my_table" in db
True
>>> db.drop_table("my_table")
@@ -847,11 +857,20 @@ class LanceDBConnection(DBConnection):
)
)
def _all_table_names(self) -> Generator[str, None, None]:
page_token = None
while True:
response = self.list_tables(page_token=page_token)
yield from response.tables
page_token = response.page_token
if not page_token:
return
def __len__(self) -> int:
return len(self.table_names())
return sum(1 for _ in self._all_table_names())
def __contains__(self, name: str) -> bool:
return name in self.table_names()
return name in self._all_table_names()
@override
def create_table(

View File

@@ -3,12 +3,14 @@
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from datetime import timedelta
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
@@ -17,41 +19,40 @@ from typing import (
Type,
TypeVar,
Union,
Any,
)
import asyncio
import deprecation
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pydantic
from typing_extensions import Annotated
from lancedb.pydantic import PYDANTIC_VERSION
from lancedb._lancedb import fts_query_to_json
from lancedb.background_loop import LOOP
from lancedb.pydantic import PYDANTIC_VERSION
from . import __version__
from .arrow import AsyncRecordBatchReader
from .dependencies import pandas as pd
from .expr import Expr
from .rerankers.base import Reranker
from .rerankers.rrf import RRFReranker
from .rerankers.util import check_reranker_result
from .util import flatten_columns
from .expr import Expr
from lancedb._lancedb import fts_query_to_json
from typing_extensions import Annotated
if TYPE_CHECKING:
import sys
import PIL
import polars as pl
from ._lancedb import Query as LanceQuery
from ._lancedb import FTSQuery as LanceFTSQuery
from ._lancedb import HybridQuery as LanceHybridQuery
from ._lancedb import VectorQuery as LanceVectorQuery
from ._lancedb import TakeQuery as LanceTakeQuery
from ._lancedb import PyQueryRequest
from ._lancedb import Query as LanceQuery
from ._lancedb import TakeQuery as LanceTakeQuery
from ._lancedb import VectorQuery as LanceVectorQuery
from .common import VEC
from .pydantic import LanceModel
from .table import Table
@@ -718,6 +719,7 @@ class LanceQueryBuilder(ABC):
flatten: Optional[Union[int, bool]] = None,
*,
timeout: Optional[timedelta] = None,
**kwargs,
) -> "pd.DataFrame":
"""
Execute the query and return the results as a pandas DataFrame.
@@ -735,9 +737,12 @@ class LanceQueryBuilder(ABC):
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
**kwargs
Forwarded to pyarrow.Table.to_pandas after query execution and
optional flattening.
"""
tbl = flatten_columns(self.to_arrow(timeout=timeout), flatten)
return tbl.to_pandas()
return tbl.to_pandas(**kwargs)
@abstractmethod
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
@@ -2352,6 +2357,7 @@ class AsyncQueryBase(object):
self,
flatten: Optional[Union[int, bool]] = None,
timeout: Optional[timedelta] = None,
**kwargs,
) -> "pd.DataFrame":
"""
Execute the query and collect the results into a pandas DataFrame.
@@ -2384,10 +2390,13 @@ class AsyncQueryBase(object):
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
**kwargs
Forwarded to pyarrow.Table.to_pandas after query execution and
optional flattening.
"""
return (
flatten_columns(await self.to_arrow(timeout=timeout), flatten)
).to_pandas()
).to_pandas(**kwargs)
async def to_polars(
self,
@@ -3340,16 +3349,18 @@ class BaseQueryBuilder(object):
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
"""
async_iter = LOOP.run(self._inner.execute(max_batch_length, timeout))
async_reader = LOOP.run(
self._inner.to_batches(max_batch_length=max_batch_length, timeout=timeout)
)
def iter_sync():
try:
while True:
yield LOOP.run(async_iter.__anext__())
yield LOOP.run(async_reader.__anext__())
except StopAsyncIteration:
return
return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync())
return pa.RecordBatchReader.from_batches(async_reader.schema, iter_sync())
def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table:
"""
@@ -3389,6 +3400,7 @@ class BaseQueryBuilder(object):
self,
flatten: Optional[Union[int, bool]] = None,
timeout: Optional[timedelta] = None,
**kwargs,
) -> "pd.DataFrame":
"""
Execute the query and collect the results into a pandas DataFrame.
@@ -3421,8 +3433,11 @@ class BaseQueryBuilder(object):
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
**kwargs
Forwarded to pyarrow.Table.to_pandas after query execution and
optional flattening.
"""
return LOOP.run(self._inner.to_pandas(flatten, timeout))
return LOOP.run(self._inner.to_pandas(flatten, timeout, **kwargs))
def to_polars(
self,

View File

@@ -9,7 +9,6 @@ from typing import List, Optional
from lancedb import __version__
from .header import HeaderProvider
from .oauth import OAuthConfig, OAuthFlowType
__all__ = [
"TimeoutConfig",
@@ -17,8 +16,6 @@ __all__ = [
"TlsConfig",
"ClientConfig",
"HeaderProvider",
"OAuthConfig",
"OAuthFlowType",
]

View File

@@ -1,90 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
class OAuthFlowType(str, Enum):
"""OAuth authentication flow types."""
CLIENT_CREDENTIALS = "client_credentials"
"""Client Credentials grant (service-to-service / M2M)."""
AUTHORIZATION_CODE_PKCE = "authorization_code_pkce"
"""Authorization Code with PKCE (interactive browser-based auth)."""
DEVICE_CODE = "device_code"
"""Device Code grant (CLI / headless environments)."""
AZURE_MANAGED_IDENTITY = "azure_managed_identity"
"""Azure Managed Identity via IMDS."""
WORKLOAD_IDENTITY = "workload_identity"
"""Workload Identity Federation (K8s, GitHub Actions)."""
@dataclass
class OAuthConfig:
"""OAuth configuration for LanceDB authentication.
All token acquisition and refresh is handled in the Rust layer.
This config is passed through to Rust via PyO3.
Parameters
----------
issuer_url : str
OIDC issuer URL or OAuth authority URL.
For Azure: ``https://login.microsoftonline.com/{tenant_id}/v2.0``
client_id : str
Application / Client ID.
scopes : List[str]
OAuth scopes to request.
For Azure: ``["api://{app_id}/.default"]``
flow : OAuthFlowType
Authentication flow to use. Default: CLIENT_CREDENTIALS.
client_secret : Optional[str]
Client secret (required for CLIENT_CREDENTIALS).
redirect_uri : Optional[str]
Redirect URI for AUTHORIZATION_CODE_PKCE flow.
callback_port : Optional[int]
Port for local HTTP callback server (AUTHORIZATION_CODE_PKCE, default: 8400).
managed_identity_client_id : Optional[str]
Client ID for user-assigned managed identity (AZURE_MANAGED_IDENTITY).
token_file : Optional[str]
Path to federated token file (WORKLOAD_IDENTITY).
refresh_buffer_secs : Optional[int]
Seconds before expiry to trigger proactive refresh (default: 300).
Examples
--------
Client Credentials (service-to-service):
>>> config = OAuthConfig(
... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0",
... client_id="app-id",
... client_secret="secret",
... scopes=["api://lancedb-api/.default"],
... )
Azure Managed Identity:
>>> config = OAuthConfig(
... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0",
... client_id="app-id",
... scopes=["api://lancedb-api/.default"],
... flow=OAuthFlowType.AZURE_MANAGED_IDENTITY,
... )
"""
issuer_url: str
client_id: str
scopes: List[str]
flow: OAuthFlowType = OAuthFlowType.CLIENT_CREDENTIALS
client_secret: Optional[str] = None
redirect_uri: Optional[str] = None
callback_port: Optional[int] = None
managed_identity_client_id: Optional[str] = None
token_file: Optional[str] = None
refresh_buffer_secs: Optional[int] = None

View File

@@ -40,7 +40,7 @@ from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.table import _normalize_progress
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder, LanceTakeQueryBuilder
from ..table import AsyncTable, IndexStatistics, Query, Table, Tags
from ..table import AsyncTable, BlobMode, IndexStatistics, Query, Table, Tags
from ..types import BaseTokenizerType
@@ -101,7 +101,7 @@ class RemoteTable(Table):
"""to_arrow() is not yet supported on LanceDB cloud."""
raise NotImplementedError("to_arrow() is not yet supported on LanceDB cloud.")
def to_pandas(self):
def to_pandas(self, blob_mode: BlobMode = "lazy", **kwargs):
"""to_pandas() is not yet supported on LanceDB cloud."""
raise NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")

View File

@@ -87,6 +87,8 @@ from .util import (
)
from .index import lang_mapping
BlobMode = Literal["lazy", "bytes", "descriptions"]
_MODEL_BACKED_TOKENIZER_PREFIXES = ("jieba", "lindera")
_MODEL_BACKED_TOKENIZER_ERRORS = (
"unknown base tokenizer",
@@ -760,14 +762,22 @@ class Table(ABC):
"""
raise NotImplementedError
def to_pandas(self) -> "pandas.DataFrame":
def to_pandas(self, blob_mode: BlobMode = "lazy", **kwargs) -> "pandas.DataFrame":
"""Return the table as a pandas DataFrame.
Parameters
----------
blob_mode: str, default "lazy"
Controls how blob columns are returned for backends that support
Lance blob-aware pandas conversion.
**kwargs
Forwarded to PyArrow / Lance pandas conversion.
Returns
-------
pd.DataFrame
"""
return self.to_arrow().to_pandas()
return self.to_arrow().to_pandas(**kwargs)
@abstractmethod
def to_arrow(self) -> pa.Table:
@@ -2168,7 +2178,7 @@ class LanceTable(Table):
return LOOP.run(self._table.count_rows(filter))
def __repr__(self) -> str:
val = f"{self.__class__.__name__}(name={self.name!r}, version={self.version}"
val = f"{self.__class__.__name__}(name={self.name!r}"
if self._conn.read_consistency_interval is not None:
val += ", read_consistency_interval={!r}".format(
self._conn.read_consistency_interval
@@ -2183,14 +2193,27 @@ class LanceTable(Table):
"""Return the first n rows of the table."""
return LOOP.run(self._table.head(n))
def to_pandas(self) -> "pd.DataFrame":
def to_pandas(self, blob_mode: BlobMode = "lazy", **kwargs) -> "pd.DataFrame":
"""Return the table as a pandas DataFrame.
Parameters
----------
blob_mode: str, default "lazy"
Controls how Lance blob columns are returned.
**kwargs
Forwarded to Lance pandas conversion.
Returns
-------
pd.DataFrame
"""
return self.to_arrow().to_pandas()
if blob_mode == "lazy" and (
self._namespace_client is not None
or get_uri_scheme(self._dataset_path) == "memory"
):
return self.to_arrow().to_pandas(**kwargs)
return self.to_lance().to_pandas(blob_mode=blob_mode, **kwargs)
def to_arrow(self) -> pa.Table:
"""Return the table as a pyarrow Table.
@@ -2519,11 +2542,6 @@ class LanceTable(Table):
"at a time. To search over multiple text fields, create a "
"separate FTS index for each field."
)
if "." in field_names:
raise ValueError(
"Native FTS indexes can only be created on top-level fields. "
f"Received nested field path: {field_names!r}."
)
if tokenizer_name is None:
tokenizer_configs = {
@@ -3857,15 +3875,11 @@ class AsyncTable:
strategies:
- ``LsmWriteSpec.bucket(column, num_buckets)`` — hash-bucket writes by
the single-column unenforced primary key.
a scalar column.
- ``LsmWriteSpec.identity(column)`` — shard by the raw value of a
scalar column.
- ``LsmWriteSpec.unsharded()`` — route every write to a single shard.
All variants require the table to have an unenforced primary key set
via [`set_unenforced_primary_key`]; bucket sharding additionally
requires it to be the single column being bucketed.
Parameters
----------
spec : LsmWriteSpec
@@ -3874,7 +3888,6 @@ class AsyncTable:
Examples
--------
>>> from lancedb._lancedb import LsmWriteSpec
>>> # table.set_unenforced_primary_key("id")
>>> # table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 16))
"""
await self._inner.set_lsm_write_spec(spec)
@@ -3945,14 +3958,39 @@ class AsyncTable:
"""
return AsyncQuery(self._inner.query())
async def to_pandas(self) -> "pd.DataFrame":
async def _to_lance(self, **kwargs) -> lance.LanceDataset:
try:
import lance
except ImportError:
raise ImportError(
"The lance library is required to use this function. "
"Please install with `pip install pylance`."
)
return lance.dataset(
await self.uri(),
version=await self.version(),
storage_options=await self.latest_storage_options(),
**kwargs,
)
async def to_pandas(self, blob_mode: BlobMode = "lazy", **kwargs) -> "pd.DataFrame":
"""Return the table as a pandas DataFrame.
Parameters
----------
blob_mode: str, default "lazy"
Controls how Lance blob columns are returned.
**kwargs
Forwarded to PyArrow / Lance pandas conversion.
Returns
-------
pd.DataFrame
"""
return (await self.to_arrow()).to_pandas()
if blob_mode == "lazy":
return (await self.to_arrow()).to_pandas(**kwargs)
return (await self._to_lance()).to_pandas(blob_mode=blob_mode, **kwargs)
async def to_arrow(self) -> pa.Table:
"""Return the table as a pyarrow Table.

View File

@@ -10,7 +10,7 @@ import pathlib
import warnings
from datetime import date, datetime
from functools import singledispatch
from typing import Tuple, Union, Optional, Any
from typing import Tuple, Union, Optional, Any, List
from urllib.parse import urlparse
import numpy as np
@@ -189,7 +189,33 @@ def flatten_columns(tbl: pa.Table, flatten: Optional[Union[int, bool]] = None):
return tbl
def inf_vector_column_query(schema: pa.Schema) -> str:
def _format_field_path(path: List[str]) -> str:
def format_segment(segment: str) -> str:
if all(char.isalnum() or char == "_" for char in segment):
return segment
return f"`{segment.replace('`', '``')}`"
return ".".join(format_segment(segment) for segment in path)
def _iter_vector_columns(
field: pa.Field, path: List[str], dim: Optional[int] = None
) -> List[str]:
field_path = [*path, field.name]
if is_vector_column(field.type):
vector_dim = infer_vector_column_dim(field.type)
if dim is None or vector_dim == dim:
return [_format_field_path(field_path)]
return []
if pa.types.is_struct(field.type):
columns = []
for idx in range(field.type.num_fields):
columns.extend(_iter_vector_columns(field.type.field(idx), field_path, dim))
return columns
return []
def inf_vector_column_query(schema: pa.Schema, dim: Optional[int] = None) -> str:
"""
Get the vector column name
@@ -202,26 +228,21 @@ def inf_vector_column_query(schema: pa.Schema) -> str:
-------
str: the vector column name.
"""
vector_col_name = ""
vector_col_count = 0
for field_name in schema.names:
field = schema.field(field_name)
if is_vector_column(field.type):
vector_col_count += 1
if vector_col_count > 1:
raise ValueError(
"Schema has more than one vector column. "
"Please specify the vector column name "
"for vector search"
)
elif vector_col_count == 1:
vector_col_name = field_name
if vector_col_count == 0:
vector_col_names = []
for field in schema:
vector_col_names.extend(_iter_vector_columns(field, [], dim))
if len(vector_col_names) > 1:
raise ValueError(
"Schema has more than one vector column. "
"Please specify the vector column name "
f"for vector search. Candidates: {vector_col_names}"
)
if len(vector_col_names) == 0:
raise ValueError(
"There is no vector column in the data. "
"Please specify the vector column name for vector search"
)
return vector_col_name
return vector_col_names[0]
def is_vector_column(data_type: pa.DataType) -> bool:
@@ -247,6 +268,29 @@ def is_vector_column(data_type: pa.DataType) -> bool:
return False
def infer_vector_column_dim(data_type: pa.DataType) -> Optional[int]:
if pa.types.is_fixed_size_list(data_type):
return data_type.list_size
if pa.types.is_list(data_type):
return infer_vector_column_dim(data_type.value_type)
return None
def _query_vector_dim(query: Optional[Any]) -> Optional[int]:
if query is None:
return None
if isinstance(query, np.ndarray):
if query.ndim == 0:
return None
return query.shape[-1]
if isinstance(query, list) and query:
first = query[0]
if isinstance(first, (list, tuple, np.ndarray)):
return len(first)
return len(query)
return None
def infer_vector_column_name(
schema: pa.Schema,
query_type: str,
@@ -262,7 +306,9 @@ def infer_vector_column_name(
if query is not None or query_type == "hybrid":
try:
vector_column_name = inf_vector_column_query(schema)
vector_column_name = inf_vector_column_query(
schema, dim=_query_vector_dim(query)
)
except Exception as e:
raise e

View File

@@ -6,6 +6,7 @@ import re
import sys
from datetime import timedelta
import os
from types import SimpleNamespace
import lancedb
import numpy as np
@@ -188,6 +189,43 @@ def test_table_names(tmp_db: lancedb.DBConnection):
assert len(result) == 3
def test_db_contains_and_len_include_all_table_name_pages(tmp_db: lancedb.DBConnection):
for idx in range(20):
tmp_db.create_table(f"table_{idx}", data=[{"id": idx}])
assert len(tmp_db) == 20
for idx in range(20):
assert f"table_{idx}" in tmp_db
assert "does_not_exist" not in tmp_db
def test_db_contains_stops_after_matching_table_page(
tmp_db: lancedb.DBConnection, monkeypatch
):
calls = []
pages = {
None: SimpleNamespace(tables=["table_0", "table_1"], page_token="next"),
"next": SimpleNamespace(tables=["table_2"], page_token=None),
}
def list_tables(*, page_token=None, **_kwargs):
calls.append(page_token)
return pages[page_token]
monkeypatch.setattr(tmp_db, "list_tables", list_tables)
assert "table_1" in tmp_db
assert calls == [None]
calls.clear()
assert "table_2" in tmp_db
assert calls == [None, "next"]
calls.clear()
assert len(tmp_db) == 3
assert calls == [None, "next"]
@pytest.mark.asyncio
async def test_table_names_async(tmp_path):
db = lancedb.connect(tmp_path)
@@ -412,6 +450,13 @@ async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection):
# await db.create_table("test", schema=bad_schema, exist_ok=True)
def _manifest_files(versions_dir):
# The `_versions` directory may also contain `latest_version_hint.json`,
# which Lance writes on non-lexically-ordered stores (e.g. the local
# filesystem) to speed up latest-version lookup. Only assert on manifests.
return [f for f in os.listdir(versions_dir) if f.endswith(".manifest")]
@pytest.mark.asyncio
async def test_create_table_v2_manifest_paths_async(tmp_path):
db_with_v2_paths = await lancedb.connect_async(
@@ -427,7 +472,7 @@ async def test_create_table_v2_manifest_paths_async(tmp_path):
)
assert await tbl.uses_v2_manifest_paths()
manifests_dir = tmp_path / "test_v2_manifest_paths.lance" / "_versions"
for manifest in os.listdir(manifests_dir):
for manifest in _manifest_files(manifests_dir):
assert re.match(r"\d{20}\.manifest", manifest)
# Start a table in V1 mode then migrate
@@ -437,13 +482,13 @@ async def test_create_table_v2_manifest_paths_async(tmp_path):
)
assert not await tbl.uses_v2_manifest_paths()
manifests_dir = tmp_path / "test_v2_migration.lance" / "_versions"
for manifest in os.listdir(manifests_dir):
for manifest in _manifest_files(manifests_dir):
assert re.match(r"\d\.manifest", manifest)
await tbl.migrate_manifest_paths_v2()
assert await tbl.uses_v2_manifest_paths()
for manifest in os.listdir(manifests_dir):
for manifest in _manifest_files(manifests_dir):
assert re.match(r"\d{20}\.manifest", manifest)

View File

@@ -563,8 +563,111 @@ def test_create_index_multiple_columns(tmp_path, table):
def test_nested_schema(tmp_path, table):
with pytest.raises(ValueError, match="top-level fields"):
table.create_fts_index("nested.text")
table.create_fts_index("nested.text", with_position=True)
indices = table.list_indices()
assert len(indices) == 1
assert indices[0].index_type == "FTS"
assert indices[0].columns == ["nested.text"]
results = (
table.search("puppy", query_type="fts", fts_columns="nested.text")
.limit(5)
.to_list()
)
assert len(results) > 0
assert all("puppy" in row["nested"]["text"] for row in results)
results = table.search(MatchQuery("puppy", "nested.text")).limit(5).to_list()
assert len(results) > 0
assert all("puppy" in row["nested"]["text"] for row in results)
phrase_results = (
table.search(PhraseQuery("puppy runs", "nested.text")).limit(5).to_list()
)
assert len(phrase_results) > 0
assert all("puppy runs" in row["nested"]["text"] for row in phrase_results)
hybrid_results = (
table.search(query_type="hybrid", fts_columns="nested.text")
.vector([0 for _ in range(128)])
.text("puppy")
.limit(5)
.to_list()
)
assert len(hybrid_results) > 0
@pytest.mark.asyncio
async def test_nested_schema_async(async_table):
await async_table.create_index("nested.text", config=FTS(with_position=True))
indices = await async_table.list_indices()
assert len(indices) == 1
assert indices[0].index_type == "FTS"
assert indices[0].columns == ["nested.text"]
results = await (
async_table.query()
.nearest_to_text("puppy", columns="nested.text")
.limit(5)
.to_list()
)
assert len(results) > 0
assert all("puppy" in row["nested"]["text"] for row in results)
results = await (
async_table.query()
.nearest_to_text(MatchQuery("puppy", "nested.text"))
.limit(5)
.to_list()
)
assert len(results) > 0
assert all("puppy" in row["nested"]["text"] for row in results)
phrase_results = await (
async_table.query()
.nearest_to_text(PhraseQuery("puppy runs", "nested.text"))
.limit(5)
.to_list()
)
assert len(phrase_results) > 0
assert all("puppy runs" in row["nested"]["text"] for row in phrase_results)
hybrid_results = await (
async_table.query()
.nearest_to([0 for _ in range(128)])
.nearest_to_text("puppy", columns="nested.text")
.limit(5)
.to_list()
)
assert len(hybrid_results) > 0
def test_nested_schema_rejects_invalid_fts_fields(tmp_path):
db = ldb.connect(tmp_path)
data = pa.table(
{
"payload": pa.array(
[
{"text": "puppy runs", "count": 1},
{"text": "car drives", "count": 2},
]
),
"vector": pa.array(
[[0.1, 0.1], [0.2, 0.2]],
type=pa.list_(pa.float32(), list_size=2),
),
}
)
table = db.create_table("test", data=data)
with pytest.raises(ValueError, match="FTS index cannot be created.*payload"):
table.create_fts_index("payload")
with pytest.raises(ValueError, match="FTS index cannot be created.*count"):
table.create_fts_index("payload.count")
with pytest.raises(ValueError, match="Field path `payload.missing` not found"):
table.create_fts_index("payload.missing")
def test_search_index_with_filter(table):

View File

@@ -105,6 +105,46 @@ async def test_create_scalar_index(some_table: AsyncTable):
assert len(indices) == 0
@pytest.mark.asyncio
async def test_create_nested_scalar_index_lists_canonical_paths(db_async):
metadata_type = pa.struct(
[
pa.field("user_id", pa.int32()),
pa.field("user.id", pa.int32()),
]
)
data = pa.Table.from_arrays(
[
pa.array([1, 2, 3], type=pa.int32()),
pa.array(
[
{"user_id": 10, "user.id": 100},
{"user_id": 20, "user.id": 200},
{"user_id": 30, "user.id": 300},
],
type=metadata_type,
),
],
names=["user_id", "metadata"],
)
table = await db_async.create_table("nested_scalar_index", data)
await table.create_index("user_id", config=BTree(), name="top_user_id_idx")
await table.create_index(
"metadata.user_id", config=BTree(), name="nested_user_id_idx"
)
await table.create_index(
"metadata.`user.id`", config=BTree(), name="escaped_user_id_idx"
)
columns_by_name = {
index.name: index.columns for index in await table.list_indices()
}
assert columns_by_name["top_user_id_idx"] == ["user_id"]
assert columns_by_name["nested_user_id_idx"] == ["metadata.user_id"]
assert columns_by_name["escaped_user_id_idx"] == ["metadata.`user.id`"]
@pytest.mark.asyncio
async def test_create_fixed_size_binary_index(some_table: AsyncTable):
await some_table.create_index("fsb", config=BTree())

View File

@@ -40,16 +40,6 @@ def _make_table(tmp_path):
def test_set_lsm_write_spec_validates(tmp_path):
_db, table = _make_table(tmp_path)
# No PK set yet.
with pytest.raises(Exception, match="primary key"):
table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 4))
table.set_unenforced_primary_key("id")
# Column mismatch.
with pytest.raises(Exception, match="match"):
table.set_lsm_write_spec(LsmWriteSpec.bucket("v", 4))
# Out-of-range num_buckets.
with pytest.raises(Exception, match="num_buckets"):
table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 0))
@@ -70,7 +60,6 @@ def test_unset_lsm_write_spec(tmp_path):
table.unset_lsm_write_spec()
# Install a spec, then remove it; afterwards a fresh spec can be set.
table.set_unenforced_primary_key("id")
table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 4))
table.unset_lsm_write_spec()
# A second unset errors — there is no spec left to remove.
@@ -81,9 +70,7 @@ def test_unset_lsm_write_spec(tmp_path):
def test_set_unsharded_spec(tmp_path):
_db, table = _make_table(tmp_path)
# Lance MemWAL still requires a primary key on the dataset; Unsharded
# just skips per-row hashing.
table.set_unenforced_primary_key("id")
# Unsharded routes every write to a single shard, skipping per-row hashing.
table.set_lsm_write_spec(LsmWriteSpec.unsharded())
table.unset_lsm_write_spec()
@@ -120,7 +107,6 @@ async def test_async_set_unset_lsm_write_spec(tmp_path):
pa.RecordBatchReader.from_batches(SCHEMA, [_batch(["seed"], [0])]),
)
await table.set_unenforced_primary_key("id")
await table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 4))
await table.unset_lsm_write_spec()
# A second unset errors.
@@ -130,9 +116,7 @@ async def test_async_set_unset_lsm_write_spec(tmp_path):
def test_set_identity_spec(tmp_path):
_db, table = _make_table(tmp_path)
# Identity sharding still requires an unenforced primary key on the
# table; it shards by the raw value of the given column.
table.set_unenforced_primary_key("id")
# Identity sharding shards by the raw value of the given column.
table.set_lsm_write_spec(LsmWriteSpec.identity("v"))
table.unset_lsm_write_spec()

View File

@@ -165,6 +165,22 @@ def test_offset(table):
assert len(results_with_offset.to_pandas()) == 1
@pytest.mark.asyncio
async def test_query_to_pandas_kwargs(table, table_async):
sync_df = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.select(["id"])
.limit(1)
.to_pandas(split_blocks=True)
)
assert sync_df["id"].tolist() == [1]
async_df = await (
table_async.query().select(["id"]).limit(2).to_pandas(split_blocks=True)
)
assert async_df["id"].tolist() == [1, 2]
def test_order_by_plain_query(mem_db):
table = mem_db.create_table(
"test_order_by",
@@ -1496,6 +1512,37 @@ def test_take_queries(tmp_path):
]
def test_take_queries_to_batches(tmp_path):
# Regression test for the sync take-query path: `to_batches` previously
# raised ``AttributeError: 'AsyncTakeQuery' object has no attribute
# 'execute'`` because the inherited ``BaseQueryBuilder.to_batches`` called
# ``execute`` on the async wrapper instead of the native query.
db = lancedb.connect(tmp_path)
data = pa.table({"idx": list(range(100)), "label": [str(i) for i in range(100)]})
table = db.create_table("test", data)
# Take by offset → to_batches
rs = list(table.take_offsets([5, 2, 17]).to_batches())
assert all(isinstance(b, pa.RecordBatch) for b in rs)
assert sum(b.num_rows for b in rs) == 3
assert sorted(v for b in rs for v in b.column("idx").to_pylist()) == [2, 5, 17]
# Take by row id → to_batches
rs = list(table.take_row_ids([5, 2, 17]).to_batches())
assert all(isinstance(b, pa.RecordBatch) for b in rs)
assert sum(b.num_rows for b in rs) == 3
assert sorted(v for b in rs for v in b.column("idx").to_pylist()) == [2, 5, 17]
# Take with select projection → to_batches preserves the projection
rs = list(table.take_row_ids([5, 2, 17]).select(["label"]).to_batches())
assert all(b.schema.names == ["label"] for b in rs)
assert sorted(v for b in rs for v in b.column("label").to_pylist()) == [
"17",
"2",
"5",
]
def test_getitems(tmp_path):
db = lancedb.connect(tmp_path)
data = pa.table(

View File

@@ -269,6 +269,25 @@ def test_table_unimplemented_functions():
table.to_pandas()
def test_table_to_pandas_not_supported():
def handler(request):
if request.path == "/v1/table/test/create/?mode=create":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}])
with pytest.raises(NotImplementedError):
table.to_pandas()
with pytest.raises(NotImplementedError):
table.to_pandas(blob_mode="bytes", split_blocks=True)
def test_table_add_in_threadpool():
def handler(request):
if request.path == "/v1/table/test/insert/":

View File

@@ -33,7 +33,7 @@ def test_basic(mem_db: DBConnection):
table = mem_db.create_table("test", data=data)
assert table.name == "test"
assert "LanceTable(name='test', version=1, _conn=LanceDBConnection(" in repr(table)
assert "LanceTable(name='test', _conn=LanceDBConnection(" in repr(table)
expected_schema = pa.schema(
{
"vector": pa.list_(pa.float32(), 2),
@@ -47,6 +47,85 @@ def test_basic(mem_db: DBConnection):
assert table.to_arrow() == expected_data
def test_table_to_pandas_default_matches_arrow(tmp_db: DBConnection):
pd = pytest.importorskip("pandas")
data = pa.table({"id": [1, 2], "text": ["one", "two"]})
table = tmp_db.create_table("test_to_pandas_old_call", data=data)
expected = data.to_pandas()
pd.testing.assert_frame_equal(table.to_pandas(), expected)
def test_table_to_pandas_blob_bytes(tmp_db: DBConnection):
pytest.importorskip("lance")
data = pa.table(
{
"id": pa.array([1, 2], pa.int64()),
"blob": pa.array([b"hello", b"world"], pa.large_binary()),
},
schema=pa.schema(
[
pa.field("id", pa.int64()),
pa.field(
"blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
),
]
),
)
table = tmp_db.create_table("test_to_pandas_blob_bytes", data=data)
df = table.to_pandas(blob_mode="bytes")
assert df["blob"].tolist() == [b"hello", b"world"]
def test_table_to_pandas_kwargs(tmp_db: DBConnection):
pd = pytest.importorskip("pandas")
data = pa.table({"id": pa.array([1, 2], pa.int64())})
table = tmp_db.create_table("test_to_pandas_kwargs", data=data)
df = table.to_pandas(types_mapper=pd.ArrowDtype)
assert str(df["id"].dtype) == "int64[pyarrow]"
@pytest.mark.asyncio
async def test_async_table_to_pandas_blob_bytes(tmp_db_async: AsyncConnection):
pytest.importorskip("lance")
data = pa.table(
{
"id": pa.array([1, 2], pa.int64()),
"blob": pa.array([b"hello", b"world"], pa.large_binary()),
},
schema=pa.schema(
[
pa.field("id", pa.int64()),
pa.field(
"blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
),
]
),
)
table = await tmp_db_async.create_table(
"test_async_to_pandas_blob_bytes", data=data
)
df = await table.to_pandas(blob_mode="bytes")
assert df["blob"].tolist() == [b"hello", b"world"]
@pytest.mark.asyncio
async def test_async_table_to_pandas_kwargs(tmp_db_async: AsyncConnection):
pd = pytest.importorskip("pandas")
data = pa.table({"id": pa.array([1, 2], pa.int64())})
table = await tmp_db_async.create_table("test_async_to_pandas_kwargs", data=data)
df = await table.to_pandas(types_mapper=pd.ArrowDtype)
assert str(df["id"].dtype) == "int64[pyarrow]"
def test_create_table_infers_large_int_vectors(mem_db: DBConnection):
data = [{"vector": [0, 300]}]
@@ -1811,6 +1890,59 @@ def test_create_scalar_index(mem_db: DBConnection):
assert scalar_index.name == "custom_y_index"
def test_create_index_nested_field_paths(mem_db: DBConnection):
schema = pa.schema(
[
pa.field("metadata", pa.struct([pa.field("user_id", pa.int32())])),
pa.field(
"image",
pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]),
),
]
)
data = pa.Table.from_pylist(
[
{
"metadata": {"user_id": i},
"image": {"embedding": [float(i), float(i + 1)]},
}
for i in range(256)
],
schema=schema,
)
table = mem_db.create_table("nested_index_paths", data=data)
table.create_scalar_index("metadata.user_id", name="metadata_user_id_idx")
table.create_index(
vector_column_name="image.embedding",
num_partitions=1,
num_sub_vectors=1,
name="image_embedding_idx",
)
indices = sorted(table.list_indices(), key=lambda idx: idx.name)
assert [(idx.name, idx.index_type, idx.columns) for idx in indices] == [
("image_embedding_idx", "IvfPq", ["image.embedding"]),
("metadata_user_id_idx", "BTree", ["metadata.user_id"]),
]
vector_results = (
table.search([0.0, 1.0], vector_column_name="image.embedding")
.limit(1)
.to_list()
)
assert len(vector_results) == 1
assert vector_results[0]["metadata"]["user_id"] == 0
default_vector_results = table.search([0.0, 1.0]).limit(1).to_list()
assert len(default_vector_results) == 1
assert default_vector_results[0]["metadata"]["user_id"] == 0
filtered_results = table.search().where("metadata.user_id = 42").limit(1).to_list()
assert len(filtered_results) == 1
assert filtered_results[0]["metadata"]["user_id"] == 42
def test_empty_query(mem_db: DBConnection):
table = mem_db.create_table(
"my_table",
@@ -1885,6 +2017,74 @@ def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection):
table.search(q).limit(1).to_arrow()
def test_search_infers_single_nested_vector(mem_db: DBConnection):
schema = pa.schema(
[
pa.field("id", pa.int32()),
pa.field(
"image",
pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]),
),
]
)
data = pa.Table.from_pylist(
[
{"id": 0, "image": {"embedding": [0.0, 1.0]}},
{"id": 1, "image": {"embedding": [10.0, 11.0]}},
],
schema=schema,
)
table = mem_db.create_table("nested_vector_default_search", data=data)
result = table.search([0.0, 1.0]).limit(1).to_list()
assert result[0]["id"] == 0
def test_search_nested_vector_multiple_candidates(mem_db: DBConnection):
schema = pa.schema(
[
pa.field(
"image",
pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]),
),
pa.field(
"text",
pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]),
),
]
)
data = pa.Table.from_pylist(
[
{
"image": {"embedding": [0.0, 1.0]},
"text": {"embedding": [2.0, 3.0]},
}
],
schema=schema,
)
table = mem_db.create_table("nested_vector_multiple_candidates", data=data)
with pytest.raises(ValueError, match="image.embedding.*text.embedding"):
table.search([0.0, 1.0]).limit(1).to_arrow()
def test_search_nested_vector_no_candidates(mem_db: DBConnection):
schema = pa.schema(
[
pa.field("id", pa.int32()),
pa.field("metadata", pa.struct([pa.field("label", pa.string())])),
]
)
data = pa.Table.from_pylist(
[{"id": 0, "metadata": {"label": "cat"}}],
schema=schema,
)
table = mem_db.create_table("nested_vector_no_candidates", data=data)
with pytest.raises(ValueError, match="no vector column"):
table.search([0.0, 1.0]).limit(1).to_arrow()
def test_compact_cleanup(tmp_db: DBConnection):
pytest.importorskip("lance")
table = tmp_db.create_table(

View File

@@ -539,7 +539,7 @@ impl Connection {
}
#[pyfunction]
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None, oauth_config=None))]
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None))]
#[allow(clippy::too_many_arguments)]
pub fn connect(
py: Python<'_>,
@@ -553,7 +553,6 @@ pub fn connect(
session: Option<crate::session::Session>,
manifest_enabled: bool,
namespace_client_properties: Option<HashMap<String, String>>,
oauth_config: Option<crate::oauth::PyOAuthConfig>,
) -> PyResult<Bound<'_, PyAny>> {
future_into_py(py, async move {
let mut builder = lancedb::connect(&uri);
@@ -583,10 +582,6 @@ pub fn connect(
if let Some(client_config) = client_config {
builder = builder.client_config(client_config.into());
}
if let Some(oauth_config) = oauth_config {
let config: lancedb::remote::oauth::OAuthConfig = oauth_config.into();
builder = builder.oauth_config(config);
}
if let Some(session) = session {
builder = builder.session(session.inner.clone());
}

View File

@@ -26,7 +26,6 @@ pub mod expr;
pub mod header;
pub mod index;
pub mod namespace;
pub mod oauth;
pub mod permutation;
pub mod query;
pub mod runtime;

View File

@@ -1,53 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use pyo3::FromPyObject;
use lancedb::remote::oauth::{OAuthConfig, OAuthFlow};
/// Python-side OAuth configuration, extracted via FromPyObject.
/// Maps to `lancedb.remote.oauth.OAuthConfig` Python dataclass.
#[derive(FromPyObject)]
pub struct PyOAuthConfig {
pub issuer_url: String,
pub client_id: String,
pub scopes: Vec<String>,
pub flow: String,
pub client_secret: Option<String>,
pub redirect_uri: Option<String>,
pub callback_port: Option<u16>,
pub managed_identity_client_id: Option<String>,
pub token_file: Option<String>,
pub refresh_buffer_secs: Option<u64>,
}
impl From<PyOAuthConfig> for OAuthConfig {
fn from(py: PyOAuthConfig) -> Self {
let flow = match py.flow.as_str() {
"client_credentials" => OAuthFlow::ClientCredentials,
"authorization_code_pkce" => OAuthFlow::AuthorizationCodePKCE {
redirect_uri: py.redirect_uri,
callback_port: py.callback_port,
},
"device_code" => OAuthFlow::DeviceCode,
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
client_id: py.managed_identity_client_id,
},
"workload_identity" => OAuthFlow::WorkloadIdentity {
token_file: py
.token_file
.expect("token_file is required for workload_identity flow"),
},
other => panic!("Unknown OAuth flow type: {other}"),
};
OAuthConfig {
issuer_url: py.issuer_url,
client_id: py.client_id,
client_secret: py.client_secret,
scopes: py.scopes,
flow,
refresh_buffer_secs: py.refresh_buffer_secs,
}
}
}

View File

@@ -185,7 +185,7 @@ pub struct LsmWriteSpec {
#[pymethods]
impl LsmWriteSpec {
/// Hash-bucket sharding by the unenforced primary key column.
/// Hash-bucket sharding by a scalar column.
#[staticmethod]
pub fn bucket(column: String, num_buckets: u32) -> Self {
Self {

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.29.1-beta.0"
version = "0.30.0-beta.0"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true
@@ -75,11 +75,6 @@ reqwest = { version = "0.12.0", default-features = false, features = [
"stream",
], optional = true }
http = { version = "1", optional = true } # Matching what is in reqwest
# OAuth dependencies (used by remote feature)
sha2 = { version = "0.10", optional = true }
base64 = { version = "0.22", optional = true }
urlencoding = { version = "2", optional = true }
open = { version = "5", optional = true }
uuid = { version = "1.7.0", features = ["v4"] }
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true }
@@ -133,7 +128,7 @@ huggingface = [
"lance-namespace-impls/dir-huggingface",
]
dynamodb = ["lance/dynamodb", "aws"]
remote = ["dep:reqwest", "dep:http", "dep:sha2", "dep:base64", "dep:urlencoding", "dep:open", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
bedrock = ["dep:aws-sdk-bedrockruntime"]

View File

@@ -622,8 +622,6 @@ pub struct ConnectRequest {
pub struct ConnectBuilder {
request: ConnectRequest,
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
#[cfg(feature = "remote")]
oauth_config: Option<crate::remote::oauth::OAuthConfig>,
}
#[cfg(feature = "remote")]
@@ -645,8 +643,6 @@ impl ConnectBuilder {
session: None,
},
embedding_registry: None,
#[cfg(feature = "remote")]
oauth_config: None,
}
}
@@ -735,19 +731,6 @@ impl ConnectBuilder {
self
}
/// Configure OAuth authentication for LanceDB Cloud/Enterprise.
///
/// This creates an [`OAuthHeaderProvider`](crate::remote::OAuthHeaderProvider)
/// from the given config and sets it as the header provider, replacing any
/// previously configured header provider or API key.
///
/// Token acquisition and refresh are handled entirely in Rust.
#[cfg(feature = "remote")]
pub fn oauth_config(mut self, config: crate::remote::oauth::OAuthConfig) -> Self {
self.oauth_config = Some(config);
self
}
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
self.embedding_registry = Some(registry);
@@ -891,29 +874,9 @@ impl ConnectBuilder {
let region = options.region.ok_or_else(|| Error::InvalidInput {
message: "A region is required when connecting to LanceDb Cloud".to_string(),
})?;
// When OAuth is configured, api_key is not required
let api_key = match (&self.oauth_config, &options.api_key) {
(Some(_), None) => String::new(),
(Some(_), Some(key)) => key.clone(),
(None, Some(key)) => key.clone(),
(None, None) => {
return Err(Error::InvalidInput {
message:
"An api_key or oauth_config is required when connecting to LanceDb Cloud"
.to_string(),
});
}
};
let mut client_config = self.request.client_config;
// Apply OAuth header provider if configured
if let Some(oauth_config) = self.oauth_config {
let provider = crate::remote::oauth::OAuthHeaderProvider::new(oauth_config)?;
client_config.header_provider =
Some(Arc::new(provider) as Arc<dyn crate::remote::client::HeaderProvider>);
}
let api_key = options.api_key.ok_or_else(|| Error::InvalidInput {
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
})?;
let storage_options = StorageOptions(options.storage_options.clone());
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
@@ -921,7 +884,7 @@ impl ConnectBuilder {
&api_key,
&region,
options.host_override,
client_config,
self.request.client_config,
storage_options.into(),
)?);
Ok(Connection {

View File

@@ -271,15 +271,26 @@ impl Scannable for WithEmbeddingsScannable {
.map_err(|e| Error::Runtime {
message: format!("Task panicked during embedding computation: {}", e),
})??;
// Cast columns to match the declared output schema. The data is
// identical but field metadata (e.g. nested nullability) may
// differ between the embedding function output and the table.
let columns: Vec<ArrayRef> = result
.columns()
// Look up columns by name (not position) so the result matches
// the output schema even when columns appear in a different
// order — e.g. `add_columns` placed a new column after the
// embedding column, but the computed batch appends embeddings
// at the end. Cast per-column because field metadata (e.g.
// nested nullability) may also differ between the embedding
// function output and the table.
let columns: Vec<ArrayRef> = output_schema
.fields()
.iter()
.enumerate()
.map(|(i, col)| {
let target_type = output_schema.field(i).data_type();
.map(|field| {
let col = result.column_by_name(field.name()).ok_or_else(|| {
Error::InvalidInput {
message: format!(
"Column '{}' required by the table schema was not present in the input batch",
field.name()
),
}
})?;
let target_type = field.data_type();
if col.data_type() == target_type {
Ok(col.clone())
} else {
@@ -964,5 +975,118 @@ mod tests {
"Expected EmbeddingFunctionNotFound"
);
}
/// Regression test for https://github.com/lancedb/lancedb/issues/3136.
///
/// When a column is added to the table after the embedding column via
/// schema evolution, the table schema becomes
/// `[..., embedding, extra]`. The input batch (without the embedding)
/// is `[..., extra]`, and `compute_embeddings_for_batch` appends the
/// embedding at the end giving `[..., extra, embedding]`. A positional
/// cast to the output schema would map `extra` onto `embedding` and
/// fail with a CastError. Columns must be matched by name.
#[tokio::test]
async fn test_with_embeddings_scannable_column_added_after_embedding() {
let input_schema = Arc::new(Schema::new(vec![
Field::new("text", DataType::Utf8, false),
Field::new("score", DataType::Float64, true),
]));
let batch = RecordBatch::try_new(
input_schema.clone(),
vec![
Arc::new(StringArray::from(vec!["hello", "world"])) as ArrayRef,
Arc::new(arrow_array::Float64Array::from(vec![1.0, 2.0])) as ArrayRef,
],
)
.unwrap();
let mock_embedding: Arc<dyn EmbeddingFunction> = Arc::new(MockEmbed::new("mock", 4));
let embedding_def = EmbeddingDefinition::new("text", "mock", Some("text_vec"));
// Table schema: embedding column is BEFORE `score`, as would
// happen if `score` was added via `add_columns` after creating
// the table with an embedding on `text`.
let output_schema = Arc::new(Schema::new(vec![
Field::new("text", DataType::Utf8, false),
Field::new(
"text_vec",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
4,
),
false,
),
Field::new("score", DataType::Float64, true),
]));
let mut scannable = WithEmbeddingsScannable::with_schema(
Box::new(batch),
vec![(embedding_def, mock_embedding)],
output_schema.clone(),
)
.unwrap();
let stream = scannable.scan_as_stream();
let results: Vec<RecordBatch> = stream.try_collect().await.unwrap();
assert_eq!(results.len(), 1);
let result_batch = &results[0];
assert_eq!(result_batch.schema(), output_schema);
assert_eq!(result_batch.num_rows(), 2);
// Position 1 must actually hold the FixedSizeList embedding —
// not the score column reinterpreted by a permissive cast.
let embedding = result_batch
.column(1)
.as_any()
.downcast_ref::<arrow_array::FixedSizeListArray>()
.expect("position 1 should be a FixedSizeList embedding");
assert_eq!(embedding.value_length(), 4);
assert_eq!(embedding.null_count(), 0);
}
/// If the input batch is missing a non-embedding column required by
/// the table schema, we should return a clear error rather than
/// silently producing a malformed batch.
#[tokio::test]
async fn test_with_embeddings_scannable_missing_required_column() {
let input_schema =
Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
let batch = RecordBatch::try_new(
input_schema,
vec![Arc::new(StringArray::from(vec!["hello", "world"])) as ArrayRef],
)
.unwrap();
let mock_embedding: Arc<dyn EmbeddingFunction> = Arc::new(MockEmbed::new("mock", 4));
let embedding_def = EmbeddingDefinition::new("text", "mock", Some("text_vec"));
let output_schema = Arc::new(Schema::new(vec![
Field::new("text", DataType::Utf8, false),
Field::new(
"text_vec",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
4,
),
false,
),
Field::new("score", DataType::Float64, true),
]));
let mut scannable = WithEmbeddingsScannable::with_schema(
Box::new(batch),
vec![(embedding_def, mock_embedding)],
output_schema,
)
.unwrap();
let stream = scannable.scan_as_stream();
let results: Result<Vec<RecordBatch>> = stream.try_collect().await;
let err = results.expect_err("expected an error");
assert!(
matches!(&err, Error::InvalidInput { message } if message.contains("score")),
"expected InvalidInput about missing 'score' column, got: {err:?}"
);
}
}
}

View File

@@ -23,17 +23,12 @@ impl VectorIndex {
.fields
.iter()
.map(|field_id| {
manifest
.schema
.field_by_id(*field_id)
.unwrap_or_else(|| {
panic!(
"field {field_id} of index {} must exist in schema",
index.name
)
})
.name
.clone()
manifest.schema.field_path(*field_id).unwrap_or_else(|_| {
panic!(
"field {field_id} of index {} must exist in schema",
index.name
)
})
})
.collect();
Self {

View File

@@ -8,7 +8,6 @@
pub(crate) mod client;
pub(crate) mod db;
pub mod oauth;
mod retry;
pub(crate) mod table;
pub(crate) mod util;
@@ -21,4 +20,3 @@ const JSON_CONTENT_TYPE: &str = "application/json";
pub use client::{ClientConfig, HeaderProvider, RetryConfig, TimeoutConfig, TlsConfig};
pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder};
pub use oauth::{OAuthConfig, OAuthFlow, OAuthHeaderProvider};

View File

@@ -1,906 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use log::{debug, info, warn};
use reqwest::Client;
use serde::Deserialize;
use tokio::sync::RwLock;
use crate::error::{Error, Result};
use crate::remote::client::HeaderProvider;
const DEFAULT_REFRESH_BUFFER_SECS: u64 = 300;
const DEFAULT_CALLBACK_PORT: u16 = 8400;
const AZURE_IMDS_ENDPOINT: &str = "http://169.254.169.254/metadata/identity/oauth2/token";
const AZURE_IMDS_API_VERSION: &str = "2018-02-01";
/// OAuth authentication flow configuration.
#[derive(Debug, Clone)]
pub enum OAuthFlow {
/// Client Credentials grant (service-to-service / M2M).
/// Requires `client_secret` in [`OAuthConfig`].
ClientCredentials,
/// Authorization Code with PKCE (interactive browser-based auth).
AuthorizationCodePKCE {
/// Redirect URI (default: `http://localhost:{callback_port}/callback`)
redirect_uri: Option<String>,
/// Port for the local HTTP callback server (default: 8400)
callback_port: Option<u16>,
},
/// Device Code grant (CLI / headless environments).
/// Displays a verification URL and user code for out-of-band authentication.
DeviceCode,
/// Azure Managed Identity via IMDS.
/// Works on Azure VMs, AKS, App Service, and Azure Functions.
AzureManagedIdentity {
/// Client ID for user-assigned managed identity.
/// Omit for system-assigned managed identity.
client_id: Option<String>,
},
/// Workload Identity Federation.
/// Exchanges a platform token (K8s service account, GitHub OIDC) for an IdP token.
WorkloadIdentity {
/// Path to the federated token file
/// (e.g. `AZURE_FEDERATED_TOKEN_FILE`).
token_file: String,
},
}
/// OAuth configuration for LanceDB authentication.
///
/// All token acquisition and refresh is handled in the Rust layer.
/// Python and TypeScript bindings expose this as a plain config object.
#[derive(Debug, Clone)]
pub struct OAuthConfig {
/// OIDC issuer URL or OAuth authority URL.
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
pub issuer_url: String,
/// Application / Client ID.
pub client_id: String,
/// Client secret (required for `ClientCredentials`, optional for others).
pub client_secret: Option<String>,
/// OAuth scopes to request.
/// For Azure: `["api://{app_id}/.default"]`
pub scopes: Vec<String>,
/// Authentication flow to use.
pub flow: OAuthFlow,
/// Seconds before token expiry to trigger proactive refresh (default: 300).
pub refresh_buffer_secs: Option<u64>,
}
// -- OIDC Discovery --
#[derive(Debug, Deserialize)]
struct OidcDiscovery {
token_endpoint: String,
authorization_endpoint: Option<String>,
device_authorization_endpoint: Option<String>,
}
// -- Token Response --
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
/// Token lifetime in seconds.
/// Some providers (Azure IMDS) return this as a string, so we accept both.
#[serde(default, deserialize_with = "deserialize_optional_u64_or_string")]
expires_in: Option<u64>,
#[serde(default)]
#[allow(dead_code)]
token_type: Option<String>,
}
fn deserialize_optional_u64_or_string<'de, D>(
deserializer: D,
) -> std::result::Result<Option<u64>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de;
struct U64OrString;
impl<'de> de::Visitor<'de> for U64OrString {
type Value = Option<u64>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a u64, a numeric string, or null")
}
fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
Ok(Some(v))
}
fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
Ok(Some(v as u64))
}
fn visit_str<E: de::Error>(self, v: &str) -> std::result::Result<Self::Value, E> {
v.parse::<u64>().map(Some).map_err(de::Error::custom)
}
fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
Ok(None)
}
fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
Ok(None)
}
}
deserializer.deserialize_any(U64OrString)
}
// -- Device Code Response --
#[derive(Debug, Deserialize)]
struct DeviceCodeResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(default)]
verification_uri_complete: Option<String>,
expires_in: u64,
interval: Option<u64>,
}
// -- Internal Token State --
struct TokenState {
access_token: Option<String>,
refresh_token: Option<String>,
expires_at: Option<Instant>,
}
impl TokenState {
fn new() -> Self {
Self {
access_token: None,
refresh_token: None,
expires_at: None,
}
}
fn is_expired(&self, buffer: Duration) -> bool {
match (self.access_token.as_ref(), self.expires_at) {
(Some(_), Some(expires_at)) => Instant::now() + buffer >= expires_at,
(None, _) => true,
(Some(_), None) => false, // no expiry info, assume valid
}
}
fn update(&mut self, resp: &TokenResponse) {
self.access_token = Some(resp.access_token.clone());
if resp.refresh_token.is_some() {
self.refresh_token = resp.refresh_token.clone();
}
self.expires_at = resp
.expires_in
.map(|secs| Instant::now() + Duration::from_secs(secs));
}
}
/// OAuth header provider that manages the full token lifecycle.
///
/// Implements [`HeaderProvider`] to inject `Authorization: Bearer <token>`
/// headers into every LanceDB request, with automatic token refresh.
pub struct OAuthHeaderProvider {
config: OAuthConfig,
http_client: Client,
token_state: Arc<RwLock<TokenState>>,
/// Cached OIDC discovery document
discovery: Arc<RwLock<Option<OidcDiscovery>>>,
refresh_buffer: Duration,
}
impl std::fmt::Debug for OAuthHeaderProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OAuthHeaderProvider")
.field("issuer_url", &self.config.issuer_url)
.field("client_id", &self.config.client_id)
.field("flow", &self.config.flow)
.finish()
}
}
impl OAuthHeaderProvider {
/// Create a new OAuth header provider from configuration.
pub fn new(config: OAuthConfig) -> Result<Self> {
// Validate config upfront
if matches!(config.flow, OAuthFlow::ClientCredentials) && config.client_secret.is_none() {
return Err(Error::InvalidInput {
message: "client_secret is required for ClientCredentials flow".to_string(),
});
}
if config.scopes.is_empty() {
return Err(Error::InvalidInput {
message: "At least one OAuth scope is required".to_string(),
});
}
let http_client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| Error::Runtime {
message: format!("Failed to create HTTP client for OAuth: {e}"),
})?;
let refresh_buffer = Duration::from_secs(
config
.refresh_buffer_secs
.unwrap_or(DEFAULT_REFRESH_BUFFER_SECS),
);
Ok(Self {
config,
http_client,
token_state: Arc::new(RwLock::new(TokenState::new())),
discovery: Arc::new(RwLock::new(None)),
refresh_buffer,
})
}
/// Get a valid access token, refreshing if necessary.
async fn get_valid_token(&self) -> Result<String> {
// Fast path: check if current token is still valid
{
let state = self.token_state.read().await;
if !state.is_expired(self.refresh_buffer)
&& let Some(ref token) = state.access_token
{
return Ok(token.clone());
}
}
// Slow path: acquire or refresh token
let mut state = self.token_state.write().await;
// Double-check after acquiring write lock
if !state.is_expired(self.refresh_buffer)
&& let Some(ref token) = state.access_token
{
return Ok(token.clone());
}
let uses_refresh_token = !matches!(
self.config.flow,
OAuthFlow::ClientCredentials
| OAuthFlow::AzureManagedIdentity { .. }
| OAuthFlow::WorkloadIdentity { .. }
);
let resp = if let Some(ref refresh_token) = state.refresh_token
&& uses_refresh_token
{
debug!("Refreshing OAuth token using refresh_token");
self.refresh_with_token(refresh_token).await?
} else {
debug!("Acquiring new OAuth token via {:?} flow", self.config.flow);
self.acquire_token().await?
};
state.update(&resp);
Ok(resp.access_token)
}
/// Acquire a new token using the configured flow.
async fn acquire_token(&self) -> Result<TokenResponse> {
match &self.config.flow {
OAuthFlow::ClientCredentials => self.acquire_client_credentials().await,
OAuthFlow::AuthorizationCodePKCE {
redirect_uri,
callback_port,
} => {
self.acquire_authorization_code_pkce(
redirect_uri.as_deref(),
callback_port.unwrap_or(DEFAULT_CALLBACK_PORT),
)
.await
}
OAuthFlow::DeviceCode => self.acquire_device_code().await,
OAuthFlow::AzureManagedIdentity { client_id } => {
self.acquire_managed_identity(client_id.as_deref()).await
}
OAuthFlow::WorkloadIdentity { token_file } => {
self.acquire_workload_identity(token_file).await
}
}
}
// -- OIDC Discovery --
async fn get_discovery(&self) -> Result<OidcDiscovery> {
{
let cached = self.discovery.read().await;
if let Some(ref disc) = *cached {
return Ok(OidcDiscovery {
token_endpoint: disc.token_endpoint.clone(),
authorization_endpoint: disc.authorization_endpoint.clone(),
device_authorization_endpoint: disc.device_authorization_endpoint.clone(),
});
}
}
let mut cache = self.discovery.write().await;
// Double-check
if let Some(ref disc) = *cache {
return Ok(OidcDiscovery {
token_endpoint: disc.token_endpoint.clone(),
authorization_endpoint: disc.authorization_endpoint.clone(),
device_authorization_endpoint: disc.device_authorization_endpoint.clone(),
});
}
let discovery_url = format!(
"{}/.well-known/openid-configuration",
self.config.issuer_url.trim_end_matches('/')
);
debug!("Fetching OIDC discovery from {}", discovery_url);
let resp = self
.http_client
.get(&discovery_url)
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to fetch OIDC discovery document: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"OIDC discovery failed with status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
let disc: OidcDiscovery = resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse OIDC discovery document: {e}"),
})?;
let result = OidcDiscovery {
token_endpoint: disc.token_endpoint.clone(),
authorization_endpoint: disc.authorization_endpoint.clone(),
device_authorization_endpoint: disc.device_authorization_endpoint.clone(),
};
*cache = Some(disc);
Ok(result)
}
fn get_token_endpoint_from_issuer(&self) -> String {
// Derive Azure v2.0 token endpoint from issuer URL
// issuer: https://login.microsoftonline.com/{tenant}/v2.0
// token: https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token
let base = self.config.issuer_url.trim_end_matches("/v2.0");
format!("{base}/oauth2/v2.0/token")
}
async fn get_token_endpoint(&self) -> Result<String> {
match self.get_discovery().await {
Ok(disc) => Ok(disc.token_endpoint),
Err(e) => {
warn!("OIDC discovery failed, falling back to derived endpoint: {e}");
Ok(self.get_token_endpoint_from_issuer())
}
}
}
fn scopes_string(&self) -> String {
self.config.scopes.join(" ")
}
// -- Client Credentials Flow --
async fn acquire_client_credentials(&self) -> Result<TokenResponse> {
let client_secret = self
.config
.client_secret
.as_ref()
.ok_or(Error::InvalidInput {
message: "client_secret is required for ClientCredentials flow".to_string(),
})?;
let token_endpoint = self.get_token_endpoint().await?;
let params = [
("grant_type", "client_credentials"),
("client_id", &self.config.client_id),
("client_secret", client_secret),
("scope", &self.scopes_string()),
];
self.post_token_request(&token_endpoint, &params).await
}
// -- Authorization Code + PKCE Flow --
async fn acquire_authorization_code_pkce(
&self,
redirect_uri: Option<&str>,
callback_port: u16,
) -> Result<TokenResponse> {
use rand::Rng;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
let discovery = self.get_discovery().await?;
let auth_endpoint = discovery.authorization_endpoint.ok_or(Error::Runtime {
message: "OIDC discovery did not provide authorization_endpoint".to_string(),
})?;
let default_redirect = format!("http://localhost:{callback_port}/callback");
let redirect = redirect_uri.unwrap_or(&default_redirect);
// Generate PKCE code verifier and challenge (S256)
const PKCE_CHARSET: &[u8] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
let code_verifier: String = {
let mut rng = rand::rng();
(0..128)
.map(|_| {
let idx = rng.random_range(0..PKCE_CHARSET.len());
PKCE_CHARSET[idx] as char
})
.collect()
};
let code_challenge = {
use sha2::{Digest, Sha256};
let hash = Sha256::digest(code_verifier.as_bytes());
base64_url_encode(&hash)
};
let state: String = {
let mut rng = rand::rng();
(0..32)
.map(|_| {
let idx = rng.random_range(0..16u8);
b"0123456789abcdef"[idx as usize] as char
})
.collect()
};
// Build authorization URL
let auth_url = format!(
"{auth_endpoint}?response_type=code&client_id={}&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={state}",
urlencoding::encode(&self.config.client_id),
urlencoding::encode(redirect),
urlencoding::encode(&self.scopes_string()),
urlencoding::encode(&code_challenge),
);
info!("Opening browser for OAuth login...");
info!("If the browser doesn't open, visit: {auth_url}");
// Try to open browser
let _ = open::that(&auth_url);
// Start local callback server
let listener = TcpListener::bind(format!("127.0.0.1:{callback_port}"))
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to bind callback server on port {callback_port}: {e}"),
})?;
info!("Waiting for OAuth callback on port {callback_port}...");
let (mut stream, _) = listener.accept().await.map_err(|e| Error::Runtime {
message: format!("Failed to accept callback connection: {e}"),
})?;
// Read the HTTP request
let mut buf = vec![0u8; 4096];
let n = tokio::io::AsyncReadExt::read(&mut stream, &mut buf)
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to read callback request: {e}"),
})?;
let request_str = String::from_utf8_lossy(&buf[..n]);
// Extract authorization code from query params
let code = extract_query_param(&request_str, "code").ok_or(Error::Runtime {
message: "No authorization code in callback".to_string(),
})?;
let returned_state = extract_query_param(&request_str, "state");
if returned_state.as_deref() != Some(&state) {
return Err(Error::Runtime {
message: "OAuth state mismatch — possible CSRF attack".to_string(),
});
}
// Send success response to browser
let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n<html><body><h2>Authentication successful!</h2><p>You can close this window.</p></body></html>";
let _ = stream.write_all(response.as_bytes()).await;
// Exchange code for tokens
let token_endpoint = self.get_token_endpoint().await?;
let mut params = vec![
("grant_type", "authorization_code"),
("client_id", self.config.client_id.as_str()),
("code", &code),
("redirect_uri", redirect),
("code_verifier", &code_verifier),
];
if let Some(ref secret) = self.config.client_secret {
params.push(("client_secret", secret));
}
self.post_token_request(&token_endpoint, &params).await
}
// -- Device Code Flow --
async fn acquire_device_code(&self) -> Result<TokenResponse> {
let discovery = self.get_discovery().await?;
let device_endpoint = discovery
.device_authorization_endpoint
.ok_or(Error::Runtime {
message: "OIDC discovery did not provide device_authorization_endpoint".to_string(),
})?;
let params = [
("client_id", self.config.client_id.as_str()),
("scope", &self.scopes_string()),
];
let resp = self
.http_client
.post(&device_endpoint)
.form(&params)
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Device code request failed: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"Device code request failed with status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
let device_resp: DeviceCodeResponse = resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse device code response: {e}"),
})?;
// Display instructions to user
info!(
"To sign in, visit {} and enter code: {}",
device_resp.verification_uri, device_resp.user_code
);
if let Some(ref uri) = device_resp.verification_uri_complete {
info!("Or visit: {uri}");
}
// Poll token endpoint
let token_endpoint = self.get_token_endpoint().await?;
let poll_interval = Duration::from_secs(device_resp.interval.unwrap_or(5));
let deadline = Instant::now() + Duration::from_secs(device_resp.expires_in);
loop {
if Instant::now() >= deadline {
return Err(Error::Runtime {
message: "Device code flow timed out waiting for user authentication"
.to_string(),
});
}
tokio::time::sleep(poll_interval).await;
let poll_params = [
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
("client_id", self.config.client_id.as_str()),
("device_code", &device_resp.device_code),
];
let poll_resp = self
.http_client
.post(&token_endpoint)
.form(&poll_params)
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Device code poll failed: {e}"),
})?;
if poll_resp.status().is_success() {
return poll_resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse token response: {e}"),
});
}
// Check for pending / slow_down errors
let body = poll_resp.text().await.unwrap_or_default();
if body.contains("authorization_pending") {
continue;
}
if body.contains("slow_down") {
tokio::time::sleep(Duration::from_secs(5)).await;
continue;
}
return Err(Error::Runtime {
message: format!("Device code poll failed: {body}"),
});
}
}
// -- Azure Managed Identity Flow --
async fn acquire_managed_identity(&self, mi_client_id: Option<&str>) -> Result<TokenResponse> {
let resource = self.scopes_string().replace("/.default", "");
let mut url = format!(
"{AZURE_IMDS_ENDPOINT}?api-version={AZURE_IMDS_API_VERSION}&resource={}",
urlencoding::encode(&resource),
);
if let Some(cid) = mi_client_id {
url.push_str(&format!("&client_id={}", urlencoding::encode(cid)));
}
let resp = self
.http_client
.get(&url)
.header("Metadata", "true")
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Azure IMDS request failed: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"Azure IMDS returned status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse IMDS token response: {e}"),
})
}
// -- Workload Identity Federation Flow --
async fn acquire_workload_identity(&self, token_file: &str) -> Result<TokenResponse> {
let federated_token =
tokio::fs::read_to_string(token_file)
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to read federated token file '{token_file}': {e}"),
})?;
let token_endpoint = self.get_token_endpoint().await?;
let params = [
("grant_type", "client_credentials"),
("client_id", self.config.client_id.as_str()),
(
"client_assertion_type",
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
),
("client_assertion", federated_token.trim()),
("scope", &self.scopes_string()),
];
self.post_token_request(&token_endpoint, &params).await
}
// -- Refresh Token Flow --
async fn refresh_with_token(&self, refresh_token: &str) -> Result<TokenResponse> {
let token_endpoint = self.get_token_endpoint().await?;
let mut params = vec![
("grant_type", "refresh_token"),
("client_id", self.config.client_id.as_str()),
("refresh_token", refresh_token),
];
if let Some(ref secret) = self.config.client_secret {
params.push(("client_secret", secret.as_str()));
}
self.post_token_request(&token_endpoint, &params).await
}
// -- Shared Helpers --
async fn post_token_request(
&self,
endpoint: &str,
params: &[(&str, &str)],
) -> Result<TokenResponse> {
let resp = self
.http_client
.post(endpoint)
.form(params)
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Token request to {endpoint} failed: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"Token request failed with status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse token response: {e}"),
})
}
}
#[async_trait::async_trait]
impl HeaderProvider for OAuthHeaderProvider {
async fn get_headers(&self) -> Result<HashMap<String, String>> {
let token = self.get_valid_token().await?;
Ok(HashMap::from([(
"authorization".to_string(),
format!("Bearer {token}"),
)]))
}
}
// -- Utility functions --
fn base64_url_encode(input: &[u8]) -> String {
use base64::Engine;
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(input)
}
/// Extract a query parameter value from a raw HTTP GET request line.
fn extract_query_param(request: &str, param: &str) -> Option<String> {
let first_line = request.lines().next()?;
let path = first_line.split_whitespace().nth(1)?;
let query = path.split('?').nth(1)?;
for pair in query.split('&') {
let mut kv = pair.splitn(2, '=');
if let (Some(key), Some(value)) = (kv.next(), kv.next())
&& key == param
{
return Some(urlencoding::decode(value).ok()?.into_owned());
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_query_param() {
let request = "GET /callback?code=abc123&state=xyz HTTP/1.1\r\nHost: localhost\r\n";
assert_eq!(
extract_query_param(request, "code"),
Some("abc123".to_string())
);
assert_eq!(
extract_query_param(request, "state"),
Some("xyz".to_string())
);
assert_eq!(extract_query_param(request, "missing"), None);
}
#[test]
fn test_extract_query_param_encoded() {
let request = "GET /callback?code=abc%20123&state=x%26y HTTP/1.1\r\n";
assert_eq!(
extract_query_param(request, "code"),
Some("abc 123".to_string())
);
assert_eq!(
extract_query_param(request, "state"),
Some("x&y".to_string())
);
}
#[test]
fn test_token_state_expiry() {
let mut state = TokenState::new();
assert!(state.is_expired(Duration::from_secs(0)));
state.access_token = Some("tok".to_string());
state.expires_at = Some(Instant::now() + Duration::from_secs(600));
assert!(!state.is_expired(Duration::from_secs(300)));
assert!(state.is_expired(Duration::from_secs(601)));
}
#[test]
fn test_base64_url_encode() {
let input = b"hello world";
let encoded = base64_url_encode(input);
assert!(!encoded.contains('+'));
assert!(!encoded.contains('/'));
assert!(!encoded.contains('='));
}
#[test]
fn test_scopes_string() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
client_id: "app-id".to_string(),
client_secret: Some("secret".to_string()),
scopes: vec!["scope1".to_string(), "scope2".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let provider = OAuthHeaderProvider::new(config).unwrap();
assert_eq!(provider.scopes_string(), "scope1 scope2");
}
#[test]
fn test_token_endpoint_derivation() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/my-tenant/v2.0".to_string(),
client_id: "id".to_string(),
client_secret: None,
scopes: vec!["api://test/.default".to_string()],
flow: OAuthFlow::DeviceCode,
refresh_buffer_secs: None,
};
let provider = OAuthHeaderProvider::new(config).unwrap();
assert_eq!(
provider.get_token_endpoint_from_issuer(),
"https://login.microsoftonline.com/my-tenant/oauth2/v2.0/token"
);
}
#[test]
fn test_client_credentials_requires_secret() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
client_id: "app-id".to_string(),
client_secret: None,
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
assert!(OAuthHeaderProvider::new(config).is_err());
}
#[test]
fn test_empty_scopes_rejected() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
client_id: "app-id".to_string(),
client_secret: None,
scopes: vec![],
flow: OAuthFlow::DeviceCode,
refresh_buffer_secs: None,
};
assert!(OAuthHeaderProvider::new(config).is_err());
}
}

View File

@@ -27,7 +27,9 @@ use crate::table::UpdateResult;
use crate::table::query::create_multi_vector_plan;
use crate::table::{AnyQuery, Filter, PreprocessingOutput, TableStatistics};
use crate::utils::background_cache::BackgroundCache;
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
use crate::utils::{
resolve_arrow_field_path, supported_btree_data_type, supported_vector_data_type,
};
use crate::{DistanceType, Error};
use crate::{
error::Result,
@@ -1563,11 +1565,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
Index::FTS(p) => ("FTS", Some(to_json(p)?)),
Index::Auto => {
let schema = self.schema().await?;
let field = schema
.field_with_name(&column)
.map_err(|_| Error::InvalidInput {
message: format!("Column {} not found in schema", column),
})?;
let field = resolve_arrow_field_path(&schema, &column)?;
if supported_vector_data_type(field.data_type()) {
body[METRIC_TYPE_KEY] =
serde_json::Value::String(DistanceType::L2.to_string().to_lowercase());
@@ -3505,7 +3503,7 @@ mod tests {
{
"index_name": "my_idx",
"index_uuid": "34255f64-5717-4562-b3fc-2c963f66afa6",
"columns": ["my_column"],
"columns": ["metadata.`my.column`"],
"index_status": "done",
},
]
@@ -3544,7 +3542,7 @@ mod tests {
IndexConfig {
name: "my_idx".into(),
index_type: IndexType::LabelList,
columns: vec!["my_column".into()],
columns: vec!["metadata.`my.column`".into()],
},
];
assert_eq!(indices, expected);

View File

@@ -282,17 +282,15 @@ pub use self::merge::MergeResult;
/// date) and [`LsmWriteSpec::with_writer_config_defaults`] (default
/// `ShardWriter` configuration recorded in the MemWAL index).
///
/// All variants require the table to have an unenforced primary key.
///
/// Install a spec with [`Table::set_lsm_write_spec`] and remove it with
/// [`Table::unset_lsm_write_spec`]. The actual `merge_insert` dispatch
/// onto the MemWAL writer is a follow-up.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum LsmWriteSpec {
/// Hash-bucket sharding by the unenforced primary key column.
/// Hash-bucket sharding by a scalar column.
///
/// `column` must equal the table's currently-set single-column
/// unenforced primary key. `num_buckets` must be in `[1, 1024]`.
/// `column` must be a non-nested column with a supported scalar type.
/// `num_buckets` must be in `[1, 1024]`.
/// Iceberg-compatible Murmur3-x86-32 (seed 0) is used so each row's
/// `bucket(column, num_buckets)` value is stable across processes.
Bucket {
@@ -1298,21 +1296,15 @@ impl Table {
///
/// [`LsmWriteSpec`] chooses one of three sharding strategies:
///
/// - [`LsmWriteSpec::bucket`] — hash-bucket writes by the single-column
/// unenforced primary key.
/// - [`LsmWriteSpec::bucket`] — hash-bucket writes by a scalar column.
/// - [`LsmWriteSpec::identity`] — shard by the raw value of a scalar column.
/// - [`LsmWriteSpec::unsharded`] — route every write to a single shard.
///
/// All variants require the table to have an unenforced primary key
/// ([`Table::set_unenforced_primary_key`]); bucket sharding additionally
/// requires it to be the single column being bucketed.
///
/// # Example
///
/// ```
/// # use lancedb::table::{LsmWriteSpec, Table};
/// # async fn example(table: &Table) -> Result<(), Box<dyn std::error::Error>> {
/// table.set_unenforced_primary_key(["id"]).await?;
/// table
/// .set_lsm_write_spec(
/// LsmWriteSpec::bucket("id", 16).with_maintained_indexes(["id_idx"]),
@@ -2171,6 +2163,33 @@ impl NativeTable {
}
}
fn resolve_index_field(
schema: &lance_core::datatypes::Schema,
column: &str,
) -> Result<(String, Field)> {
lance_core::datatypes::parse_field_path(column).map_err(|e| Error::InvalidInput {
message: format!("Invalid field path `{}`: {}", column, e),
})?;
let field_path = schema
.resolve_case_insensitive(column)
.ok_or_else(|| Error::Schema {
message: format!(
"Field path `{}` not found in schema. Available field paths: {}",
column,
schema.field_paths().join(", ")
),
})?;
let field = field_path.last().expect("field path should be non-empty");
let path_segments = field_path
.iter()
.map(|field| field.name.as_str())
.collect::<Vec<_>>();
let canonical_path = lance_core::datatypes::format_field_path(&path_segments);
Ok((canonical_path, Field::from(*field)))
}
// Convert LanceDB Index to Lance IndexParams
async fn make_index_params(
&self,
@@ -2661,15 +2680,13 @@ impl BaseTable for NativeTable {
message: "Multi-column (composite) indices are not yet supported".to_string(),
});
}
let schema = self.schema().await?;
let field = schema.field_with_name(&opts.columns[0])?;
let lance_idx_params = self.make_index_params(field, opts.index.clone()).await?;
let index_type = self.get_index_type_for_field(field, &opts.index);
let columns = [field.name().as_str()];
self.dataset.ensure_mutable()?;
let mut dataset = (*self.dataset.get().await?).clone();
let (column, field) = Self::resolve_index_field(dataset.schema(), &opts.columns[0])?;
let lance_idx_params = self.make_index_params(&field, opts.index.clone()).await?;
let index_type = self.get_index_type_for_field(&field, &opts.index);
let columns = [column.as_str()];
let mut builder = dataset
.create_index_builder(&columns, index_type, lance_idx_params.as_ref())
.train(opts.train)
@@ -2787,54 +2804,88 @@ impl BaseTable for NativeTable {
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
let dataset = self.dataset.get().await?;
let indices = dataset.load_indices().await?;
let results = futures::stream::iter(indices.as_slice()).then(|idx| async {
// skip Lance internal indexes
if idx.name == FRAG_REUSE_INDEX_NAME {
return None;
}
let stats = match dataset.index_statistics(idx.name.as_str()).await {
Ok(stats) => stats,
Err(e) => {
log::warn!("Failed to get statistics for index {} ({}): {}", idx.name, idx.uuid, e);
let results = futures::stream::iter(indices.as_slice())
.then(|idx| async {
// skip Lance internal indexes
if idx.name == FRAG_REUSE_INDEX_NAME {
return None;
}
};
let stats: serde_json::Value = match serde_json::from_str(&stats) {
Ok(stats) => stats,
Err(e) => {
log::warn!("Failed to deserialize index statistics for index {} ({}): {}", idx.name, idx.uuid, e);
return None;
}
};
let stats = match dataset.index_statistics(idx.name.as_str()).await {
Ok(stats) => stats,
Err(e) => {
log::warn!(
"Failed to get statistics for index {} ({}): {}",
idx.name,
idx.uuid,
e
);
return None;
}
};
let Some(index_type) = stats.get("index_type").and_then(|v| v.as_str()) else {
log::warn!("Index statistics was missing 'index_type' field for index {} ({})", idx.name, idx.uuid);
return None;
};
let stats: serde_json::Value = match serde_json::from_str(&stats) {
Ok(stats) => stats,
Err(e) => {
log::warn!(
"Failed to deserialize index statistics for index {} ({}): {}",
idx.name,
idx.uuid,
e
);
return None;
}
};
let index_type: crate::index::IndexType = match index_type.parse() {
Ok(index_type) => index_type,
Err(e) => {
log::warn!("Failed to parse index type for index {} ({}): {}", idx.name, idx.uuid, e);
return None;
}
};
let mut columns = Vec::with_capacity(idx.fields.len());
for field_id in &idx.fields {
let Some(field) = dataset.schema().field_by_id(*field_id) else {
log::warn!("The index {} ({}) referenced a field with id {} which does not exist in the schema", idx.name, idx.uuid, field_id);
let Some(index_type) = stats.get("index_type").and_then(|v| v.as_str()) else {
log::warn!(
"Index statistics was missing 'index_type' field for index {} ({})",
idx.name,
idx.uuid
);
return None;
};
columns.push(field.name.clone());
}
let name = idx.name.clone();
Some(IndexConfig { index_type, columns, name })
}).collect::<Vec<_>>().await;
let index_type: crate::index::IndexType = match index_type.parse() {
Ok(index_type) => index_type,
Err(e) => {
log::warn!(
"Failed to parse index type for index {} ({}): {}",
idx.name,
idx.uuid,
e
);
return None;
}
};
let mut columns = Vec::with_capacity(idx.fields.len());
for field_id in &idx.fields {
let field_path = match dataset.schema().field_path(*field_id) {
Ok(field_path) => field_path,
Err(e) => {
log::warn!(
"Failed to resolve field path for index {} ({}) field id {}: {}",
idx.name,
idx.uuid,
field_id,
e
);
return None;
}
};
columns.push(field_path);
}
let name = idx.name.clone();
Some(IndexConfig {
index_type,
columns,
name,
})
})
.collect::<Vec<_>>()
.await;
Ok(results.into_iter().flatten().collect())
}
@@ -3037,13 +3088,14 @@ pub struct FragmentSummaryStats {
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use arrow_array::{
Array, BooleanArray, FixedSizeListArray, Int32Array, LargeStringArray, RecordBatch,
RecordBatchIterator, RecordBatchReader, StringArray,
Array, ArrayRef, BooleanArray, FixedSizeListArray, Int32Array, LargeStringArray,
RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray, StructArray,
builder::{ListBuilder, StringBuilder},
};
use arrow_array::{BinaryArray, LargeBinaryArray};
@@ -3063,6 +3115,7 @@ mod tests {
use crate::query::Select;
use crate::query::{ExecutableQuery, QueryBase};
use crate::test_utils::connection::new_test_connection;
use lance_index::scalar::FullTextSearchQuery;
#[tokio::test]
async fn test_open() {
let tmp_dir = tempdir().unwrap();
@@ -3650,6 +3703,222 @@ mod tests {
assert_eq!(stats.num_unindexed_rows, 0);
}
#[tokio::test]
async fn test_create_index_nested_field_paths() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let conn = ConnectBuilder::new(uri).execute().await.unwrap();
let num_rows = 512;
let dimension = 8;
let metadata = Arc::new(StructArray::from(vec![(
Arc::new(Field::new("user_id", DataType::Int32, false)),
Arc::new(Int32Array::from_iter_values(0..num_rows)) as ArrayRef,
)]));
let vector_values = arrow_array::Float32Array::from_iter_values(
(0..num_rows * dimension).map(|v| v as f32),
);
let embeddings =
Arc::new(create_fixed_size_list(vector_values, dimension).unwrap()) as ArrayRef;
let image = Arc::new(StructArray::from(vec![(
Arc::new(Field::new(
"embedding",
embeddings.data_type().clone(),
false,
)),
embeddings,
)]));
let payload = Arc::new(StructArray::from(vec![(
Arc::new(Field::new("text", DataType::Utf8, false)),
Arc::new(StringArray::from_iter_values(
(0..num_rows).map(|i| format!("document {}", i)),
)) as ArrayRef,
)]));
let meta_data = Arc::new(StructArray::from(vec![(
Arc::new(Field::new("user-id", DataType::Int32, false)),
Arc::new(Int32Array::from_iter_values(0..num_rows)) as ArrayRef,
)]));
let literal = Arc::new(StructArray::from(vec![(
Arc::new(Field::new("a.b", DataType::Int32, false)),
Arc::new(Int32Array::from_iter_values(0..num_rows)) as ArrayRef,
)]));
let schema = Arc::new(Schema::new(vec![
Field::new("metadata", metadata.data_type().clone(), false),
Field::new("image", image.data_type().clone(), false),
Field::new("payload", payload.data_type().clone(), false),
Field::new("meta-data", meta_data.data_type().clone(), false),
Field::new("literal", literal.data_type().clone(), false),
]));
let batch =
RecordBatch::try_new(schema, vec![metadata, image, payload, meta_data, literal])
.unwrap();
let table = conn
.create_table("nested_index_paths", batch)
.execute()
.await
.unwrap();
table
.create_index(
&["metadata.user_id"],
Index::BTree(BTreeIndexBuilder::default()),
)
.name("metadata_user_id_idx".to_string())
.execute()
.await
.unwrap();
table
.create_index(&["image.embedding"], Index::Auto)
.name("image_embedding_idx".to_string())
.execute()
.await
.unwrap();
table
.create_index(&["payload.text"], Index::FTS(Default::default()))
.name("payload_text_idx".to_string())
.execute()
.await
.unwrap();
table
.create_index(
&["`meta-data`.`user-id`"],
Index::BTree(BTreeIndexBuilder::default()),
)
.name("escaped_names_idx".to_string())
.execute()
.await
.unwrap();
table
.create_index(
&["literal.`a.b`"],
Index::BTree(BTreeIndexBuilder::default()),
)
.name("literal_dot_idx".to_string())
.execute()
.await
.unwrap();
let mut index_configs = table.list_indices().await.unwrap();
index_configs.sort_by(|left, right| left.name.cmp(&right.name));
let indexed_columns = index_configs
.iter()
.map(|index| {
(
index.name.as_str(),
index.columns.as_slice(),
index.index_type.clone(),
)
})
.collect::<Vec<_>>();
assert_eq!(
indexed_columns,
vec![
(
"escaped_names_idx",
&["`meta-data`.`user-id`".to_string()][..],
crate::index::IndexType::BTree,
),
(
"image_embedding_idx",
&["image.embedding".to_string()][..],
crate::index::IndexType::IvfPq,
),
(
"literal_dot_idx",
&["literal.`a.b`".to_string()][..],
crate::index::IndexType::BTree,
),
(
"metadata_user_id_idx",
&["metadata.user_id".to_string()][..],
crate::index::IndexType::BTree,
),
(
"payload_text_idx",
&["payload.text".to_string()][..],
crate::index::IndexType::FTS,
),
]
);
let vector_results = table
.query()
.nearest_to(&[0.0; 8])
.unwrap()
.column("image.embedding")
.limit(1)
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(
vector_results
.iter()
.map(|batch| batch.num_rows())
.sum::<usize>(),
1
);
let default_vector_results = table
.query()
.nearest_to(&[0.0; 8])
.unwrap()
.limit(1)
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(
default_vector_results
.iter()
.map(|batch| batch.num_rows())
.sum::<usize>(),
1
);
let fts_results = table
.query()
.full_text_search(FullTextSearchQuery::new("document".to_string()))
.limit(5)
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert!(!fts_results.is_empty());
let filtered_results = table
.query()
.only_if("metadata.user_id = 42")
.limit(1)
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(
filtered_results
.iter()
.map(|batch| batch.num_rows())
.sum::<usize>(),
1
);
}
#[tokio::test]
async fn test_create_bitmap_index() {
let tmp_dir = tempdir().unwrap();
@@ -4323,21 +4592,6 @@ mod tests {
.unwrap();
let table = conn.create_table("t", reader).execute().await.unwrap();
// Reject when no PK is set.
let err = table
.set_lsm_write_spec(LsmWriteSpec::bucket("id", 4))
.await
.expect_err("should reject without PK");
assert!(matches!(err, Error::Lance { .. }), "got {:?}", err);
// Set PK, then a mismatched column on the spec must be rejected.
table.set_unenforced_primary_key(["id"]).await.unwrap();
let err = table
.set_lsm_write_spec(LsmWriteSpec::bucket("name", 4))
.await
.expect_err("should reject column != PK");
assert!(matches!(err, Error::Lance { .. }), "got {:?}", err);
// Reject num_buckets out of range.
for bad in [0u32, 1025] {
let err = table
@@ -4403,9 +4657,6 @@ mod tests {
.unwrap();
let table = conn.create_table("t", reader).execute().await.unwrap();
// Lance's MemWAL still requires *some* unenforced primary key on
// the dataset; Unsharded just skips the per-row hashing step.
table.set_unenforced_primary_key(["id"]).await.unwrap();
table
.set_lsm_write_spec(LsmWriteSpec::unsharded())
.await
@@ -4452,7 +4703,6 @@ mod tests {
.unwrap();
let table = conn.create_table("t", reader).execute().await.unwrap();
table.set_unenforced_primary_key(["id"]).await.unwrap();
table
.set_lsm_write_spec(
LsmWriteSpec::identity("region")
@@ -4508,7 +4758,6 @@ mod tests {
table.unset_lsm_write_spec().await.unwrap_err();
// Install a spec, then unset it.
table.set_unenforced_primary_key(["id"]).await.unwrap();
table
.set_lsm_write_spec(LsmWriteSpec::bucket("id", 4))
.await

View File

@@ -268,7 +268,9 @@ mod tests {
};
use crate::query::{ExecutableQuery, QueryBase, Select};
use crate::table::add_data::NaNVectorBehavior;
use crate::table::{ColumnDefinition, ColumnKind, Table, TableDefinition, WriteOptions};
use crate::table::{
ColumnDefinition, ColumnKind, NewColumnTransform, Table, TableDefinition, WriteOptions,
};
use crate::test_utils::TestCustomError;
use crate::test_utils::embeddings::MockEmbed;
@@ -518,6 +520,225 @@ mod tests {
}
}
/// Regression test for https://github.com/lancedb/lancedb/issues/3136.
///
/// When a column is added via `add_columns` AFTER an embedding column,
/// the table schema becomes `[..., embedding, extra]`. Subsequent
/// `table.add()` calls used to fail with a CastError because columns
/// were matched positionally rather than by name.
#[tokio::test]
async fn test_add_with_embeddings_after_add_columns() {
let registry = Arc::new(MemoryRegistry::new());
let mock_embedding: Arc<dyn EmbeddingFunction> = Arc::new(MockEmbed::new("mock", 4));
registry.register("mock", mock_embedding).unwrap();
let conn = connect("memory://")
.embedding_registry(registry)
.execute()
.await
.unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("text", DataType::Utf8, false),
Field::new(
"text_vec",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
false,
),
]));
let embedding_def = EmbeddingDefinition::new("text", "mock", Some("text_vec"));
let table_def = TableDefinition::new(
schema.clone(),
vec![
ColumnDefinition {
kind: ColumnKind::Physical,
},
ColumnDefinition {
kind: ColumnKind::Embedding(embedding_def),
},
],
);
let rich_schema = table_def.into_rich_schema();
let table = conn
.create_empty_table("embed_evol_test", rich_schema)
.execute()
.await
.unwrap();
// Seed a row so add_columns has data to compute against.
let seed_batch = record_batch!(("text", Utf8, ["hello"])).unwrap();
table.add(seed_batch).execute().await.unwrap();
// Add a new physical column AFTER the embedding column.
table
.add_columns(
NewColumnTransform::SqlExpressions(vec![("score".into(), "42.0".into())]),
None,
)
.await
.unwrap();
// Now add data including the new column but WITHOUT the embedding.
// The input batch column order is [text, score]; after computing the
// embedding it becomes [text, score, text_vec], but the table schema
// is [text, text_vec, score]. Columns must be matched by name.
let new_schema = Arc::new(Schema::new(vec![
Field::new("text", DataType::Utf8, false),
Field::new("score", DataType::Float64, true),
]));
let new_batch = RecordBatch::try_new(
new_schema,
vec![
Arc::new(arrow_array::StringArray::from(vec!["foo", "bar"])),
Arc::new(arrow_array::Float64Array::from(vec![1.0, 2.0])),
],
)
.unwrap();
table.add(new_batch).execute().await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 3);
let results: Vec<RecordBatch> = table
.query()
.select(Select::columns(&["text", "text_vec", "score"]))
.execute()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 3);
for batch in &results {
// text_vec must be populated for the newly added rows too.
assert_eq!(batch.column(1).null_count(), 0);
}
}
/// Like `test_add_with_embeddings_after_add_columns`, but the column
/// added after the embedding is a nested struct rather than a scalar.
/// Verifies that name-based column matching also works when the
/// post-embedding column has a complex Arrow type.
#[tokio::test]
async fn test_add_with_embeddings_after_add_nested_columns() {
let registry = Arc::new(MemoryRegistry::new());
let mock_embedding: Arc<dyn EmbeddingFunction> = Arc::new(MockEmbed::new("mock", 4));
registry.register("mock", mock_embedding).unwrap();
let conn = connect("memory://")
.embedding_registry(registry)
.execute()
.await
.unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("text", DataType::Utf8, false),
Field::new(
"text_vec",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
false,
),
]));
let embedding_def = EmbeddingDefinition::new("text", "mock", Some("text_vec"));
let table_def = TableDefinition::new(
schema,
vec![
ColumnDefinition {
kind: ColumnKind::Physical,
},
ColumnDefinition {
kind: ColumnKind::Embedding(embedding_def),
},
],
);
let rich_schema = table_def.into_rich_schema();
let table = conn
.create_empty_table("embed_nested_test", rich_schema)
.execute()
.await
.unwrap();
let seed_batch = record_batch!(("text", Utf8, ["hello"])).unwrap();
table.add(seed_batch).execute().await.unwrap();
// Add a STRUCT column after the embedding column.
let meta_struct = DataType::Struct(
vec![
Field::new("source", DataType::Utf8, true),
Field::new("score", DataType::Float64, true),
]
.into(),
);
let nested_schema = Arc::new(Schema::new(vec![Field::new(
"meta",
meta_struct.clone(),
true,
)]));
table
.add_columns(NewColumnTransform::AllNulls(nested_schema), None)
.await
.unwrap();
// Insert with the nested struct present but the embedding column
// absent. The computed batch is [text, meta, text_vec], but the
// table schema is [text, text_vec, meta] — only name-based matching
// can put `meta` (a struct) in the right slot.
let source = Arc::new(arrow_array::StringArray::from(vec!["foo", "bar"]));
let score = Arc::new(arrow_array::Float64Array::from(vec![1.0, 2.0]));
let meta = Arc::new(arrow_array::StructArray::from(vec![
(
Arc::new(Field::new("source", DataType::Utf8, true)),
source as Arc<dyn arrow_array::Array>,
),
(
Arc::new(Field::new("score", DataType::Float64, true)),
score as Arc<dyn arrow_array::Array>,
),
]));
let new_schema = Arc::new(Schema::new(vec![
Field::new("text", DataType::Utf8, false),
Field::new("meta", meta_struct, true),
]));
let new_batch = RecordBatch::try_new(
new_schema,
vec![
Arc::new(arrow_array::StringArray::from(vec!["foo", "bar"])),
meta,
],
)
.unwrap();
table.add(new_batch).execute().await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 3);
let results: Vec<RecordBatch> = table
.query()
.select(Select::columns(&["text", "text_vec", "meta"]))
.execute()
.await
.unwrap()
.try_collect()
.await
.unwrap();
let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 3);
for batch in &results {
assert_eq!(batch.schema().field(2).name(), "meta");
assert!(matches!(
batch.schema().field(2).data_type(),
DataType::Struct(_)
));
// text_vec must be populated for the newly added rows too.
assert_eq!(batch.column(1).null_count(), 0);
}
}
#[tokio::test]
async fn test_add_casts_to_table_schema() {
let table_schema = Arc::new(Schema::new(vec![

View File

@@ -6,7 +6,7 @@ pub(crate) mod background_cache;
use std::sync::Arc;
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Schema, SchemaRef};
use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef};
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_execution::RecordBatchStream;
use futures::{FutureExt, Stream};
@@ -152,14 +152,10 @@ pub fn validate_namespace(namespace: &[String]) -> Result<()> {
/// Find one default column to create index or perform vector query.
pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result<String> {
// Try to find a vector column.
let candidates = schema
.fields()
.iter()
.filter_map(|field| match infer_vector_dim(field.data_type()) {
Ok(d) if dim.is_none() || dim == Some(d as i32) => Some(field.name()),
_ => None,
})
.collect::<Vec<_>>();
let mut candidates = Vec::new();
for field in schema.fields() {
collect_vector_columns(field, &mut Vec::new(), dim, &mut candidates);
}
if candidates.is_empty() {
Err(Error::InvalidInput {
message: format!(
@@ -180,6 +176,63 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result
}
}
fn collect_vector_columns(
field: &Field,
path: &mut Vec<String>,
dim: Option<i32>,
candidates: &mut Vec<String>,
) {
path.push(field.name().clone());
match infer_vector_dim(field.data_type()) {
Ok(d) if dim.is_none() || dim == Some(d as i32) => {
let path_segments = path.iter().map(String::as_str).collect::<Vec<_>>();
candidates.push(lance_core::datatypes::format_field_path(&path_segments));
}
_ => {
if let DataType::Struct(fields) = field.data_type() {
for child in fields {
collect_vector_columns(child, path, dim, candidates);
}
}
}
}
path.pop();
}
pub(crate) fn resolve_arrow_field_path(schema: &Schema, column: &str) -> Result<Field> {
let segments =
lance_core::datatypes::parse_field_path(column).map_err(|e| Error::InvalidInput {
message: format!("Invalid field path `{}`: {}", column, e),
})?;
let mut fields = schema.fields();
for (idx, segment) in segments.iter().enumerate() {
let field = find_field(fields, segment).ok_or_else(|| Error::Schema {
message: format!("Field path `{}` not found in schema", column),
})?;
if idx + 1 == segments.len() {
return Ok(field.clone());
}
fields = match field.data_type() {
DataType::Struct(fields) => fields,
_ => {
return Err(Error::Schema {
message: format!("Field path `{}` not found in schema", column),
});
}
};
}
unreachable!("parse_field_path returns at least one segment")
}
fn find_field<'a>(fields: &'a Fields, name: &str) -> Option<&'a Field> {
fields
.iter()
.find(|field| field.name() == name)
.map(|field| field.as_ref())
}
pub fn supported_btree_data_type(dtype: &DataType) -> bool {
dtype.is_integer()
|| dtype.is_floating()
@@ -450,6 +503,49 @@ mod tests {
"vec"
);
let schema_with_nested_vec_col = Schema::new(vec![
Field::new("id", DataType::Int16, true),
Field::new(
"image",
DataType::Struct(
vec![Field::new(
"embedding",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, false)),
10,
),
false,
)]
.into(),
),
false,
),
]);
assert_eq!(
default_vector_column(&schema_with_nested_vec_col, None).unwrap(),
"image.embedding"
);
let schema_with_escaped_nested_vec_col = Schema::new(vec![Field::new(
"image-meta",
DataType::Struct(
vec![Field::new(
"embedding.v1",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, false)),
10,
),
false,
)]
.into(),
),
false,
)]);
assert_eq!(
default_vector_column(&schema_with_escaped_nested_vec_col, None).unwrap(),
"`image-meta`.`embedding.v1`"
);
let multi_vec_col = Schema::new(vec![
Field::new("id", DataType::Int16, true),
Field::new(
@@ -469,6 +565,48 @@ mod tests {
.to_string()
.contains("More than one")
);
let multi_nested_vec_col = Schema::new(vec![
Field::new(
"image",
DataType::Struct(
vec![Field::new(
"embedding",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, false)),
10,
),
false,
)]
.into(),
),
false,
),
Field::new(
"text",
DataType::Struct(
vec![Field::new(
"embedding",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, false)),
50,
),
false,
)]
.into(),
),
false,
),
]);
assert_eq!(
default_vector_column(&multi_nested_vec_col, Some(50)).unwrap(),
"text.embedding"
);
let err = default_vector_column(&multi_nested_vec_col, None)
.unwrap_err()
.to_string();
assert!(err.contains("image.embedding"));
assert!(err.contains("text.embedding"));
}
#[test]