mirror of
https://github.com/lancedb/lancedb.git
synced 2026-03-26 10:30:40 +00:00
Compare commits
46 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6de8f42dcd | ||
|
|
5c3bd68e58 | ||
|
|
4be85444f0 | ||
|
|
68c07f333f | ||
|
|
814a379e08 | ||
|
|
f31561c5bb | ||
|
|
e0c5ceac03 | ||
|
|
e93bb3355a | ||
|
|
b75991eb07 | ||
|
|
97ca9bb943 | ||
|
|
fa1b04f341 | ||
|
|
367abe99d2 | ||
|
|
52ce2c995c | ||
|
|
e71a00998c | ||
|
|
39a2ac0a1c | ||
|
|
bc7b344fa4 | ||
|
|
f91d2f5fec | ||
|
|
cf81b6419f | ||
|
|
0498ac1f2f | ||
|
|
aeb1c3ee6a | ||
|
|
f9ae46c0e7 | ||
|
|
84bf022fb1 | ||
|
|
310967eceb | ||
|
|
154dbeee2a | ||
|
|
c9c08ac8b9 | ||
|
|
e253f5d9b6 | ||
|
|
05b4fb0990 | ||
|
|
613b9c1099 | ||
|
|
d5948576b9 | ||
|
|
0d3fc7860a | ||
|
|
531cec075c | ||
|
|
0e486511fa | ||
|
|
367262662d | ||
|
|
11efaf46ae | ||
|
|
1ea22ee5ef | ||
|
|
8cef8806e9 | ||
|
|
a3cd7fce69 | ||
|
|
48ddc833dd | ||
|
|
2802764092 | ||
|
|
37bbb0dba1 | ||
|
|
155ec16161 | ||
|
|
636b8b5bbd | ||
|
|
715b81c86b | ||
|
|
7e1616376e | ||
|
|
d5ac5b949a | ||
|
|
7be6f45e0b |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.26.2"
|
||||
current_version = "0.27.0-beta.4"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -29,6 +29,7 @@ runs:
|
||||
if: ${{ inputs.arm-build == 'false' }}
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
maturin-version: "1.12.4"
|
||||
command: build
|
||||
working-directory: python
|
||||
docker-options: "-e PIP_EXTRA_INDEX_URL='https://pypi.fury.io/lance-format/ https://pypi.fury.io/lancedb/'"
|
||||
@@ -44,6 +45,7 @@ runs:
|
||||
if: ${{ inputs.arm-build == 'true' }}
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
maturin-version: "1.12.4"
|
||||
command: build
|
||||
working-directory: python
|
||||
docker-options: "-e PIP_EXTRA_INDEX_URL='https://pypi.fury.io/lance-format/ https://pypi.fury.io/lancedb/'"
|
||||
|
||||
1
.github/workflows/build_mac_wheel/action.yml
vendored
1
.github/workflows/build_mac_wheel/action.yml
vendored
@@ -20,6 +20,7 @@ runs:
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
command: build
|
||||
maturin-version: "1.12.4"
|
||||
# TODO: pass through interpreter
|
||||
args: ${{ inputs.args }}
|
||||
docker-options: "-e PIP_EXTRA_INDEX_URL='https://pypi.fury.io/lance-format/ https://pypi.fury.io/lancedb/'"
|
||||
|
||||
@@ -25,6 +25,7 @@ runs:
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
command: build
|
||||
maturin-version: "1.12.4"
|
||||
args: ${{ inputs.args }}
|
||||
docker-options: "-e PIP_EXTRA_INDEX_URL='https://pypi.fury.io/lance-format/ https://pypi.fury.io/lancedb/'"
|
||||
working-directory: python
|
||||
|
||||
15
.github/workflows/nodejs.yml
vendored
15
.github/workflows/nodejs.yml
vendored
@@ -8,6 +8,7 @@ on:
|
||||
paths:
|
||||
- Cargo.toml
|
||||
- nodejs/**
|
||||
- rust/**
|
||||
- docs/src/js/**
|
||||
- .github/workflows/nodejs.yml
|
||||
- docker-compose.yml
|
||||
@@ -77,8 +78,11 @@ jobs:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- uses: actions/setup-node@v3
|
||||
name: Setup Node.js 20 for build
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
# @napi-rs/cli v3 requires Node >= 20.12 (via @inquirer/prompts@8).
|
||||
# Build always on Node 20; tests run on the matrix version below.
|
||||
node-version: 20
|
||||
cache: 'npm'
|
||||
cache-dependency-path: nodejs/package-lock.json
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
@@ -86,12 +90,16 @@ jobs:
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install -y protobuf-compiler libssl-dev
|
||||
npm install -g @napi-rs/cli
|
||||
- name: Build
|
||||
run: |
|
||||
npm ci --include=optional
|
||||
npm run build:debug -- --profile ci
|
||||
npm run tsc
|
||||
- uses: actions/setup-node@v3
|
||||
name: Setup Node.js ${{ matrix.node-version }} for test
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
- name: Compile TypeScript
|
||||
run: npm run tsc
|
||||
- name: Setup localstack
|
||||
working-directory: .
|
||||
run: docker compose up --detach --wait
|
||||
@@ -144,7 +152,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
brew install protobuf
|
||||
npm install -g @napi-rs/cli
|
||||
- name: Build
|
||||
run: |
|
||||
npm ci --include=optional
|
||||
|
||||
44
.github/workflows/npm-publish.yml
vendored
44
.github/workflows/npm-publish.yml
vendored
@@ -128,16 +128,13 @@ jobs:
|
||||
- target: x86_64-unknown-linux-musl
|
||||
# This one seems to need some extra memory
|
||||
host: ubuntu-2404-8x-x64
|
||||
# https://github.com/napi-rs/napi-rs/blob/main/alpine.Dockerfile
|
||||
docker: ghcr.io/napi-rs/napi-rs/nodejs-rust:lts-alpine
|
||||
features: fp16kernels
|
||||
pre_build: |-
|
||||
set -e &&
|
||||
apk add protobuf-dev curl &&
|
||||
ln -s /usr/lib/gcc/x86_64-alpine-linux-musl/14.2.0/crtbeginS.o /usr/lib/crtbeginS.o &&
|
||||
ln -s /usr/lib/libgcc_s.so /usr/lib/libgcc.so &&
|
||||
CC=gcc &&
|
||||
CXX=g++
|
||||
sudo apt-get update &&
|
||||
sudo apt-get install -y protobuf-compiler pkg-config &&
|
||||
rustup target add x86_64-unknown-linux-musl &&
|
||||
export EXTRA_ARGS="-x"
|
||||
- target: aarch64-unknown-linux-gnu
|
||||
host: ubuntu-2404-8x-x64
|
||||
# https://github.com/napi-rs/napi-rs/blob/main/debian-aarch64.Dockerfile
|
||||
@@ -153,15 +150,13 @@ jobs:
|
||||
rustup target add aarch64-unknown-linux-gnu
|
||||
- target: aarch64-unknown-linux-musl
|
||||
host: ubuntu-2404-8x-x64
|
||||
# https://github.com/napi-rs/napi-rs/blob/main/alpine.Dockerfile
|
||||
docker: ghcr.io/napi-rs/napi-rs/nodejs-rust:lts-alpine
|
||||
features: ","
|
||||
pre_build: |-
|
||||
set -e &&
|
||||
apk add protobuf-dev &&
|
||||
sudo apt-get update &&
|
||||
sudo apt-get install -y protobuf-compiler &&
|
||||
rustup target add aarch64-unknown-linux-musl &&
|
||||
export CC_aarch64_unknown_linux_musl=aarch64-linux-musl-gcc &&
|
||||
export CXX_aarch64_unknown_linux_musl=aarch64-linux-musl-g++
|
||||
export EXTRA_ARGS="-x"
|
||||
name: build - ${{ matrix.settings.target }}
|
||||
runs-on: ${{ matrix.settings.host }}
|
||||
defaults:
|
||||
@@ -192,12 +187,18 @@ jobs:
|
||||
.cargo-cache
|
||||
target/
|
||||
key: nodejs-${{ matrix.settings.target }}-cargo-${{ matrix.settings.host }}
|
||||
- name: Setup toolchain
|
||||
run: ${{ matrix.settings.setup }}
|
||||
if: ${{ matrix.settings.setup }}
|
||||
shell: bash
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
- name: Install Zig
|
||||
uses: mlugg/setup-zig@v2
|
||||
if: ${{ contains(matrix.settings.target, 'musl') }}
|
||||
with:
|
||||
version: 0.14.1
|
||||
- name: Install cargo-zigbuild
|
||||
uses: taiki-e/install-action@v2
|
||||
if: ${{ contains(matrix.settings.target, 'musl') }}
|
||||
with:
|
||||
tool: cargo-zigbuild
|
||||
- name: Build in docker
|
||||
uses: addnab/docker-run-action@v3
|
||||
if: ${{ matrix.settings.docker }}
|
||||
@@ -210,24 +211,24 @@ jobs:
|
||||
run: |
|
||||
set -e
|
||||
${{ matrix.settings.pre_build }}
|
||||
npx napi build --platform --release --no-const-enum \
|
||||
npx napi build --platform --release \
|
||||
--features ${{ matrix.settings.features }} \
|
||||
--target ${{ matrix.settings.target }} \
|
||||
--dts ../lancedb/native.d.ts \
|
||||
--js ../lancedb/native.js \
|
||||
--strip \
|
||||
dist/
|
||||
--output-dir dist/
|
||||
- name: Build
|
||||
run: |
|
||||
${{ matrix.settings.pre_build }}
|
||||
npx napi build --platform --release --no-const-enum \
|
||||
npx napi build --platform --release \
|
||||
--features ${{ matrix.settings.features }} \
|
||||
--target ${{ matrix.settings.target }} \
|
||||
--dts ../lancedb/native.d.ts \
|
||||
--js ../lancedb/native.js \
|
||||
--strip \
|
||||
$EXTRA_ARGS \
|
||||
dist/
|
||||
--output-dir dist/
|
||||
if: ${{ !matrix.settings.docker }}
|
||||
shell: bash
|
||||
- name: Upload artifact
|
||||
@@ -355,7 +356,8 @@ jobs:
|
||||
if [[ $DRY_RUN == "true" ]]; then
|
||||
ARGS="$ARGS --dry-run"
|
||||
fi
|
||||
if [[ $GITHUB_REF =~ refs/tags/v(.*)-beta.* ]]; then
|
||||
VERSION=$(node -p "require('./package.json').version")
|
||||
if [[ $VERSION == *-* ]]; then
|
||||
ARGS="$ARGS --tag preview"
|
||||
fi
|
||||
npm publish $ARGS
|
||||
|
||||
5
.github/workflows/python.yml
vendored
5
.github/workflows/python.yml
vendored
@@ -8,7 +8,12 @@ on:
|
||||
paths:
|
||||
- Cargo.toml
|
||||
- python/**
|
||||
- rust/**
|
||||
- .github/workflows/python.yml
|
||||
- .github/workflows/build_linux_wheel/**
|
||||
- .github/workflows/build_mac_wheel/**
|
||||
- .github/workflows/build_windows_wheel/**
|
||||
- .github/workflows/run_tests/**
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
|
||||
6
.github/workflows/rust.yml
vendored
6
.github/workflows/rust.yml
vendored
@@ -100,7 +100,9 @@ jobs:
|
||||
lfs: true
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- name: Install dependencies
|
||||
run: sudo apt install -y protobuf-compiler libssl-dev
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install -y protobuf-compiler libssl-dev
|
||||
- uses: rui314/setup-mold@v1
|
||||
- name: Make Swap
|
||||
run: |
|
||||
@@ -183,7 +185,7 @@ jobs:
|
||||
runs-on: ubuntu-24.04
|
||||
strategy:
|
||||
matrix:
|
||||
msrv: ["1.88.0"] # This should match up with rust-version in Cargo.toml
|
||||
msrv: ["1.91.0"] # This should match up with rust-version in Cargo.toml
|
||||
env:
|
||||
# Need up-to-date compilers for kernels
|
||||
CC: clang-18
|
||||
|
||||
929
Cargo.lock
generated
929
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
48
Cargo.toml
48
Cargo.toml
@@ -5,30 +5,30 @@ exclude = ["python"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
edition = "2021"
|
||||
edition = "2024"
|
||||
authors = ["LanceDB Devs <dev@lancedb.com>"]
|
||||
license = "Apache-2.0"
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||
categories = ["database-implementations"]
|
||||
rust-version = "1.88.0"
|
||||
rust-version = "1.91.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=2.0.1", default-features = false }
|
||||
lance-core = "=2.0.1"
|
||||
lance-datagen = "=2.0.1"
|
||||
lance-file = "=2.0.1"
|
||||
lance-io = { "version" = "=2.0.1", default-features = false }
|
||||
lance-index = "=2.0.1"
|
||||
lance-linalg = "=2.0.1"
|
||||
lance-namespace = "=2.0.1"
|
||||
lance-namespace-impls = { "version" = "=2.0.1", default-features = false }
|
||||
lance-table = "=2.0.1"
|
||||
lance-testing = "=2.0.1"
|
||||
lance-datafusion = "=2.0.1"
|
||||
lance-encoding = "=2.0.1"
|
||||
lance-arrow = "=2.0.1"
|
||||
lance = { "version" = "=3.0.0-rc.3", default-features = false, "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=3.0.0-rc.3", default-features = false, "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=3.0.0-rc.3", default-features = false, "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
ahash = "0.8"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "57.2", optional = false }
|
||||
@@ -40,13 +40,15 @@ arrow-schema = "57.2"
|
||||
arrow-select = "57.2"
|
||||
arrow-cast = "57.2"
|
||||
async-trait = "0"
|
||||
datafusion = { version = "51.0", default-features = false }
|
||||
datafusion-catalog = "51.0"
|
||||
datafusion-common = { version = "51.0", default-features = false }
|
||||
datafusion-execution = "51.0"
|
||||
datafusion-expr = "51.0"
|
||||
datafusion-physical-plan = "51.0"
|
||||
datafusion-physical-expr = "51.0"
|
||||
datafusion = { version = "52.1", default-features = false }
|
||||
datafusion-catalog = "52.1"
|
||||
datafusion-common = { version = "52.1", default-features = false }
|
||||
datafusion-execution = "52.1"
|
||||
datafusion-expr = "52.1"
|
||||
datafusion-functions = "52.1"
|
||||
datafusion-physical-plan = "52.1"
|
||||
datafusion-physical-expr = "52.1"
|
||||
datafusion-sql = "52.1"
|
||||
env_logger = "0.11"
|
||||
half = { "version" = "2.7.1", default-features = false, features = [
|
||||
"num-traits",
|
||||
|
||||
@@ -52,14 +52,21 @@ plugins:
|
||||
options:
|
||||
docstring_style: numpy
|
||||
heading_level: 3
|
||||
show_source: true
|
||||
show_symbol_type_in_heading: true
|
||||
show_signature_annotations: true
|
||||
show_root_heading: true
|
||||
show_docstring_examples: true
|
||||
show_docstring_attributes: false
|
||||
show_docstring_other_parameters: true
|
||||
show_symbol_type_heading: true
|
||||
show_labels: false
|
||||
show_if_no_docstring: true
|
||||
show_source: false
|
||||
members_order: source
|
||||
docstring_section_style: list
|
||||
signature_crossrefs: true
|
||||
separate_signature: true
|
||||
filters:
|
||||
- "!^_"
|
||||
import:
|
||||
# for cross references
|
||||
- https://arrow.apache.org/docs/objects.inv
|
||||
@@ -113,7 +120,7 @@ markdown_extensions:
|
||||
emoji_index: !!python/name:material.extensions.emoji.twemoji
|
||||
emoji_generator: !!python/name:material.extensions.emoji.to_svg
|
||||
- markdown.extensions.toc:
|
||||
toc_depth: 3
|
||||
toc_depth: 4
|
||||
permalink: true
|
||||
permalink_title: Anchor link to this section
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
||||
<dependency>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-core</artifactId>
|
||||
<version>0.26.2</version>
|
||||
<version>0.27.0-beta.4</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
|
||||
@@ -8,6 +8,14 @@
|
||||
|
||||
## Properties
|
||||
|
||||
### numDeletedRows
|
||||
|
||||
```ts
|
||||
numDeletedRows: number;
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### version
|
||||
|
||||
```ts
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# LanceDB Java SDK
|
||||
# LanceDB Java Enterprise Client
|
||||
|
||||
## Configuration and Initialization
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.26.2-final.0</version>
|
||||
<version>0.27.0-beta.4</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.26.2-final.0</version>
|
||||
<version>0.27.0-beta.4</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
@@ -28,7 +28,7 @@
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<arrow.version>15.0.0</arrow.version>
|
||||
<lance-core.version>2.0.1</lance-core.version>
|
||||
<lance-core.version>3.1.0-beta.2</lance-core.version>
|
||||
<spotless.skip>false</spotless.skip>
|
||||
<spotless.version>2.30.0</spotless.version>
|
||||
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.26.2"
|
||||
version = "0.27.0-beta.4"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
@@ -19,11 +19,11 @@ arrow-schema.workspace = true
|
||||
env_logger.workspace = true
|
||||
futures.workspace = true
|
||||
lancedb = { path = "../rust/lancedb", default-features = false }
|
||||
napi = { version = "2.16.8", default-features = false, features = [
|
||||
napi = { version = "3.8.3", default-features = false, features = [
|
||||
"napi9",
|
||||
"async"
|
||||
] }
|
||||
napi-derive = "2.16.4"
|
||||
napi-derive = "3.5.2"
|
||||
# Prevent dynamic linking of lzma, which comes from datafusion
|
||||
lzma-sys = { version = "*", features = ["static"] }
|
||||
log.workspace = true
|
||||
@@ -33,7 +33,7 @@ aws-lc-sys = "=0.28.0"
|
||||
aws-lc-rs = "=1.13.0"
|
||||
|
||||
[build-dependencies]
|
||||
napi-build = "2.1"
|
||||
napi-build = "2.3.1"
|
||||
|
||||
[features]
|
||||
default = ["remote", "lancedb/aws", "lancedb/gcs", "lancedb/azure", "lancedb/dynamodb", "lancedb/oss", "lancedb/huggingface"]
|
||||
|
||||
@@ -63,6 +63,7 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
tableFromIPC,
|
||||
DataType,
|
||||
Dictionary,
|
||||
Uint8: ArrowUint8,
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
} = <any>arrow;
|
||||
type Schema = ApacheArrow["Schema"];
|
||||
@@ -362,6 +363,38 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
).toEqual(new Float64().toString());
|
||||
});
|
||||
|
||||
it("will infer FixedSizeList<Float32> from Float32Array values", async function () {
|
||||
const table = makeArrowTable([
|
||||
{ id: "a", vector: new Float32Array([0.1, 0.2, 0.3]) },
|
||||
{ id: "b", vector: new Float32Array([0.4, 0.5, 0.6]) },
|
||||
]);
|
||||
|
||||
expect(DataType.isFixedSizeList(table.getChild("vector")?.type)).toBe(
|
||||
true,
|
||||
);
|
||||
const vectorType = table.getChild("vector")?.type;
|
||||
expect(vectorType.listSize).toBe(3);
|
||||
expect(vectorType.children[0].type.toString()).toEqual(
|
||||
new Float32().toString(),
|
||||
);
|
||||
});
|
||||
|
||||
it("will infer FixedSizeList<Uint8> from Uint8Array values", async function () {
|
||||
const table = makeArrowTable([
|
||||
{ id: "a", vector: new Uint8Array([1, 2, 3]) },
|
||||
{ id: "b", vector: new Uint8Array([4, 5, 6]) },
|
||||
]);
|
||||
|
||||
expect(DataType.isFixedSizeList(table.getChild("vector")?.type)).toBe(
|
||||
true,
|
||||
);
|
||||
const vectorType = table.getChild("vector")?.type;
|
||||
expect(vectorType.listSize).toBe(3);
|
||||
expect(vectorType.children[0].type.toString()).toEqual(
|
||||
new ArrowUint8().toString(),
|
||||
);
|
||||
});
|
||||
|
||||
it("will use dictionary encoded strings if asked", async function () {
|
||||
const table = makeArrowTable([{ str: "hello" }]);
|
||||
expect(DataType.isUtf8(table.getChild("str")?.type)).toBe(true);
|
||||
|
||||
@@ -1697,6 +1697,65 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
expect(results2[0].text).toBe(data[1].text);
|
||||
});
|
||||
|
||||
test("full text search fast search", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const data = [{ text: "hello world", vector: [0.1, 0.2, 0.3], id: 1 }];
|
||||
const table = await db.createTable("test", data);
|
||||
await table.createIndex("text", {
|
||||
config: Index.fts(),
|
||||
});
|
||||
|
||||
// Insert unindexed data after creating the index.
|
||||
await table.add([{ text: "xyz", vector: [0.4, 0.5, 0.6], id: 2 }]);
|
||||
|
||||
const withFlatSearch = await table
|
||||
.search("xyz", "fts")
|
||||
.limit(10)
|
||||
.toArray();
|
||||
expect(withFlatSearch.length).toBeGreaterThan(0);
|
||||
|
||||
const fastSearchResults = await table
|
||||
.search("xyz", "fts")
|
||||
.fastSearch()
|
||||
.limit(10)
|
||||
.toArray();
|
||||
expect(fastSearchResults.length).toBe(0);
|
||||
|
||||
const nearestToTextFastSearch = await table
|
||||
.query()
|
||||
.nearestToText("xyz")
|
||||
.fastSearch()
|
||||
.limit(10)
|
||||
.toArray();
|
||||
expect(nearestToTextFastSearch.length).toBe(0);
|
||||
|
||||
// fastSearch should be chainable with other methods.
|
||||
const chainedFastSearch = await table
|
||||
.search("xyz", "fts")
|
||||
.fastSearch()
|
||||
.select(["text"])
|
||||
.limit(5)
|
||||
.toArray();
|
||||
expect(chainedFastSearch.length).toBe(0);
|
||||
|
||||
await table.optimize();
|
||||
|
||||
const indexedFastSearch = await table
|
||||
.search("xyz", "fts")
|
||||
.fastSearch()
|
||||
.limit(10)
|
||||
.toArray();
|
||||
expect(indexedFastSearch.length).toBeGreaterThan(0);
|
||||
|
||||
const indexedNearestToTextFastSearch = await table
|
||||
.query()
|
||||
.nearestToText("xyz")
|
||||
.fastSearch()
|
||||
.limit(10)
|
||||
.toArray();
|
||||
expect(indexedNearestToTextFastSearch.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
test("prewarm full text search index", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const data = [
|
||||
@@ -2145,3 +2204,36 @@ describe("when creating an empty table", () => {
|
||||
expect((actualSchema.fields[1].type as Float64).precision).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
// Ensure we can create float32 arrays without using Arrow
|
||||
// by utilizing native JS TypedArray support
|
||||
//
|
||||
// https://github.com/lancedb/lancedb/issues/3115
|
||||
describe("when creating a table with Float32Array vectors", () => {
|
||||
let tmpDir: tmp.DirResult;
|
||||
beforeEach(() => {
|
||||
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||
});
|
||||
afterEach(() => {
|
||||
tmpDir.removeCallback();
|
||||
});
|
||||
|
||||
it("should persist Float32Array as FixedSizeList<Float32> in the LanceDB schema", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const table = await db.createTable("test", [
|
||||
{ id: "a", vector: new Float32Array([0.1, 0.2, 0.3]) },
|
||||
{ id: "b", vector: new Float32Array([0.4, 0.5, 0.6]) },
|
||||
]);
|
||||
|
||||
const schema = await table.schema();
|
||||
const vectorField = schema.fields.find((f) => f.name === "vector");
|
||||
expect(vectorField).toBeDefined();
|
||||
expect(vectorField!.type).toBeInstanceOf(FixedSizeList);
|
||||
|
||||
const fsl = vectorField!.type as FixedSizeList;
|
||||
expect(fsl.listSize).toBe(3);
|
||||
expect(fsl.children[0].type.typeId).toBe(Type.Float);
|
||||
// precision: HALF=0, SINGLE=1, DOUBLE=2
|
||||
expect((fsl.children[0].type as Float32).precision).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -20,6 +20,8 @@ import {
|
||||
Float32,
|
||||
Float64,
|
||||
Int,
|
||||
Int8,
|
||||
Int16,
|
||||
Int32,
|
||||
Int64,
|
||||
LargeBinary,
|
||||
@@ -35,6 +37,8 @@ import {
|
||||
Timestamp,
|
||||
Type,
|
||||
Uint8,
|
||||
Uint16,
|
||||
Uint32,
|
||||
Utf8,
|
||||
Vector,
|
||||
makeVector as arrowMakeVector,
|
||||
@@ -529,7 +533,8 @@ function isObject(value: unknown): value is Record<string, unknown> {
|
||||
!(value instanceof Date) &&
|
||||
!(value instanceof Set) &&
|
||||
!(value instanceof Map) &&
|
||||
!(value instanceof Buffer)
|
||||
!(value instanceof Buffer) &&
|
||||
!ArrayBuffer.isView(value)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -588,6 +593,13 @@ function inferType(
|
||||
return new Bool();
|
||||
} else if (value instanceof Buffer) {
|
||||
return new Binary();
|
||||
} else if (ArrayBuffer.isView(value) && !(value instanceof DataView)) {
|
||||
const info = typedArrayToArrowType(value);
|
||||
if (info !== undefined) {
|
||||
const child = new Field("item", info.elementType, true);
|
||||
return new FixedSizeList(info.length, child);
|
||||
}
|
||||
return undefined;
|
||||
} else if (Array.isArray(value)) {
|
||||
if (value.length === 0) {
|
||||
return undefined; // Without any values we can't infer the type
|
||||
@@ -746,6 +758,32 @@ function makeListVector(lists: unknown[][]): Vector<unknown> {
|
||||
return listBuilder.finish().toVector();
|
||||
}
|
||||
|
||||
/**
|
||||
* Map a JS TypedArray instance to the corresponding Arrow element DataType
|
||||
* and its length. Returns undefined if the value is not a recognized TypedArray.
|
||||
*/
|
||||
function typedArrayToArrowType(
|
||||
value: ArrayBufferView,
|
||||
): { elementType: DataType; length: number } | undefined {
|
||||
if (value instanceof Float32Array)
|
||||
return { elementType: new Float32(), length: value.length };
|
||||
if (value instanceof Float64Array)
|
||||
return { elementType: new Float64(), length: value.length };
|
||||
if (value instanceof Uint8Array)
|
||||
return { elementType: new Uint8(), length: value.length };
|
||||
if (value instanceof Uint16Array)
|
||||
return { elementType: new Uint16(), length: value.length };
|
||||
if (value instanceof Uint32Array)
|
||||
return { elementType: new Uint32(), length: value.length };
|
||||
if (value instanceof Int8Array)
|
||||
return { elementType: new Int8(), length: value.length };
|
||||
if (value instanceof Int16Array)
|
||||
return { elementType: new Int16(), length: value.length };
|
||||
if (value instanceof Int32Array)
|
||||
return { elementType: new Int32(), length: value.length };
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/** Helper function to convert an Array of JS values to an Arrow Vector */
|
||||
function makeVector(
|
||||
values: unknown[],
|
||||
@@ -814,6 +852,16 @@ function makeVector(
|
||||
"makeVector cannot infer the type if all values are null or undefined",
|
||||
);
|
||||
}
|
||||
if (ArrayBuffer.isView(sampleValue) && !(sampleValue instanceof DataView)) {
|
||||
const info = typedArrayToArrowType(sampleValue);
|
||||
if (info !== undefined) {
|
||||
const fslType = new FixedSizeList(
|
||||
info.length,
|
||||
new Field("item", info.elementType, true),
|
||||
);
|
||||
return vectorFromArray(values, fslType);
|
||||
}
|
||||
}
|
||||
if (Array.isArray(sampleValue)) {
|
||||
// Default Arrow inference doesn't handle list types
|
||||
return makeListVector(values as unknown[][]);
|
||||
|
||||
@@ -273,7 +273,9 @@ export async function connect(
|
||||
let nativeProvider: NativeJsHeaderProvider | undefined;
|
||||
if (finalHeaderProvider) {
|
||||
if (typeof finalHeaderProvider === "function") {
|
||||
nativeProvider = new NativeJsHeaderProvider(finalHeaderProvider);
|
||||
nativeProvider = new NativeJsHeaderProvider(async () =>
|
||||
finalHeaderProvider(),
|
||||
);
|
||||
} else if (
|
||||
finalHeaderProvider &&
|
||||
typeof finalHeaderProvider.getHeaders === "function"
|
||||
|
||||
@@ -684,19 +684,17 @@ export class VectorQuery extends StandardQueryBase<NativeVectorQuery> {
|
||||
|
||||
rerank(reranker: Reranker): VectorQuery {
|
||||
super.doCall((inner) =>
|
||||
inner.rerank({
|
||||
rerankHybrid: async (_, args) => {
|
||||
const vecResults = await fromBufferToRecordBatch(args.vecResults);
|
||||
const ftsResults = await fromBufferToRecordBatch(args.ftsResults);
|
||||
const result = await reranker.rerankHybrid(
|
||||
args.query,
|
||||
vecResults as RecordBatch,
|
||||
ftsResults as RecordBatch,
|
||||
);
|
||||
inner.rerank(async (args) => {
|
||||
const vecResults = await fromBufferToRecordBatch(args.vecResults);
|
||||
const ftsResults = await fromBufferToRecordBatch(args.ftsResults);
|
||||
const result = await reranker.rerankHybrid(
|
||||
args.query,
|
||||
vecResults as RecordBatch,
|
||||
ftsResults as RecordBatch,
|
||||
);
|
||||
|
||||
const buffer = fromRecordBatchToBuffer(result);
|
||||
return buffer;
|
||||
},
|
||||
const buffer = fromRecordBatchToBuffer(result);
|
||||
return buffer;
|
||||
}),
|
||||
);
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.26.2",
|
||||
"version": "0.27.0-beta.4",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.26.2",
|
||||
"version": "0.27.0-beta.4",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.26.2",
|
||||
"version": "0.27.0-beta.4",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.26.2",
|
||||
"version": "0.27.0-beta.4",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.26.2",
|
||||
"version": "0.27.0-beta.4",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.26.2",
|
||||
"version": "0.27.0-beta.4",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.26.2",
|
||||
"version": "0.27.0-beta.4",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
1781
nodejs/package-lock.json
generated
1781
nodejs/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.26.2",
|
||||
"version": "0.27.0-beta.4",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
@@ -21,19 +21,16 @@
|
||||
},
|
||||
"types": "dist/index.d.ts",
|
||||
"napi": {
|
||||
"name": "lancedb",
|
||||
"triples": {
|
||||
"defaults": false,
|
||||
"additional": [
|
||||
"aarch64-apple-darwin",
|
||||
"x86_64-unknown-linux-gnu",
|
||||
"aarch64-unknown-linux-gnu",
|
||||
"x86_64-unknown-linux-musl",
|
||||
"aarch64-unknown-linux-musl",
|
||||
"x86_64-pc-windows-msvc",
|
||||
"aarch64-pc-windows-msvc"
|
||||
]
|
||||
}
|
||||
"binaryName": "lancedb",
|
||||
"targets": [
|
||||
"aarch64-apple-darwin",
|
||||
"x86_64-unknown-linux-gnu",
|
||||
"aarch64-unknown-linux-gnu",
|
||||
"x86_64-unknown-linux-musl",
|
||||
"aarch64-unknown-linux-musl",
|
||||
"x86_64-pc-windows-msvc",
|
||||
"aarch64-pc-windows-msvc"
|
||||
]
|
||||
},
|
||||
"license": "Apache-2.0",
|
||||
"repository": {
|
||||
@@ -46,7 +43,7 @@
|
||||
"@aws-sdk/client-s3": "^3.33.0",
|
||||
"@biomejs/biome": "^1.7.3",
|
||||
"@jest/globals": "^29.7.0",
|
||||
"@napi-rs/cli": "^2.18.3",
|
||||
"@napi-rs/cli": "^3.5.1",
|
||||
"@types/axios": "^0.14.0",
|
||||
"@types/jest": "^29.1.2",
|
||||
"@types/node": "^22.7.4",
|
||||
@@ -75,9 +72,9 @@
|
||||
"os": ["darwin", "linux", "win32"],
|
||||
"scripts": {
|
||||
"artifacts": "napi artifacts",
|
||||
"build:debug": "napi build --platform --no-const-enum --dts ../lancedb/native.d.ts --js ../lancedb/native.js lancedb",
|
||||
"build:debug": "napi build --platform --dts ../lancedb/native.d.ts --js ../lancedb/native.js --output-dir lancedb",
|
||||
"postbuild:debug": "shx mkdir -p dist && shx cp lancedb/*.node dist/",
|
||||
"build:release": "napi build --platform --no-const-enum --release --dts ../lancedb/native.d.ts --js ../lancedb/native.js dist/",
|
||||
"build:release": "napi build --platform --release --dts ../lancedb/native.d.ts --js ../lancedb/native.js --output-dir dist",
|
||||
"postbuild:release": "shx mkdir -p dist && shx cp lancedb/*.node dist/",
|
||||
"build": "npm run build:debug && npm run tsc",
|
||||
"build-release": "npm run build:release && npm run tsc",
|
||||
@@ -91,7 +88,7 @@
|
||||
"prepublishOnly": "napi prepublish -t npm",
|
||||
"test": "jest --verbose",
|
||||
"integration": "S3_TEST=1 npm run test",
|
||||
"universal": "napi universal",
|
||||
"universal": "napi universalize",
|
||||
"version": "napi version"
|
||||
},
|
||||
"dependencies": {
|
||||
|
||||
@@ -8,10 +8,10 @@ use lancedb::database::{CreateTableMode, Database};
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::*;
|
||||
|
||||
use crate::ConnectionOptions;
|
||||
use crate::error::NapiErrorExt;
|
||||
use crate::header::JsHeaderProvider;
|
||||
use crate::table::Table;
|
||||
use crate::ConnectionOptions;
|
||||
use lancedb::connection::{ConnectBuilder, Connection as LanceDBConnection};
|
||||
|
||||
use lancedb::ipc::{ipc_file_to_batches, ipc_file_to_schema};
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use napi::{
|
||||
bindgen_prelude::*,
|
||||
threadsafe_function::{ErrorStrategy, ThreadsafeFunction},
|
||||
};
|
||||
use napi::{bindgen_prelude::*, threadsafe_function::ThreadsafeFunction};
|
||||
use napi_derive::napi;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
type GetHeadersFn = ThreadsafeFunction<(), Promise<HashMap<String, String>>, (), Status, false>;
|
||||
|
||||
/// JavaScript HeaderProvider implementation that wraps a JavaScript callback.
|
||||
/// This is the only native header provider - all header provider implementations
|
||||
/// should provide a JavaScript function that returns headers.
|
||||
#[napi]
|
||||
pub struct JsHeaderProvider {
|
||||
get_headers_fn: Arc<ThreadsafeFunction<(), ErrorStrategy::CalleeHandled>>,
|
||||
get_headers_fn: Arc<GetHeadersFn>,
|
||||
}
|
||||
|
||||
impl Clone for JsHeaderProvider {
|
||||
@@ -29,9 +28,12 @@ impl Clone for JsHeaderProvider {
|
||||
impl JsHeaderProvider {
|
||||
/// Create a new JsHeaderProvider from a JavaScript callback
|
||||
#[napi(constructor)]
|
||||
pub fn new(get_headers_callback: JsFunction) -> Result<Self> {
|
||||
pub fn new(
|
||||
get_headers_callback: Function<(), Promise<HashMap<String, String>>>,
|
||||
) -> Result<Self> {
|
||||
let get_headers_fn = get_headers_callback
|
||||
.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))
|
||||
.build_threadsafe_function()
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
Error::new(
|
||||
Status::GenericFailure,
|
||||
@@ -51,7 +53,7 @@ impl lancedb::remote::HeaderProvider for JsHeaderProvider {
|
||||
async fn get_headers(&self) -> lancedb::error::Result<HashMap<String, String>> {
|
||||
// Call the JavaScript function asynchronously
|
||||
let promise: Promise<HashMap<String, String>> =
|
||||
self.get_headers_fn.call_async(Ok(())).await.map_err(|e| {
|
||||
self.get_headers_fn.call_async(()).await.map_err(|e| {
|
||||
lancedb::error::Error::Runtime {
|
||||
message: format!("Failed to call JavaScript get_headers: {}", e),
|
||||
}
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
|
||||
use std::sync::Mutex;
|
||||
|
||||
use lancedb::index::Index as LanceDbIndex;
|
||||
use lancedb::index::scalar::{BTreeIndexBuilder, FtsIndexBuilder};
|
||||
use lancedb::index::vector::{
|
||||
IvfFlatIndexBuilder, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder,
|
||||
IvfRqIndexBuilder,
|
||||
};
|
||||
use lancedb::index::Index as LanceDbIndex;
|
||||
use napi_derive::napi;
|
||||
|
||||
use crate::util::parse_distance_type;
|
||||
|
||||
@@ -60,7 +60,7 @@ pub struct OpenTableOptions {
|
||||
pub storage_options: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[napi::module_init]
|
||||
#[napi_derive::module_init]
|
||||
fn init() {
|
||||
let env = Env::new()
|
||||
.filter_or("LANCEDB_LOG", "warn")
|
||||
|
||||
@@ -17,11 +17,11 @@ use lancedb::query::VectorQuery as LanceDbVectorQuery;
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::napi;
|
||||
|
||||
use crate::error::convert_error;
|
||||
use crate::error::NapiErrorExt;
|
||||
use crate::error::convert_error;
|
||||
use crate::iterator::RecordBatchIterator;
|
||||
use crate::rerankers::RerankHybridCallbackArgs;
|
||||
use crate::rerankers::Reranker;
|
||||
use crate::rerankers::RerankerCallbacks;
|
||||
use crate::util::{parse_distance_type, schema_to_buffer};
|
||||
|
||||
#[napi]
|
||||
@@ -42,7 +42,7 @@ impl Query {
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn full_text_search(&mut self, query: napi::JsObject) -> napi::Result<()> {
|
||||
pub fn full_text_search(&mut self, query: Object) -> napi::Result<()> {
|
||||
let query = parse_fts_query(query)?;
|
||||
self.inner = self.inner.clone().full_text_search(query);
|
||||
Ok(())
|
||||
@@ -235,7 +235,7 @@ impl VectorQuery {
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn full_text_search(&mut self, query: napi::JsObject) -> napi::Result<()> {
|
||||
pub fn full_text_search(&mut self, query: Object) -> napi::Result<()> {
|
||||
let query = parse_fts_query(query)?;
|
||||
self.inner = self.inner.clone().full_text_search(query);
|
||||
Ok(())
|
||||
@@ -272,11 +272,13 @@ impl VectorQuery {
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn rerank(&mut self, callbacks: RerankerCallbacks) {
|
||||
self.inner = self
|
||||
.inner
|
||||
.clone()
|
||||
.rerank(Arc::new(Reranker::new(callbacks)));
|
||||
pub fn rerank(
|
||||
&mut self,
|
||||
rerank_hybrid: Function<RerankHybridCallbackArgs, Promise<Buffer>>,
|
||||
) -> napi::Result<()> {
|
||||
let reranker = Reranker::new(rerank_hybrid)?;
|
||||
self.inner = self.inner.clone().rerank(Arc::new(reranker));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
@@ -523,12 +525,12 @@ impl JsFullTextQuery {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_fts_query(query: napi::JsObject) -> napi::Result<FullTextSearchQuery> {
|
||||
if let Ok(Some(query)) = query.get::<_, &JsFullTextQuery>("query") {
|
||||
fn parse_fts_query(query: Object) -> napi::Result<FullTextSearchQuery> {
|
||||
if let Ok(Some(query)) = query.get::<&JsFullTextQuery>("query") {
|
||||
Ok(FullTextSearchQuery::new_query(query.inner.clone()))
|
||||
} else if let Ok(Some(query_text)) = query.get::<_, String>("query") {
|
||||
} else if let Ok(Some(query_text)) = query.get::<String>("query") {
|
||||
let mut query_text = query_text;
|
||||
let columns = query.get::<_, Option<Vec<String>>>("columns")?.flatten();
|
||||
let columns = query.get::<Option<Vec<String>>>("columns")?.flatten();
|
||||
|
||||
let is_phrase =
|
||||
query_text.len() >= 2 && query_text.starts_with('"') && query_text.ends_with('"');
|
||||
@@ -549,15 +551,12 @@ fn parse_fts_query(query: napi::JsObject) -> napi::Result<FullTextSearchQuery> {
|
||||
}
|
||||
};
|
||||
let mut query = FullTextSearchQuery::new_query(query);
|
||||
if let Some(cols) = columns {
|
||||
if !cols.is_empty() {
|
||||
query = query.with_columns(&cols).map_err(|e| {
|
||||
napi::Error::from_reason(format!(
|
||||
"Failed to set full text search columns: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
}
|
||||
if let Some(cols) = columns
|
||||
&& !cols.is_empty()
|
||||
{
|
||||
query = query.with_columns(&cols).map_err(|e| {
|
||||
napi::Error::from_reason(format!("Failed to set full text search columns: {}", e))
|
||||
})?;
|
||||
}
|
||||
Ok(query)
|
||||
} else {
|
||||
|
||||
@@ -3,10 +3,7 @@
|
||||
|
||||
use arrow_array::RecordBatch;
|
||||
use async_trait::async_trait;
|
||||
use napi::{
|
||||
bindgen_prelude::*,
|
||||
threadsafe_function::{ErrorStrategy, ThreadsafeFunction},
|
||||
};
|
||||
use napi::{bindgen_prelude::*, threadsafe_function::ThreadsafeFunction};
|
||||
use napi_derive::napi;
|
||||
|
||||
use lancedb::ipc::batches_to_ipc_file;
|
||||
@@ -15,27 +12,28 @@ use lancedb::{error::Error, ipc::ipc_file_to_batches};
|
||||
|
||||
use crate::error::NapiErrorExt;
|
||||
|
||||
type RerankHybridFn = ThreadsafeFunction<
|
||||
RerankHybridCallbackArgs,
|
||||
Promise<Buffer>,
|
||||
RerankHybridCallbackArgs,
|
||||
Status,
|
||||
false,
|
||||
>;
|
||||
|
||||
/// Reranker implementation that "wraps" a NodeJS Reranker implementation.
|
||||
/// This contains references to the callbacks that can be used to invoke the
|
||||
/// reranking methods on the NodeJS implementation and handles serializing the
|
||||
/// record batches to Arrow IPC buffers.
|
||||
#[napi]
|
||||
pub struct Reranker {
|
||||
/// callback to the Javascript which will call the rerankHybrid method of
|
||||
/// some Reranker implementation
|
||||
rerank_hybrid: ThreadsafeFunction<RerankHybridCallbackArgs, ErrorStrategy::CalleeHandled>,
|
||||
rerank_hybrid: RerankHybridFn,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl Reranker {
|
||||
#[napi]
|
||||
pub fn new(callbacks: RerankerCallbacks) -> Self {
|
||||
let rerank_hybrid = callbacks
|
||||
.rerank_hybrid
|
||||
.create_threadsafe_function(0, move |ctx| Ok(vec![ctx.value]))
|
||||
.unwrap();
|
||||
|
||||
Self { rerank_hybrid }
|
||||
pub fn new(
|
||||
rerank_hybrid: Function<RerankHybridCallbackArgs, Promise<Buffer>>,
|
||||
) -> napi::Result<Self> {
|
||||
let rerank_hybrid = rerank_hybrid.build_threadsafe_function().build()?;
|
||||
Ok(Self { rerank_hybrid })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,16 +47,16 @@ impl lancedb::rerankers::Reranker for Reranker {
|
||||
) -> lancedb::error::Result<RecordBatch> {
|
||||
let callback_args = RerankHybridCallbackArgs {
|
||||
query: query.to_string(),
|
||||
vec_results: batches_to_ipc_file(&[vector_results])?,
|
||||
fts_results: batches_to_ipc_file(&[fts_results])?,
|
||||
vec_results: Buffer::from(batches_to_ipc_file(&[vector_results])?.as_ref()),
|
||||
fts_results: Buffer::from(batches_to_ipc_file(&[fts_results])?.as_ref()),
|
||||
};
|
||||
let promised_buffer: Promise<Buffer> = self
|
||||
.rerank_hybrid
|
||||
.call_async(Ok(callback_args))
|
||||
.call_async(callback_args)
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("napi error status={}, reason={}", e.status, e.reason),
|
||||
})?;
|
||||
message: format!("napi error status={}, reason={}", e.status, e.reason),
|
||||
})?;
|
||||
let buffer = promised_buffer.await.map_err(|e| Error::Runtime {
|
||||
message: format!("napi error status={}, reason={}", e.status, e.reason),
|
||||
})?;
|
||||
@@ -77,16 +75,11 @@ impl std::fmt::Debug for Reranker {
|
||||
}
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
pub struct RerankerCallbacks {
|
||||
pub rerank_hybrid: JsFunction,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
pub struct RerankHybridCallbackArgs {
|
||||
pub query: String,
|
||||
pub vec_results: Vec<u8>,
|
||||
pub fts_results: Vec<u8>,
|
||||
pub vec_results: Buffer,
|
||||
pub fts_results: Buffer,
|
||||
}
|
||||
|
||||
fn buffer_to_record_batch(buffer: Buffer) -> Result<RecordBatch> {
|
||||
|
||||
@@ -95,8 +95,7 @@ impl napi::bindgen_prelude::FromNapiValue for Session {
|
||||
napi_val: napi::sys::napi_value,
|
||||
) -> napi::Result<Self> {
|
||||
let object: napi::bindgen_prelude::ClassInstance<Self> =
|
||||
napi::bindgen_prelude::ClassInstance::from_napi_value(env, napi_val)?;
|
||||
let copy = object.clone();
|
||||
Ok(copy)
|
||||
unsafe { napi::bindgen_prelude::ClassInstance::from_napi_value(env, napi_val)? };
|
||||
Ok((*object).clone())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,6 +71,17 @@ impl Table {
|
||||
pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result<AddResult> {
|
||||
let batches = ipc_file_to_batches(buf.to_vec())
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
||||
let batches = batches
|
||||
.into_iter()
|
||||
.map(|batch| {
|
||||
batch.map_err(|e| {
|
||||
napi::Error::from_reason(format!(
|
||||
"Failed to read record batch from IPC file: {}",
|
||||
e
|
||||
))
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let mut op = self.inner_ref()?.add(batches);
|
||||
|
||||
op = if mode == "append" {
|
||||
@@ -742,12 +753,14 @@ impl From<lancedb::table::AddResult> for AddResult {
|
||||
|
||||
#[napi(object)]
|
||||
pub struct DeleteResult {
|
||||
pub num_deleted_rows: i64,
|
||||
pub version: i64,
|
||||
}
|
||||
|
||||
impl From<lancedb::table::DeleteResult> for DeleteResult {
|
||||
fn from(value: lancedb::table::DeleteResult) -> Self {
|
||||
Self {
|
||||
num_deleted_rows: value.num_deleted_rows as i64,
|
||||
version: value.version as i64,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.30.0-beta.0"
|
||||
current_version = "0.30.0-beta.5"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.30.0-beta.0"
|
||||
version = "0.30.0-beta.5"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
rust-version = "1.88.0"
|
||||
rust-version = "1.91.0"
|
||||
|
||||
[lib]
|
||||
name = "_lancedb"
|
||||
@@ -16,9 +16,11 @@ crate-type = ["cdylib"]
|
||||
[dependencies]
|
||||
arrow = { version = "57.2", features = ["pyarrow"] }
|
||||
async-trait = "0.1"
|
||||
bytes = "1"
|
||||
lancedb = { path = "../rust/lancedb", default-features = false }
|
||||
lance-core.workspace = true
|
||||
lance-namespace.workspace = true
|
||||
lance-namespace-impls.workspace = true
|
||||
lance-io.workspace = true
|
||||
env_logger.workspace = true
|
||||
pyo3 = { version = "0.26", features = ["extension-module", "abi3-py39"] }
|
||||
@@ -28,6 +30,8 @@ pyo3-async-runtimes = { version = "0.26", features = [
|
||||
] }
|
||||
pin-project = "1.1.5"
|
||||
futures.workspace = true
|
||||
serde = "1"
|
||||
serde_json = "1"
|
||||
snafu.workspace = true
|
||||
tokio = { version = "1.40", features = ["sync"] }
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# LanceDB
|
||||
# LanceDB Python SDK
|
||||
|
||||
A Python library for [LanceDB](https://github.com/lancedb/lancedb).
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ repository = "https://github.com/lancedb/lancedb"
|
||||
|
||||
[project.optional-dependencies]
|
||||
pylance = [
|
||||
"pylance>=1.0.0b14",
|
||||
"pylance>=4.0.0b7",
|
||||
]
|
||||
tests = [
|
||||
"aiohttp",
|
||||
@@ -59,9 +59,9 @@ tests = [
|
||||
"polars>=0.19, <=1.3.0",
|
||||
"tantivy",
|
||||
"pyarrow-stubs",
|
||||
"pylance>=1.0.0b14",
|
||||
"pylance>=4.0.0b7",
|
||||
"requests",
|
||||
"datafusion",
|
||||
"datafusion>=52,<53",
|
||||
]
|
||||
dev = [
|
||||
"ruff",
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from functools import singledispatch
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from lancedb.pydantic import LanceModel, model_to_dict
|
||||
import pyarrow as pa
|
||||
|
||||
from ._lancedb import RecordBatchStream
|
||||
@@ -80,3 +82,32 @@ def peek_reader(
|
||||
yield from reader
|
||||
|
||||
return batch, pa.RecordBatchReader.from_batches(batch.schema, all_batches())
|
||||
|
||||
|
||||
@singledispatch
|
||||
def to_arrow(data) -> pa.Table:
|
||||
"""Convert a single data object to a pa.Table."""
|
||||
raise NotImplementedError(f"to_arrow not implemented for type {type(data)}")
|
||||
|
||||
|
||||
@to_arrow.register(pa.RecordBatch)
|
||||
def _arrow_from_batch(data: pa.RecordBatch) -> pa.Table:
|
||||
return pa.Table.from_batches([data])
|
||||
|
||||
|
||||
@to_arrow.register(pa.Table)
|
||||
def _arrow_from_table(data: pa.Table) -> pa.Table:
|
||||
return data
|
||||
|
||||
|
||||
@to_arrow.register(list)
|
||||
def _arrow_from_list(data: list) -> pa.Table:
|
||||
if not data:
|
||||
raise ValueError("Cannot create table from empty list without a schema")
|
||||
|
||||
if isinstance(data[0], LanceModel):
|
||||
schema = data[0].__class__.to_arrow_schema()
|
||||
dicts = [model_to_dict(d) for d in data]
|
||||
return pa.Table.from_pylist(dicts, schema=schema)
|
||||
|
||||
return pa.Table.from_pylist(data)
|
||||
|
||||
@@ -8,7 +8,7 @@ from abc import abstractmethod
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
@@ -1541,6 +1541,8 @@ class AsyncConnection(object):
|
||||
storage_options_provider: Optional["StorageOptionsProvider"] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
location: Optional[str] = None,
|
||||
namespace_client: Optional[Any] = None,
|
||||
managed_versioning: Optional[bool] = None,
|
||||
) -> AsyncTable:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
@@ -1573,6 +1575,9 @@ class AsyncConnection(object):
|
||||
The explicit location (URI) of the table. If provided, the table will be
|
||||
opened from this location instead of deriving it from the database URI
|
||||
and table name.
|
||||
managed_versioning: bool, optional
|
||||
Whether managed versioning is enabled for this table. If provided,
|
||||
avoids a redundant describe_table call when namespace_client is set.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -1587,6 +1592,8 @@ class AsyncConnection(object):
|
||||
storage_options_provider=storage_options_provider,
|
||||
index_cache_size=index_cache_size,
|
||||
location=location,
|
||||
namespace_client=namespace_client,
|
||||
managed_versioning=managed_versioning,
|
||||
)
|
||||
return AsyncTable(table)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -15,6 +16,8 @@ from .utils import weak_lru
|
||||
@register("gte-text")
|
||||
class GteEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
Deprecated: GTE embeddings should be used through sentence-transformers.
|
||||
|
||||
An embedding function that uses GTE-LARGE MLX format(for Apple silicon devices only)
|
||||
as well as the standard cpu/gpu version from: https://huggingface.co/thenlper/gte-large.
|
||||
|
||||
@@ -61,6 +64,13 @@ class GteEmbeddings(TextEmbeddingFunction):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
warnings.warn(
|
||||
"GTE embeddings as a standalone embedding function are deprecated. "
|
||||
"Use the 'sentence-transformers' embedding function with a GTE model "
|
||||
"instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
self._ndims = None
|
||||
if kwargs:
|
||||
self.mlx = kwargs.get("mlx", False)
|
||||
|
||||
@@ -110,6 +110,9 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
valid_embeddings = {
|
||||
idx: v.embedding for v, idx in zip(rs.data, valid_indices)
|
||||
}
|
||||
except openai.AuthenticationError:
|
||||
logging.error("Authentication failed: Invalid API key provided")
|
||||
raise
|
||||
except openai.BadRequestError:
|
||||
logging.exception("Bad request: %s", texts)
|
||||
return [None] * len(texts)
|
||||
|
||||
@@ -6,6 +6,7 @@ import io
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
import urllib.parse as urlparse
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
@@ -24,6 +25,7 @@ if TYPE_CHECKING:
|
||||
|
||||
@register("siglip")
|
||||
class SigLipEmbeddings(EmbeddingFunction):
|
||||
# Deprecated: prefer CLIP embeddings via `open-clip`.
|
||||
model_name: str = "google/siglip-base-patch16-224"
|
||||
device: str = "cpu"
|
||||
batch_size: int = 64
|
||||
@@ -36,6 +38,12 @@ class SigLipEmbeddings(EmbeddingFunction):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
warnings.warn(
|
||||
"SigLip embeddings are deprecated. Use CLIP embeddings via the "
|
||||
"'open-clip' embedding function instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
transformers = attempt_import_or_raise("transformers")
|
||||
self._torch = attempt_import_or_raise("torch")
|
||||
|
||||
|
||||
@@ -269,6 +269,11 @@ def retry_with_exponential_backoff(
|
||||
# and say that it is assumed that if this portion errors out, it's due
|
||||
# to rate limit but the user should check the error message to be sure.
|
||||
except Exception as e: # noqa: PERF203
|
||||
# Don't retry on authentication errors (e.g., OpenAI 401)
|
||||
# These are permanent failures that won't be fixed by retrying
|
||||
if _is_non_retryable_error(e):
|
||||
raise
|
||||
|
||||
num_retries += 1
|
||||
|
||||
if num_retries > max_retries:
|
||||
@@ -289,6 +294,29 @@ def retry_with_exponential_backoff(
|
||||
return wrapper
|
||||
|
||||
|
||||
def _is_non_retryable_error(error: Exception) -> bool:
|
||||
"""Check if an error should not be retried.
|
||||
|
||||
Args:
|
||||
error: The exception to check
|
||||
|
||||
Returns:
|
||||
True if the error should not be retried, False otherwise
|
||||
"""
|
||||
# Check for OpenAI authentication errors
|
||||
error_type = type(error).__name__
|
||||
if error_type == "AuthenticationError":
|
||||
return True
|
||||
|
||||
# Check for other common non-retryable HTTP status codes
|
||||
# 401 Unauthorized, 403 Forbidden
|
||||
if hasattr(error, "status_code"):
|
||||
if error.status_code in (401, 403):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def url_retrieve(url: str):
|
||||
"""
|
||||
Parameters
|
||||
|
||||
@@ -12,7 +12,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
@@ -44,7 +44,7 @@ from lance_namespace import (
|
||||
ListNamespacesRequest,
|
||||
CreateNamespaceRequest,
|
||||
DropNamespaceRequest,
|
||||
CreateEmptyTableRequest,
|
||||
DeclareTableRequest,
|
||||
)
|
||||
from lancedb.table import AsyncTable, LanceTable, Table
|
||||
from lancedb.util import validate_table_name
|
||||
@@ -240,7 +240,7 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
session : Optional[Session]
|
||||
A session to use for this connection
|
||||
"""
|
||||
self._ns = namespace
|
||||
self._namespace_client = namespace
|
||||
self.read_consistency_interval = read_consistency_interval
|
||||
self.storage_options = storage_options or {}
|
||||
self.session = session
|
||||
@@ -269,7 +269,7 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
if namespace is None:
|
||||
namespace = []
|
||||
request = ListTablesRequest(id=namespace, page_token=page_token, limit=limit)
|
||||
response = self._ns.list_tables(request)
|
||||
response = self._namespace_client.list_tables(request)
|
||||
return response.tables if response.tables else []
|
||||
|
||||
@override
|
||||
@@ -309,7 +309,9 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
# Try to describe the table first to see if it exists
|
||||
try:
|
||||
describe_request = DescribeTableRequest(id=table_id)
|
||||
describe_response = self._ns.describe_table(describe_request)
|
||||
describe_response = self._namespace_client.describe_table(
|
||||
describe_request
|
||||
)
|
||||
location = describe_response.location
|
||||
namespace_storage_options = describe_response.storage_options
|
||||
except Exception:
|
||||
@@ -318,20 +320,20 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
|
||||
if location is None:
|
||||
# Table doesn't exist or mode is "create", reserve a new location
|
||||
create_empty_request = CreateEmptyTableRequest(
|
||||
declare_request = DeclareTableRequest(
|
||||
id=table_id,
|
||||
location=None,
|
||||
properties=self.storage_options if self.storage_options else None,
|
||||
)
|
||||
create_empty_response = self._ns.create_empty_table(create_empty_request)
|
||||
declare_response = self._namespace_client.declare_table(declare_request)
|
||||
|
||||
if not create_empty_response.location:
|
||||
if not declare_response.location:
|
||||
raise ValueError(
|
||||
"Table location is missing from create_empty_table response"
|
||||
"Table location is missing from declare_table response"
|
||||
)
|
||||
|
||||
location = create_empty_response.location
|
||||
namespace_storage_options = create_empty_response.storage_options
|
||||
location = declare_response.location
|
||||
namespace_storage_options = declare_response.storage_options
|
||||
|
||||
# Merge storage options: self.storage_options < user options < namespace options
|
||||
merged_storage_options = dict(self.storage_options)
|
||||
@@ -353,7 +355,7 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
# Only create if namespace returned storage_options (not None)
|
||||
if storage_options_provider is None and namespace_storage_options is not None:
|
||||
storage_options_provider = LanceNamespaceStorageOptionsProvider(
|
||||
namespace=self._ns,
|
||||
namespace=self._namespace_client,
|
||||
table_id=table_id,
|
||||
)
|
||||
|
||||
@@ -371,6 +373,7 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
storage_options=merged_storage_options,
|
||||
storage_options_provider=storage_options_provider,
|
||||
location=location,
|
||||
namespace_client=self._namespace_client,
|
||||
)
|
||||
|
||||
return tbl
|
||||
@@ -389,7 +392,7 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
namespace = []
|
||||
table_id = namespace + [name]
|
||||
request = DescribeTableRequest(id=table_id)
|
||||
response = self._ns.describe_table(request)
|
||||
response = self._namespace_client.describe_table(request)
|
||||
|
||||
# Merge storage options: self.storage_options < user options < namespace options
|
||||
merged_storage_options = dict(self.storage_options)
|
||||
@@ -402,10 +405,14 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
# Only create if namespace returned storage_options (not None)
|
||||
if storage_options_provider is None and response.storage_options is not None:
|
||||
storage_options_provider = LanceNamespaceStorageOptionsProvider(
|
||||
namespace=self._ns,
|
||||
namespace=self._namespace_client,
|
||||
table_id=table_id,
|
||||
)
|
||||
|
||||
# Pass managed_versioning to avoid redundant describe_table call in Rust.
|
||||
# Convert None to False since we already have the answer from describe_table.
|
||||
managed_versioning = response.managed_versioning is True
|
||||
|
||||
return self._lance_table_from_uri(
|
||||
name,
|
||||
response.location,
|
||||
@@ -413,6 +420,8 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
storage_options=merged_storage_options,
|
||||
storage_options_provider=storage_options_provider,
|
||||
index_cache_size=index_cache_size,
|
||||
namespace_client=self._namespace_client,
|
||||
managed_versioning=managed_versioning,
|
||||
)
|
||||
|
||||
@override
|
||||
@@ -422,7 +431,7 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
namespace = []
|
||||
table_id = namespace + [name]
|
||||
request = DropTableRequest(id=table_id)
|
||||
self._ns.drop_table(request)
|
||||
self._namespace_client.drop_table(request)
|
||||
|
||||
@override
|
||||
def rename_table(
|
||||
@@ -484,7 +493,7 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
request = ListNamespacesRequest(
|
||||
id=namespace, page_token=page_token, limit=limit
|
||||
)
|
||||
response = self._ns.list_namespaces(request)
|
||||
response = self._namespace_client.list_namespaces(request)
|
||||
return ListNamespacesResponse(
|
||||
namespaces=response.namespaces if response.namespaces else [],
|
||||
page_token=response.page_token,
|
||||
@@ -520,7 +529,7 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
mode=_normalize_create_namespace_mode(mode),
|
||||
properties=properties,
|
||||
)
|
||||
response = self._ns.create_namespace(request)
|
||||
response = self._namespace_client.create_namespace(request)
|
||||
return CreateNamespaceResponse(
|
||||
properties=response.properties if hasattr(response, "properties") else None
|
||||
)
|
||||
@@ -555,7 +564,7 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
mode=_normalize_drop_namespace_mode(mode),
|
||||
behavior=_normalize_drop_namespace_behavior(behavior),
|
||||
)
|
||||
response = self._ns.drop_namespace(request)
|
||||
response = self._namespace_client.drop_namespace(request)
|
||||
return DropNamespaceResponse(
|
||||
properties=(
|
||||
response.properties if hasattr(response, "properties") else None
|
||||
@@ -581,7 +590,7 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
Response containing the namespace properties.
|
||||
"""
|
||||
request = DescribeNamespaceRequest(id=namespace)
|
||||
response = self._ns.describe_namespace(request)
|
||||
response = self._namespace_client.describe_namespace(request)
|
||||
return DescribeNamespaceResponse(
|
||||
properties=response.properties if hasattr(response, "properties") else None
|
||||
)
|
||||
@@ -615,7 +624,7 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
if namespace is None:
|
||||
namespace = []
|
||||
request = ListTablesRequest(id=namespace, page_token=page_token, limit=limit)
|
||||
response = self._ns.list_tables(request)
|
||||
response = self._namespace_client.list_tables(request)
|
||||
return ListTablesResponse(
|
||||
tables=response.tables if response.tables else [],
|
||||
page_token=response.page_token,
|
||||
@@ -630,6 +639,8 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
storage_options_provider: Optional[StorageOptionsProvider] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
namespace_client: Optional[Any] = None,
|
||||
managed_versioning: Optional[bool] = None,
|
||||
) -> LanceTable:
|
||||
# Open a table directly from a URI using the location parameter
|
||||
# Note: storage_options should already be merged by the caller
|
||||
@@ -643,6 +654,8 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
)
|
||||
|
||||
# Open the table using the temporary connection with the location parameter
|
||||
# Pass namespace_client to enable managed versioning support
|
||||
# Pass managed_versioning to avoid redundant describe_table call
|
||||
return LanceTable.open(
|
||||
temp_conn,
|
||||
name,
|
||||
@@ -651,6 +664,8 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
storage_options_provider=storage_options_provider,
|
||||
index_cache_size=index_cache_size,
|
||||
location=table_uri,
|
||||
namespace_client=namespace_client,
|
||||
managed_versioning=managed_versioning,
|
||||
)
|
||||
|
||||
|
||||
@@ -685,7 +700,7 @@ class AsyncLanceNamespaceDBConnection:
|
||||
session : Optional[Session]
|
||||
A session to use for this connection
|
||||
"""
|
||||
self._ns = namespace
|
||||
self._namespace_client = namespace
|
||||
self.read_consistency_interval = read_consistency_interval
|
||||
self.storage_options = storage_options or {}
|
||||
self.session = session
|
||||
@@ -713,7 +728,7 @@ class AsyncLanceNamespaceDBConnection:
|
||||
if namespace is None:
|
||||
namespace = []
|
||||
request = ListTablesRequest(id=namespace, page_token=page_token, limit=limit)
|
||||
response = self._ns.list_tables(request)
|
||||
response = self._namespace_client.list_tables(request)
|
||||
return response.tables if response.tables else []
|
||||
|
||||
async def create_table(
|
||||
@@ -750,7 +765,9 @@ class AsyncLanceNamespaceDBConnection:
|
||||
# Try to describe the table first to see if it exists
|
||||
try:
|
||||
describe_request = DescribeTableRequest(id=table_id)
|
||||
describe_response = self._ns.describe_table(describe_request)
|
||||
describe_response = self._namespace_client.describe_table(
|
||||
describe_request
|
||||
)
|
||||
location = describe_response.location
|
||||
namespace_storage_options = describe_response.storage_options
|
||||
except Exception:
|
||||
@@ -759,20 +776,20 @@ class AsyncLanceNamespaceDBConnection:
|
||||
|
||||
if location is None:
|
||||
# Table doesn't exist or mode is "create", reserve a new location
|
||||
create_empty_request = CreateEmptyTableRequest(
|
||||
declare_request = DeclareTableRequest(
|
||||
id=table_id,
|
||||
location=None,
|
||||
properties=self.storage_options if self.storage_options else None,
|
||||
)
|
||||
create_empty_response = self._ns.create_empty_table(create_empty_request)
|
||||
declare_response = self._namespace_client.declare_table(declare_request)
|
||||
|
||||
if not create_empty_response.location:
|
||||
if not declare_response.location:
|
||||
raise ValueError(
|
||||
"Table location is missing from create_empty_table response"
|
||||
"Table location is missing from declare_table response"
|
||||
)
|
||||
|
||||
location = create_empty_response.location
|
||||
namespace_storage_options = create_empty_response.storage_options
|
||||
location = declare_response.location
|
||||
namespace_storage_options = declare_response.storage_options
|
||||
|
||||
# Merge storage options: self.storage_options < user options < namespace options
|
||||
merged_storage_options = dict(self.storage_options)
|
||||
@@ -797,7 +814,7 @@ class AsyncLanceNamespaceDBConnection:
|
||||
and namespace_storage_options is not None
|
||||
):
|
||||
provider = LanceNamespaceStorageOptionsProvider(
|
||||
namespace=self._ns,
|
||||
namespace=self._namespace_client,
|
||||
table_id=table_id,
|
||||
)
|
||||
else:
|
||||
@@ -817,6 +834,7 @@ class AsyncLanceNamespaceDBConnection:
|
||||
storage_options=merged_storage_options,
|
||||
storage_options_provider=provider,
|
||||
location=location,
|
||||
namespace_client=self._namespace_client,
|
||||
)
|
||||
|
||||
lance_table = await asyncio.to_thread(_create_table)
|
||||
@@ -837,7 +855,7 @@ class AsyncLanceNamespaceDBConnection:
|
||||
namespace = []
|
||||
table_id = namespace + [name]
|
||||
request = DescribeTableRequest(id=table_id)
|
||||
response = self._ns.describe_table(request)
|
||||
response = self._namespace_client.describe_table(request)
|
||||
|
||||
# Merge storage options: self.storage_options < user options < namespace options
|
||||
merged_storage_options = dict(self.storage_options)
|
||||
@@ -849,10 +867,14 @@ class AsyncLanceNamespaceDBConnection:
|
||||
# Create a storage options provider if not provided by user
|
||||
if storage_options_provider is None and response.storage_options is not None:
|
||||
storage_options_provider = LanceNamespaceStorageOptionsProvider(
|
||||
namespace=self._ns,
|
||||
namespace=self._namespace_client,
|
||||
table_id=table_id,
|
||||
)
|
||||
|
||||
# Capture managed_versioning from describe response.
|
||||
# Convert None to False since we already have the answer from describe_table.
|
||||
managed_versioning = response.managed_versioning is True
|
||||
|
||||
# Open table in a thread
|
||||
def _open_table():
|
||||
temp_conn = LanceDBConnection(
|
||||
@@ -870,6 +892,8 @@ class AsyncLanceNamespaceDBConnection:
|
||||
storage_options_provider=storage_options_provider,
|
||||
index_cache_size=index_cache_size,
|
||||
location=response.location,
|
||||
namespace_client=self._namespace_client,
|
||||
managed_versioning=managed_versioning,
|
||||
)
|
||||
|
||||
lance_table = await asyncio.to_thread(_open_table)
|
||||
@@ -881,7 +905,7 @@ class AsyncLanceNamespaceDBConnection:
|
||||
namespace = []
|
||||
table_id = namespace + [name]
|
||||
request = DropTableRequest(id=table_id)
|
||||
self._ns.drop_table(request)
|
||||
self._namespace_client.drop_table(request)
|
||||
|
||||
async def rename_table(
|
||||
self,
|
||||
@@ -943,7 +967,7 @@ class AsyncLanceNamespaceDBConnection:
|
||||
request = ListNamespacesRequest(
|
||||
id=namespace, page_token=page_token, limit=limit
|
||||
)
|
||||
response = self._ns.list_namespaces(request)
|
||||
response = self._namespace_client.list_namespaces(request)
|
||||
return ListNamespacesResponse(
|
||||
namespaces=response.namespaces if response.namespaces else [],
|
||||
page_token=response.page_token,
|
||||
@@ -978,7 +1002,7 @@ class AsyncLanceNamespaceDBConnection:
|
||||
mode=_normalize_create_namespace_mode(mode),
|
||||
properties=properties,
|
||||
)
|
||||
response = self._ns.create_namespace(request)
|
||||
response = self._namespace_client.create_namespace(request)
|
||||
return CreateNamespaceResponse(
|
||||
properties=response.properties if hasattr(response, "properties") else None
|
||||
)
|
||||
@@ -1012,7 +1036,7 @@ class AsyncLanceNamespaceDBConnection:
|
||||
mode=_normalize_drop_namespace_mode(mode),
|
||||
behavior=_normalize_drop_namespace_behavior(behavior),
|
||||
)
|
||||
response = self._ns.drop_namespace(request)
|
||||
response = self._namespace_client.drop_namespace(request)
|
||||
return DropNamespaceResponse(
|
||||
properties=(
|
||||
response.properties if hasattr(response, "properties") else None
|
||||
@@ -1039,7 +1063,7 @@ class AsyncLanceNamespaceDBConnection:
|
||||
Response containing the namespace properties.
|
||||
"""
|
||||
request = DescribeNamespaceRequest(id=namespace)
|
||||
response = self._ns.describe_namespace(request)
|
||||
response = self._namespace_client.describe_namespace(request)
|
||||
return DescribeNamespaceResponse(
|
||||
properties=response.properties if hasattr(response, "properties") else None
|
||||
)
|
||||
@@ -1072,7 +1096,7 @@ class AsyncLanceNamespaceDBConnection:
|
||||
if namespace is None:
|
||||
namespace = []
|
||||
request = ListTablesRequest(id=namespace, page_token=page_token, limit=limit)
|
||||
response = self._ns.list_tables(request)
|
||||
response = self._namespace_client.list_tables(request)
|
||||
return ListTablesResponse(
|
||||
tables=response.tables if response.tables else [],
|
||||
page_token=response.page_token,
|
||||
|
||||
@@ -606,6 +606,7 @@ class LanceQueryBuilder(ABC):
|
||||
query,
|
||||
ordering_field_name=ordering_field_name,
|
||||
fts_columns=fts_columns,
|
||||
fast_search=fast_search,
|
||||
)
|
||||
|
||||
if isinstance(query, list):
|
||||
@@ -1456,12 +1457,14 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
query: str | FullTextQuery,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
fast_search: bool = None,
|
||||
):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
self._phrase_query = False
|
||||
self.ordering_field_name = ordering_field_name
|
||||
self._reranker = None
|
||||
self._fast_search = fast_search
|
||||
if isinstance(fts_columns, str):
|
||||
fts_columns = [fts_columns]
|
||||
self._fts_columns = fts_columns
|
||||
@@ -1483,6 +1486,19 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
self._phrase_query = phrase_query
|
||||
return self
|
||||
|
||||
def fast_search(self) -> LanceFtsQueryBuilder:
|
||||
"""
|
||||
Skip a flat search of unindexed data. This will improve
|
||||
search performance but search results will not include unindexed data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceFtsQueryBuilder
|
||||
The LanceFtsQueryBuilder object.
|
||||
"""
|
||||
self._fast_search = True
|
||||
return self
|
||||
|
||||
def to_query_object(self) -> Query:
|
||||
return Query(
|
||||
columns=self._columns,
|
||||
@@ -1494,6 +1510,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
query=self._query, columns=self._fts_columns
|
||||
),
|
||||
offset=self._offset,
|
||||
fast_search=self._fast_search,
|
||||
)
|
||||
|
||||
def output_schema(self) -> pa.Schema:
|
||||
@@ -1782,6 +1799,26 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
vector_results = LanceHybridQueryBuilder._rank(vector_results, "_distance")
|
||||
fts_results = LanceHybridQueryBuilder._rank(fts_results, "_score")
|
||||
|
||||
# If both result sets are empty (e.g. after hard filtering),
|
||||
# return early to avoid errors in reranking or score restoration.
|
||||
if vector_results.num_rows == 0 and fts_results.num_rows == 0:
|
||||
# Build a minimal empty table with the _relevance_score column
|
||||
combined_schema = pa.unify_schemas(
|
||||
[vector_results.schema, fts_results.schema],
|
||||
)
|
||||
empty = pa.table(
|
||||
{
|
||||
col: pa.array([], type=combined_schema.field(col).type)
|
||||
for col in combined_schema.names
|
||||
}
|
||||
)
|
||||
empty = empty.append_column(
|
||||
"_relevance_score", pa.array([], type=pa.float32())
|
||||
)
|
||||
if not with_row_ids and "_rowid" in empty.column_names:
|
||||
empty = empty.drop(["_rowid"])
|
||||
return empty
|
||||
|
||||
original_distances = None
|
||||
original_scores = None
|
||||
original_distance_row_ids = None
|
||||
|
||||
@@ -218,8 +218,6 @@ class RemoteTable(Table):
|
||||
train: bool = True,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
Currently, the only parameters that matter are
|
||||
the metric and the vector column name.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -250,11 +248,6 @@ class RemoteTable(Table):
|
||||
>>> table.create_index("l2", "vector") # doctest: +SKIP
|
||||
"""
|
||||
|
||||
if num_sub_vectors is not None:
|
||||
logging.warning(
|
||||
"num_sub_vectors is not supported on LanceDB cloud."
|
||||
"This parameter will be tuned automatically."
|
||||
)
|
||||
if accelerator is not None:
|
||||
logging.warning(
|
||||
"GPU accelerator is not yet supported on LanceDB cloud."
|
||||
|
||||
214
python/python/lancedb/scannable.py
Normal file
214
python/python/lancedb/scannable.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import singledispatch
|
||||
import sys
|
||||
from typing import Callable, Iterator, Optional
|
||||
from lancedb.arrow import to_arrow
|
||||
import pyarrow as pa
|
||||
import pyarrow.dataset as ds
|
||||
|
||||
from .pydantic import LanceModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class Scannable:
|
||||
schema: pa.Schema
|
||||
num_rows: Optional[int]
|
||||
# Factory function to create a new reader each time (supports re-scanning)
|
||||
reader: Callable[[], pa.RecordBatchReader]
|
||||
# Whether reader can be called more than once. For example, an iterator can
|
||||
# only be consumed once, while a DataFrame can be converted to a new reader
|
||||
# each time.
|
||||
rescannable: bool = True
|
||||
|
||||
|
||||
@singledispatch
|
||||
def to_scannable(data) -> Scannable:
|
||||
# Fallback: try iterable protocol
|
||||
if hasattr(data, "__iter__"):
|
||||
return _from_iterable(iter(data))
|
||||
raise NotImplementedError(f"to_scannable not implemented for type {type(data)}")
|
||||
|
||||
|
||||
@to_scannable.register(pa.RecordBatchReader)
|
||||
def _from_reader(data: pa.RecordBatchReader) -> Scannable:
|
||||
# RecordBatchReader can only be consumed once - not rescannable
|
||||
return Scannable(
|
||||
schema=data.schema, num_rows=None, reader=lambda: data, rescannable=False
|
||||
)
|
||||
|
||||
|
||||
@to_scannable.register(pa.RecordBatch)
|
||||
def _from_batch(data: pa.RecordBatch) -> Scannable:
|
||||
return Scannable(
|
||||
schema=data.schema,
|
||||
num_rows=data.num_rows,
|
||||
reader=lambda: pa.RecordBatchReader.from_batches(data.schema, [data]),
|
||||
)
|
||||
|
||||
|
||||
@to_scannable.register(pa.Table)
|
||||
def _from_table(data: pa.Table) -> Scannable:
|
||||
return Scannable(schema=data.schema, num_rows=data.num_rows, reader=data.to_reader)
|
||||
|
||||
|
||||
@to_scannable.register(ds.Dataset)
|
||||
def _from_dataset(data: ds.Dataset) -> Scannable:
|
||||
return Scannable(
|
||||
schema=data.schema,
|
||||
num_rows=data.count_rows(),
|
||||
reader=lambda: data.scanner().to_reader(),
|
||||
)
|
||||
|
||||
|
||||
@to_scannable.register(ds.Scanner)
|
||||
def _from_scanner(data: ds.Scanner) -> Scannable:
|
||||
# Scanner can only be consumed once - not rescannable
|
||||
return Scannable(
|
||||
schema=data.projected_schema,
|
||||
num_rows=None,
|
||||
reader=data.to_reader,
|
||||
rescannable=False,
|
||||
)
|
||||
|
||||
|
||||
@to_scannable.register(list)
|
||||
def _from_list(data: list) -> Scannable:
|
||||
if not data:
|
||||
raise ValueError("Cannot create table from empty list without a schema")
|
||||
table = to_arrow(data)
|
||||
return Scannable(
|
||||
schema=table.schema, num_rows=table.num_rows, reader=table.to_reader
|
||||
)
|
||||
|
||||
|
||||
@to_scannable.register(dict)
|
||||
def _from_dict(data: dict) -> Scannable:
|
||||
raise ValueError("Cannot add a single dictionary to a table. Use a list.")
|
||||
|
||||
|
||||
@to_scannable.register(LanceModel)
|
||||
def _from_lance_model(data: LanceModel) -> Scannable:
|
||||
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
|
||||
|
||||
|
||||
def _from_iterable(data: Iterator) -> Scannable:
|
||||
first_item = next(data, None)
|
||||
if first_item is None:
|
||||
raise ValueError("Cannot create table from empty iterator")
|
||||
first = to_arrow(first_item)
|
||||
schema = first.schema
|
||||
|
||||
def iter():
|
||||
yield from first.to_batches()
|
||||
for item in data:
|
||||
batch = to_arrow(item)
|
||||
if batch.schema != schema:
|
||||
try:
|
||||
batch = batch.cast(schema)
|
||||
except pa.lib.ArrowInvalid:
|
||||
raise ValueError(
|
||||
f"Input iterator yielded a batch with schema that "
|
||||
f"does not match the schema of other batches.\n"
|
||||
f"Expected:\n{schema}\nGot:\n{batch.schema}"
|
||||
)
|
||||
yield from batch.to_batches()
|
||||
|
||||
reader = pa.RecordBatchReader.from_batches(schema, iter())
|
||||
return to_scannable(reader)
|
||||
|
||||
|
||||
_registered_modules: set[str] = set()
|
||||
|
||||
|
||||
def _register_optional_converters():
|
||||
"""Register converters for optional dependencies that are already imported."""
|
||||
|
||||
if "pandas" in sys.modules and "pandas" not in _registered_modules:
|
||||
_registered_modules.add("pandas")
|
||||
import pandas as pd
|
||||
|
||||
@to_arrow.register(pd.DataFrame)
|
||||
def _arrow_from_pandas(data: pd.DataFrame) -> pa.Table:
|
||||
table = pa.Table.from_pandas(data, preserve_index=False)
|
||||
return table.replace_schema_metadata(None)
|
||||
|
||||
@to_scannable.register(pd.DataFrame)
|
||||
def _from_pandas(data: pd.DataFrame) -> Scannable:
|
||||
return to_scannable(_arrow_from_pandas(data))
|
||||
|
||||
if "polars" in sys.modules and "polars" not in _registered_modules:
|
||||
_registered_modules.add("polars")
|
||||
import polars as pl
|
||||
|
||||
@to_arrow.register(pl.DataFrame)
|
||||
def _arrow_from_polars(data: pl.DataFrame) -> pa.Table:
|
||||
return data.to_arrow()
|
||||
|
||||
@to_scannable.register(pl.DataFrame)
|
||||
def _from_polars(data: pl.DataFrame) -> Scannable:
|
||||
arrow = data.to_arrow()
|
||||
return Scannable(
|
||||
schema=arrow.schema, num_rows=len(data), reader=arrow.to_reader
|
||||
)
|
||||
|
||||
@to_scannable.register(pl.LazyFrame)
|
||||
def _from_polars_lazy(data: pl.LazyFrame) -> Scannable:
|
||||
arrow = data.collect().to_arrow()
|
||||
return Scannable(
|
||||
schema=arrow.schema, num_rows=arrow.num_rows, reader=arrow.to_reader
|
||||
)
|
||||
|
||||
if "datasets" in sys.modules and "datasets" not in _registered_modules:
|
||||
_registered_modules.add("datasets")
|
||||
from datasets import Dataset as HFDataset
|
||||
from datasets import DatasetDict as HFDatasetDict
|
||||
|
||||
@to_scannable.register(HFDataset)
|
||||
def _from_hf_dataset(data: HFDataset) -> Scannable:
|
||||
table = data.data.table # Access underlying Arrow table
|
||||
return Scannable(
|
||||
schema=table.schema, num_rows=len(data), reader=table.to_reader
|
||||
)
|
||||
|
||||
@to_scannable.register(HFDatasetDict)
|
||||
def _from_hf_dataset_dict(data: HFDatasetDict) -> Scannable:
|
||||
# HuggingFace DatasetDict: combine all splits with a 'split' column
|
||||
schema = data[list(data.keys())[0]].features.arrow_schema
|
||||
if "split" not in schema.names:
|
||||
schema = schema.append(pa.field("split", pa.string()))
|
||||
|
||||
def gen():
|
||||
for split_name, dataset in data.items():
|
||||
for batch in dataset.data.to_batches():
|
||||
split_arr = pa.array(
|
||||
[split_name] * len(batch), type=pa.string()
|
||||
)
|
||||
yield pa.RecordBatch.from_arrays(
|
||||
list(batch.columns) + [split_arr], schema=schema
|
||||
)
|
||||
|
||||
total_rows = sum(len(dataset) for dataset in data.values())
|
||||
return Scannable(
|
||||
schema=schema,
|
||||
num_rows=total_rows,
|
||||
reader=lambda: pa.RecordBatchReader.from_batches(schema, gen()),
|
||||
)
|
||||
|
||||
if "lance" in sys.modules and "lance" not in _registered_modules:
|
||||
_registered_modules.add("lance")
|
||||
import lance
|
||||
|
||||
@to_scannable.register(lance.LanceDataset)
|
||||
def _from_lance(data: lance.LanceDataset) -> Scannable:
|
||||
return Scannable(
|
||||
schema=data.schema,
|
||||
num_rows=data.count_rows(),
|
||||
reader=lambda: data.scanner().to_reader(),
|
||||
)
|
||||
|
||||
|
||||
# Register on module load
|
||||
_register_optional_converters()
|
||||
@@ -25,6 +25,8 @@ from typing import (
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from lancedb.scannable import _register_optional_converters, to_scannable
|
||||
|
||||
from . import __version__
|
||||
from lancedb.arrow import peek_reader
|
||||
from lancedb.background_loop import LOOP
|
||||
@@ -1329,7 +1331,7 @@ class Table(ABC):
|
||||
1 2 [3.0, 4.0]
|
||||
2 3 [5.0, 6.0]
|
||||
>>> table.delete("x = 2")
|
||||
DeleteResult(version=2)
|
||||
DeleteResult(num_deleted_rows=1, version=2)
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
@@ -1343,7 +1345,7 @@ class Table(ABC):
|
||||
>>> to_remove
|
||||
'1, 5'
|
||||
>>> table.delete(f"x IN ({to_remove})")
|
||||
DeleteResult(version=3)
|
||||
DeleteResult(num_deleted_rows=1, version=3)
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 3 [5.0, 6.0]
|
||||
@@ -1744,6 +1746,8 @@ class LanceTable(Table):
|
||||
storage_options_provider: Optional["StorageOptionsProvider"] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
location: Optional[str] = None,
|
||||
namespace_client: Optional[Any] = None,
|
||||
managed_versioning: Optional[bool] = None,
|
||||
_async: AsyncTable = None,
|
||||
):
|
||||
if namespace is None:
|
||||
@@ -1751,6 +1755,7 @@ class LanceTable(Table):
|
||||
self._conn = connection
|
||||
self._namespace = namespace
|
||||
self._location = location # Store location for use in _dataset_path
|
||||
self._namespace_client = namespace_client
|
||||
if _async is not None:
|
||||
self._table = _async
|
||||
else:
|
||||
@@ -1762,6 +1767,8 @@ class LanceTable(Table):
|
||||
storage_options_provider=storage_options_provider,
|
||||
index_cache_size=index_cache_size,
|
||||
location=location,
|
||||
namespace_client=namespace_client,
|
||||
managed_versioning=managed_versioning,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1804,6 +1811,8 @@ class LanceTable(Table):
|
||||
storage_options_provider: Optional["StorageOptionsProvider"] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
location: Optional[str] = None,
|
||||
namespace_client: Optional[Any] = None,
|
||||
managed_versioning: Optional[bool] = None,
|
||||
):
|
||||
if namespace is None:
|
||||
namespace = []
|
||||
@@ -1815,6 +1824,8 @@ class LanceTable(Table):
|
||||
storage_options_provider=storage_options_provider,
|
||||
index_cache_size=index_cache_size,
|
||||
location=location,
|
||||
namespace_client=namespace_client,
|
||||
managed_versioning=managed_versioning,
|
||||
)
|
||||
|
||||
# check the dataset exists
|
||||
@@ -1846,6 +1857,16 @@ class LanceTable(Table):
|
||||
"Please install with `pip install pylance`."
|
||||
)
|
||||
|
||||
if self._namespace_client is not None:
|
||||
table_id = self._namespace + [self.name]
|
||||
return lance.dataset(
|
||||
version=self.version,
|
||||
storage_options=self._conn.storage_options,
|
||||
namespace=self._namespace_client,
|
||||
table_id=table_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return lance.dataset(
|
||||
self._dataset_path,
|
||||
version=self.version,
|
||||
@@ -2711,6 +2732,7 @@ class LanceTable(Table):
|
||||
data_storage_version: Optional[str] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
location: Optional[str] = None,
|
||||
namespace_client: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Create a new table.
|
||||
@@ -2771,6 +2793,7 @@ class LanceTable(Table):
|
||||
self._conn = db
|
||||
self._namespace = namespace
|
||||
self._location = location
|
||||
self._namespace_client = namespace_client
|
||||
|
||||
if data_storage_version is not None:
|
||||
warnings.warn(
|
||||
@@ -3727,18 +3750,31 @@ class AsyncTable:
|
||||
on_bad_vectors = "error"
|
||||
if fill_value is None:
|
||||
fill_value = 0.0
|
||||
data = _sanitize_data(
|
||||
data,
|
||||
schema,
|
||||
metadata=schema.metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
allow_subschema=True,
|
||||
)
|
||||
if isinstance(data, pa.Table):
|
||||
data = data.to_reader()
|
||||
|
||||
return await self._inner.add(data, mode or "append")
|
||||
# _santitize_data is an old code path, but we will use it until the
|
||||
# new code path is ready.
|
||||
if on_bad_vectors != "error" or (
|
||||
schema.metadata is not None and b"embedding_functions" in schema.metadata
|
||||
):
|
||||
data = _sanitize_data(
|
||||
data,
|
||||
schema,
|
||||
metadata=schema.metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
allow_subschema=True,
|
||||
)
|
||||
_register_optional_converters()
|
||||
data = to_scannable(data)
|
||||
try:
|
||||
return await self._inner.add(data, mode or "append")
|
||||
except RuntimeError as e:
|
||||
if "Cast error" in str(e):
|
||||
raise ValueError(e)
|
||||
elif "Vector column contains NaN" in str(e):
|
||||
raise ValueError(e)
|
||||
else:
|
||||
raise
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
@@ -4200,7 +4236,7 @@ class AsyncTable:
|
||||
1 2 [3.0, 4.0]
|
||||
2 3 [5.0, 6.0]
|
||||
>>> table.delete("x = 2")
|
||||
DeleteResult(version=2)
|
||||
DeleteResult(num_deleted_rows=1, version=2)
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
@@ -4214,7 +4250,7 @@ class AsyncTable:
|
||||
>>> to_remove
|
||||
'1, 5'
|
||||
>>> table.delete(f"x IN ({to_remove})")
|
||||
DeleteResult(version=3)
|
||||
DeleteResult(num_deleted_rows=1, version=3)
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 3 [5.0, 6.0]
|
||||
|
||||
@@ -324,6 +324,16 @@ def _(value: list):
|
||||
return "[" + ", ".join(map(value_to_sql, value)) + "]"
|
||||
|
||||
|
||||
@value_to_sql.register(dict)
|
||||
def _(value: dict):
|
||||
# https://datafusion.apache.org/user-guide/sql/scalar_functions.html#named-struct
|
||||
return (
|
||||
"named_struct("
|
||||
+ ", ".join(f"'{k}', {value_to_sql(v)}" for k, v in value.items())
|
||||
+ ")"
|
||||
)
|
||||
|
||||
|
||||
@value_to_sql.register(np.ndarray)
|
||||
def _(value: np.ndarray):
|
||||
return value_to_sql(value.tolist())
|
||||
|
||||
@@ -515,3 +515,34 @@ def test_openai_propagates_api_key(monkeypatch):
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
assert len(actual.text) > 0
|
||||
|
||||
|
||||
@patch("time.sleep")
|
||||
def test_openai_no_retry_on_401(mock_sleep):
|
||||
"""
|
||||
Test that OpenAI embedding function does not retry on 401 authentication
|
||||
errors.
|
||||
"""
|
||||
from lancedb.embeddings.utils import retry_with_exponential_backoff
|
||||
|
||||
# Create a mock that raises an AuthenticationError
|
||||
class MockAuthenticationError(Exception):
|
||||
"""Mock OpenAI AuthenticationError"""
|
||||
|
||||
pass
|
||||
|
||||
MockAuthenticationError.__name__ = "AuthenticationError"
|
||||
|
||||
mock_func = MagicMock(side_effect=MockAuthenticationError("Invalid API key"))
|
||||
|
||||
# Wrap the function with retry logic
|
||||
wrapped_func = retry_with_exponential_backoff(mock_func, max_retries=3)
|
||||
|
||||
# Should raise without retrying
|
||||
with pytest.raises(MockAuthenticationError):
|
||||
wrapped_func()
|
||||
|
||||
# Verify that the function was only called once (no retries)
|
||||
assert mock_func.call_count == 1
|
||||
# Verify that sleep was never called (no retries)
|
||||
assert mock_sleep.call_count == 0
|
||||
|
||||
@@ -27,6 +27,7 @@ from lancedb.query import (
|
||||
PhraseQuery,
|
||||
BooleanQuery,
|
||||
Occur,
|
||||
LanceFtsQueryBuilder,
|
||||
)
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
@@ -882,3 +883,109 @@ def test_fts_query_to_json():
|
||||
'"must_not":[]}}'
|
||||
)
|
||||
assert json_str == expected
|
||||
|
||||
|
||||
def test_fts_fast_search(table):
|
||||
table.create_fts_index("text", use_tantivy=False)
|
||||
|
||||
# Insert some unindexed data
|
||||
table.add(
|
||||
[
|
||||
{
|
||||
"text": "xyz",
|
||||
"vector": [0 for _ in range(128)],
|
||||
"id": 101,
|
||||
"text2": "xyz",
|
||||
"nested": {"text": "xyz"},
|
||||
"count": 10,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Without fast_search, the query object should not have fast_search set
|
||||
builder = table.search("xyz", query_type="fts").limit(10)
|
||||
query = builder.to_query_object()
|
||||
assert query.fast_search is None
|
||||
|
||||
# With fast_search, the query object should have fast_search=True
|
||||
builder = table.search("xyz", query_type="fts").fast_search().limit(10)
|
||||
query = builder.to_query_object()
|
||||
assert query.fast_search is True
|
||||
|
||||
# fast_search should be chainable with other methods
|
||||
builder = (
|
||||
table.search("xyz", query_type="fts").fast_search().select(["text"]).limit(5)
|
||||
)
|
||||
query = builder.to_query_object()
|
||||
assert query.fast_search is True
|
||||
assert query.limit == 5
|
||||
assert query.columns == ["text"]
|
||||
|
||||
# fast_search should be enabled by keyword argument too
|
||||
query = LanceFtsQueryBuilder(table, "xyz", fast_search=True).to_query_object()
|
||||
assert query.fast_search is True
|
||||
|
||||
# Verify it executes without error and skips unindexed data
|
||||
results = table.search("xyz", query_type="fts").fast_search().limit(5).to_list()
|
||||
assert len(results) == 0
|
||||
|
||||
# Update index and verify it returns results
|
||||
table.optimize()
|
||||
results = table.search("xyz", query_type="fts").fast_search().limit(5).to_list()
|
||||
assert len(results) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fts_fast_search_async(async_table):
|
||||
await async_table.create_index("text", config=FTS())
|
||||
|
||||
# Insert some unindexed data
|
||||
await async_table.add(
|
||||
[
|
||||
{
|
||||
"text": "xyz",
|
||||
"vector": [0 for _ in range(128)],
|
||||
"id": 101,
|
||||
"text2": "xyz",
|
||||
"nested": {"text": "xyz"},
|
||||
"count": 10,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Without fast_search, should return results
|
||||
results = await async_table.query().nearest_to_text("xyz").limit(5).to_list()
|
||||
assert len(results) > 0
|
||||
|
||||
# With fast_search, should return no results data unindexed
|
||||
fast_results = (
|
||||
await async_table.query()
|
||||
.nearest_to_text("xyz")
|
||||
.fast_search()
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(fast_results) == 0
|
||||
|
||||
# Update index and verify it returns results
|
||||
await async_table.optimize()
|
||||
|
||||
fast_results = (
|
||||
await async_table.query()
|
||||
.nearest_to_text("xyz")
|
||||
.fast_search()
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(fast_results) > 0
|
||||
|
||||
# fast_search should be chainable with other methods
|
||||
results = (
|
||||
await async_table.query()
|
||||
.nearest_to_text("xyz")
|
||||
.fast_search()
|
||||
.select(["text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) > 0
|
||||
|
||||
@@ -531,6 +531,78 @@ def test_empty_result_reranker():
|
||||
)
|
||||
|
||||
|
||||
def test_empty_hybrid_result_reranker():
|
||||
"""Test that hybrid search with empty results after filtering doesn't crash.
|
||||
|
||||
Regression test for https://github.com/lancedb/lancedb/issues/2425
|
||||
"""
|
||||
from lancedb.query import LanceHybridQueryBuilder
|
||||
|
||||
# Simulate empty vector and FTS results with the expected schema
|
||||
vector_schema = pa.schema(
|
||||
[
|
||||
("text", pa.string()),
|
||||
("vector", pa.list_(pa.float32(), 4)),
|
||||
("_rowid", pa.uint64()),
|
||||
("_distance", pa.float32()),
|
||||
]
|
||||
)
|
||||
fts_schema = pa.schema(
|
||||
[
|
||||
("text", pa.string()),
|
||||
("vector", pa.list_(pa.float32(), 4)),
|
||||
("_rowid", pa.uint64()),
|
||||
("_score", pa.float32()),
|
||||
]
|
||||
)
|
||||
empty_vector = pa.table(
|
||||
{
|
||||
"text": pa.array([], type=pa.string()),
|
||||
"vector": pa.array([], type=pa.list_(pa.float32(), 4)),
|
||||
"_rowid": pa.array([], type=pa.uint64()),
|
||||
"_distance": pa.array([], type=pa.float32()),
|
||||
},
|
||||
schema=vector_schema,
|
||||
)
|
||||
empty_fts = pa.table(
|
||||
{
|
||||
"text": pa.array([], type=pa.string()),
|
||||
"vector": pa.array([], type=pa.list_(pa.float32(), 4)),
|
||||
"_rowid": pa.array([], type=pa.uint64()),
|
||||
"_score": pa.array([], type=pa.float32()),
|
||||
},
|
||||
schema=fts_schema,
|
||||
)
|
||||
|
||||
for reranker in [LinearCombinationReranker(), RRFReranker()]:
|
||||
result = LanceHybridQueryBuilder._combine_hybrid_results(
|
||||
fts_results=empty_fts,
|
||||
vector_results=empty_vector,
|
||||
norm="score",
|
||||
fts_query="nonexistent query",
|
||||
reranker=reranker,
|
||||
limit=10,
|
||||
with_row_ids=False,
|
||||
)
|
||||
assert len(result) == 0
|
||||
assert "_relevance_score" in result.column_names
|
||||
assert "_rowid" not in result.column_names
|
||||
|
||||
# Also test with with_row_ids=True
|
||||
result = LanceHybridQueryBuilder._combine_hybrid_results(
|
||||
fts_results=empty_fts,
|
||||
vector_results=empty_vector,
|
||||
norm="score",
|
||||
fts_query="nonexistent query",
|
||||
reranker=LinearCombinationReranker(),
|
||||
limit=10,
|
||||
with_row_ids=True,
|
||||
)
|
||||
assert len(result) == 0
|
||||
assert "_relevance_score" in result.column_names
|
||||
assert "_rowid" in result.column_names
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_cross_encoder_reranker_return_all(tmp_path, use_tantivy):
|
||||
pytest.importorskip("sentence_transformers")
|
||||
|
||||
@@ -326,6 +326,24 @@ def test_add_struct(mem_db: DBConnection):
|
||||
table = mem_db.create_table("test2", schema=schema)
|
||||
table.add(data)
|
||||
|
||||
struct_type = pa.struct(
|
||||
[
|
||||
("b", pa.int64()),
|
||||
("a", pa.int64()),
|
||||
]
|
||||
)
|
||||
expected = pa.table(
|
||||
{
|
||||
"s_list": [
|
||||
[
|
||||
pa.scalar({"b": 1, "a": 2}, type=struct_type),
|
||||
pa.scalar({"b": 4, "a": None}, type=struct_type),
|
||||
]
|
||||
],
|
||||
}
|
||||
)
|
||||
assert table.to_arrow() == expected
|
||||
|
||||
|
||||
def test_add_subschema(mem_db: DBConnection):
|
||||
schema = pa.schema(
|
||||
@@ -810,7 +828,7 @@ def test_create_index_name_and_train_parameters(
|
||||
)
|
||||
|
||||
|
||||
def test_add_with_nans(mem_db: DBConnection):
|
||||
def test_create_with_nans(mem_db: DBConnection):
|
||||
# by default we raise an error on bad input vectors
|
||||
bad_data = [
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
@@ -854,6 +872,57 @@ def test_add_with_nans(mem_db: DBConnection):
|
||||
assert np.allclose(v, np.array([0.0, 0.0]))
|
||||
|
||||
|
||||
def test_add_with_nans(mem_db: DBConnection):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2), nullable=True),
|
||||
pa.field("item", pa.string(), nullable=True),
|
||||
pa.field("price", pa.float64(), nullable=False),
|
||||
],
|
||||
)
|
||||
table = mem_db.create_table("test", schema=schema)
|
||||
# by default we raise an error on bad input vectors
|
||||
bad_data = [
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [5], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
|
||||
]
|
||||
for row in bad_data:
|
||||
with pytest.raises(ValueError):
|
||||
table.add(
|
||||
data=[row],
|
||||
)
|
||||
|
||||
table.add(
|
||||
[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [2.1, 4.1], "item": "foo", "price": 9.0},
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [5], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
assert len(table) == 2
|
||||
table.delete("true")
|
||||
|
||||
# We can fill bad input with some value
|
||||
table.add(
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
],
|
||||
on_bad_vectors="fill",
|
||||
fill_value=0.0,
|
||||
)
|
||||
assert len(table) == 3
|
||||
arrow_tbl = table.search().where("item == 'bar'").to_arrow()
|
||||
v = arrow_tbl["vector"].to_pylist()[0]
|
||||
assert np.allclose(v, np.array([0.0, 0.0]))
|
||||
|
||||
|
||||
def test_restore(mem_db: DBConnection):
|
||||
table = mem_db.create_table(
|
||||
"my_table",
|
||||
|
||||
@@ -121,6 +121,32 @@ def test_value_to_sql_string(tmp_path):
|
||||
assert table.to_pandas().query("search == @value")["replace"].item() == value
|
||||
|
||||
|
||||
def test_value_to_sql_dict():
|
||||
# Simple flat struct
|
||||
assert value_to_sql({"a": 1, "b": "hello"}) == "named_struct('a', 1, 'b', 'hello')"
|
||||
|
||||
# Nested struct
|
||||
assert (
|
||||
value_to_sql({"outer": {"inner": 1}})
|
||||
== "named_struct('outer', named_struct('inner', 1))"
|
||||
)
|
||||
|
||||
# List inside struct
|
||||
assert value_to_sql({"a": [1, 2]}) == "named_struct('a', [1, 2])"
|
||||
|
||||
# Mixed types
|
||||
assert (
|
||||
value_to_sql({"name": "test", "count": 42, "rate": 3.14, "active": True})
|
||||
== "named_struct('name', 'test', 'count', 42, 'rate', 3.14, 'active', TRUE)"
|
||||
)
|
||||
|
||||
# Null value inside struct
|
||||
assert value_to_sql({"a": None}) == "named_struct('a', NULL)"
|
||||
|
||||
# Empty dict
|
||||
assert value_to_sql({}) == "named_struct()"
|
||||
|
||||
|
||||
def test_append_vector_columns():
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
registry.register("test")(MockTextEmbeddingFunction)
|
||||
@@ -292,18 +318,14 @@ class TestModel(lancedb.pydantic.LanceModel):
|
||||
lambda: pa.table({"a": [1], "b": [2]}),
|
||||
lambda: pa.table({"a": [1], "b": [2]}).to_reader(),
|
||||
lambda: iter(pa.table({"a": [1], "b": [2]}).to_batches()),
|
||||
lambda: (
|
||||
lance.write_dataset(
|
||||
pa.table({"a": [1], "b": [2]}),
|
||||
"memory://test",
|
||||
)
|
||||
),
|
||||
lambda: (
|
||||
lance.write_dataset(
|
||||
pa.table({"a": [1], "b": [2]}),
|
||||
"memory://test",
|
||||
).scanner()
|
||||
lambda: lance.write_dataset(
|
||||
pa.table({"a": [1], "b": [2]}),
|
||||
"memory://test",
|
||||
),
|
||||
lambda: lance.write_dataset(
|
||||
pa.table({"a": [1], "b": [2]}),
|
||||
"memory://test",
|
||||
).scanner(),
|
||||
lambda: pd.DataFrame({"a": [1], "b": [2]}),
|
||||
lambda: pl.DataFrame({"a": [1], "b": [2]}),
|
||||
lambda: pl.LazyFrame({"a": [1], "b": [2]}),
|
||||
|
||||
@@ -10,7 +10,7 @@ use arrow::{
|
||||
use futures::stream::StreamExt;
|
||||
use lancedb::arrow::SendableRecordBatchStream;
|
||||
use pyo3::{
|
||||
exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, Py, PyAny, PyRef, PyResult, Python,
|
||||
Bound, Py, PyAny, PyRef, PyResult, Python, exceptions::PyStopAsyncIteration, pyclass, pymethods,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
|
||||
@@ -9,15 +9,16 @@ use lancedb::{
|
||||
database::{CreateTableMode, Database, ReadConsistency},
|
||||
};
|
||||
use pyo3::{
|
||||
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
pyclass, pyfunction, pymethods,
|
||||
types::{PyDict, PyDictMethods},
|
||||
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
use crate::{
|
||||
error::PythonErrorExt, storage_options::py_object_to_storage_options_provider, table::Table,
|
||||
error::PythonErrorExt, namespace::extract_namespace_arc,
|
||||
storage_options::py_object_to_storage_options_provider, table::Table,
|
||||
};
|
||||
|
||||
#[pyclass]
|
||||
@@ -182,7 +183,8 @@ impl Connection {
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, namespace=vec![], storage_options = None, storage_options_provider=None, index_cache_size = None, location=None))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[pyo3(signature = (name, namespace=vec![], storage_options = None, storage_options_provider=None, index_cache_size = None, location=None, namespace_client=None, managed_versioning=None))]
|
||||
pub fn open_table(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
@@ -191,11 +193,13 @@ impl Connection {
|
||||
storage_options_provider: Option<Py<PyAny>>,
|
||||
index_cache_size: Option<u32>,
|
||||
location: Option<String>,
|
||||
namespace_client: Option<Py<PyAny>>,
|
||||
managed_versioning: Option<bool>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
|
||||
let mut builder = inner.open_table(name);
|
||||
builder = builder.namespace(namespace);
|
||||
builder = builder.namespace(namespace.clone());
|
||||
if let Some(storage_options) = storage_options {
|
||||
builder = builder.storage_options(storage_options);
|
||||
}
|
||||
@@ -209,6 +213,20 @@ impl Connection {
|
||||
if let Some(location) = location {
|
||||
builder = builder.location(location);
|
||||
}
|
||||
// Extract namespace client from Python object if provided
|
||||
let ns_client = if let Some(ns_obj) = namespace_client {
|
||||
let py = self_.py();
|
||||
Some(extract_namespace_arc(py, ns_obj)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if let Some(ns_client) = ns_client {
|
||||
builder = builder.namespace_client(ns_client);
|
||||
}
|
||||
// Pass managed_versioning if provided to avoid redundant describe_table call
|
||||
if let Some(enabled) = managed_versioning {
|
||||
builder = builder.managed_versioning(enabled);
|
||||
}
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
let table = builder.execute().await.infer_error()?;
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use pyo3::{
|
||||
PyErr, PyResult, Python,
|
||||
exceptions::{PyIOError, PyNotImplementedError, PyOSError, PyRuntimeError, PyValueError},
|
||||
intern,
|
||||
types::{PyAnyMethods, PyNone},
|
||||
PyErr, PyResult, Python,
|
||||
};
|
||||
|
||||
use lancedb::error::Error as LanceError;
|
||||
|
||||
@@ -3,17 +3,17 @@
|
||||
|
||||
use lancedb::index::vector::{IvfFlatIndexBuilder, IvfRqIndexBuilder, IvfSqIndexBuilder};
|
||||
use lancedb::index::{
|
||||
Index as LanceDbIndex,
|
||||
scalar::{BTreeIndexBuilder, FtsIndexBuilder},
|
||||
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
||||
Index as LanceDbIndex,
|
||||
};
|
||||
use pyo3::types::PyStringMethods;
|
||||
use pyo3::IntoPyObject;
|
||||
use pyo3::types::PyStringMethods;
|
||||
use pyo3::{
|
||||
Bound, FromPyObject, PyAny, PyResult, Python,
|
||||
exceptions::{PyKeyError, PyValueError},
|
||||
intern, pyclass, pymethods,
|
||||
types::PyAnyMethods,
|
||||
Bound, FromPyObject, PyAny, PyResult, Python,
|
||||
};
|
||||
|
||||
use crate::util::parse_distance_type;
|
||||
@@ -41,7 +41,12 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
|
||||
let inner_opts = FtsIndexBuilder::default()
|
||||
.base_tokenizer(params.base_tokenizer)
|
||||
.language(¶ms.language)
|
||||
.map_err(|_| PyValueError::new_err(format!("LanceDB does not support the requested language: '{}'", params.language)))?
|
||||
.map_err(|_| {
|
||||
PyValueError::new_err(format!(
|
||||
"LanceDB does not support the requested language: '{}'",
|
||||
params.language
|
||||
))
|
||||
})?
|
||||
.with_position(params.with_position)
|
||||
.lower_case(params.lower_case)
|
||||
.max_token_length(params.max_token_length)
|
||||
@@ -52,7 +57,7 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
|
||||
.ngram_max_length(params.ngram_max_length)
|
||||
.ngram_prefix_only(params.prefix_only);
|
||||
Ok(LanceDbIndex::FTS(inner_opts))
|
||||
},
|
||||
}
|
||||
"IvfFlat" => {
|
||||
let params = source.extract::<IvfFlatParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
@@ -64,10 +69,11 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
|
||||
ivf_flat_builder = ivf_flat_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(target_partition_size) = params.target_partition_size {
|
||||
ivf_flat_builder = ivf_flat_builder.target_partition_size(target_partition_size);
|
||||
ivf_flat_builder =
|
||||
ivf_flat_builder.target_partition_size(target_partition_size);
|
||||
}
|
||||
Ok(LanceDbIndex::IvfFlat(ivf_flat_builder))
|
||||
},
|
||||
}
|
||||
"IvfPq" => {
|
||||
let params = source.extract::<IvfPqParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
@@ -86,7 +92,7 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
|
||||
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
|
||||
}
|
||||
Ok(LanceDbIndex::IvfPq(ivf_pq_builder))
|
||||
},
|
||||
}
|
||||
"IvfSq" => {
|
||||
let params = source.extract::<IvfSqParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
@@ -101,7 +107,7 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
|
||||
ivf_sq_builder = ivf_sq_builder.target_partition_size(target_partition_size);
|
||||
}
|
||||
Ok(LanceDbIndex::IvfSq(ivf_sq_builder))
|
||||
},
|
||||
}
|
||||
"IvfRq" => {
|
||||
let params = source.extract::<IvfRqParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
@@ -117,7 +123,7 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
|
||||
ivf_rq_builder = ivf_rq_builder.target_partition_size(target_partition_size);
|
||||
}
|
||||
Ok(LanceDbIndex::IvfRq(ivf_rq_builder))
|
||||
},
|
||||
}
|
||||
"HnswPq" => {
|
||||
let params = source.extract::<IvfHnswPqParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
@@ -138,7 +144,7 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors);
|
||||
}
|
||||
Ok(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))
|
||||
},
|
||||
}
|
||||
"HnswSq" => {
|
||||
let params = source.extract::<IvfHnswSqParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
@@ -155,7 +161,7 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
|
||||
hnsw_sq_builder = hnsw_sq_builder.target_partition_size(target_partition_size);
|
||||
}
|
||||
Ok(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))
|
||||
},
|
||||
}
|
||||
not_supported => Err(PyValueError::new_err(format!(
|
||||
"Invalid index type '{}'. Must be one of BTree, Bitmap, LabelList, FTS, IvfPq, IvfSq, IvfHnswPq, or IvfHnswSq",
|
||||
not_supported
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use arrow::RecordBatchStream;
|
||||
use connection::{connect, Connection};
|
||||
use connection::{Connection, connect};
|
||||
use env_logger::Env;
|
||||
use index::IndexConfig;
|
||||
use permutation::{PyAsyncPermutationBuilder, PyPermutationReader};
|
||||
use pyo3::{
|
||||
pymodule,
|
||||
Bound, PyResult, Python, pymodule,
|
||||
types::{PyModule, PyModuleMethods},
|
||||
wrap_pyfunction, Bound, PyResult, Python,
|
||||
wrap_pyfunction,
|
||||
};
|
||||
use query::{FTSQuery, HybridQuery, Query, VectorQuery};
|
||||
use session::Session;
|
||||
@@ -23,6 +23,7 @@ pub mod connection;
|
||||
pub mod error;
|
||||
pub mod header;
|
||||
pub mod index;
|
||||
pub mod namespace;
|
||||
pub mod permutation;
|
||||
pub mod query;
|
||||
pub mod session;
|
||||
|
||||
696
python/src/namespace.rs
Normal file
696
python/src/namespace.rs
Normal file
@@ -0,0 +1,696 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Namespace utilities for Python bindings
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use lance_namespace::LanceNamespace as LanceNamespaceTrait;
|
||||
use lance_namespace::models::*;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
|
||||
/// Wrapper that allows any Python object implementing LanceNamespace protocol
|
||||
/// to be used as a Rust LanceNamespace.
|
||||
///
|
||||
/// This is similar to PyLanceNamespace in lance's Python bindings - it wraps a Python
|
||||
/// object and calls back into Python when namespace methods are invoked.
|
||||
pub struct PyLanceNamespace {
|
||||
py_namespace: Arc<Py<PyAny>>,
|
||||
namespace_id: String,
|
||||
}
|
||||
|
||||
impl PyLanceNamespace {
|
||||
/// Create a new PyLanceNamespace wrapper around a Python namespace object.
|
||||
pub fn new(_py: Python<'_>, py_namespace: &Bound<'_, PyAny>) -> PyResult<Self> {
|
||||
let namespace_id = py_namespace
|
||||
.call_method0("namespace_id")?
|
||||
.extract::<String>()?;
|
||||
|
||||
Ok(Self {
|
||||
py_namespace: Arc::new(py_namespace.clone().unbind()),
|
||||
namespace_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create an Arc<dyn LanceNamespace> from a Python namespace object.
|
||||
pub fn create_arc(
|
||||
py: Python<'_>,
|
||||
py_namespace: &Bound<'_, PyAny>,
|
||||
) -> PyResult<Arc<dyn LanceNamespaceTrait>> {
|
||||
let wrapper = Self::new(py, py_namespace)?;
|
||||
Ok(Arc::new(wrapper))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PyLanceNamespace {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "PyLanceNamespace {{ id: {} }}", self.namespace_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get or create the DictWithModelDump class in Python.
|
||||
/// This class acts like a dict but also has model_dump() method.
|
||||
/// This allows it to work with both:
|
||||
/// - depythonize (which expects a dict/Mapping)
|
||||
/// - Python code that calls .model_dump() (like DirectoryNamespace wrapper)
|
||||
fn get_dict_with_model_dump_class(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
|
||||
// Use a module-level cache via __builtins__
|
||||
let builtins = py.import("builtins")?;
|
||||
if builtins.hasattr("_DictWithModelDump")? {
|
||||
return builtins.getattr("_DictWithModelDump");
|
||||
}
|
||||
|
||||
// Create the class using exec
|
||||
let locals = PyDict::new(py);
|
||||
py.run(
|
||||
c"class DictWithModelDump(dict):
|
||||
def model_dump(self):
|
||||
return dict(self)",
|
||||
None,
|
||||
Some(&locals),
|
||||
)?;
|
||||
let class = locals.get_item("DictWithModelDump")?.ok_or_else(|| {
|
||||
pyo3::exceptions::PyRuntimeError::new_err("Failed to create DictWithModelDump class")
|
||||
})?;
|
||||
|
||||
// Cache it
|
||||
builtins.setattr("_DictWithModelDump", &class)?;
|
||||
Ok(class)
|
||||
}
|
||||
|
||||
/// Helper to call a Python namespace method with JSON serialization.
|
||||
/// For methods that take a request and return a response.
|
||||
/// Uses DictWithModelDump to pass a dict that also has model_dump() method,
|
||||
/// making it compatible with both depythonize and Python wrappers.
|
||||
async fn call_py_method<Req, Resp>(
|
||||
py_namespace: Arc<Py<PyAny>>,
|
||||
method_name: &'static str,
|
||||
request: Req,
|
||||
) -> lance_core::Result<Resp>
|
||||
where
|
||||
Req: serde::Serialize + Send + 'static,
|
||||
Resp: serde::de::DeserializeOwned + Send + 'static,
|
||||
{
|
||||
let request_json = serde_json::to_string(&request).map_err(|e| {
|
||||
lance_core::Error::io(format!(
|
||||
"Failed to serialize request for {}: {}",
|
||||
method_name, e
|
||||
))
|
||||
})?;
|
||||
|
||||
let response_json = tokio::task::spawn_blocking(move || {
|
||||
Python::attach(|py| {
|
||||
let json_module = py.import("json")?;
|
||||
let request_dict = json_module.call_method1("loads", (&request_json,))?;
|
||||
|
||||
// Wrap dict in DictWithModelDump so it works with both depythonize and .model_dump()
|
||||
let dict_class = get_dict_with_model_dump_class(py)?;
|
||||
let request_arg = dict_class.call1((request_dict,))?;
|
||||
|
||||
// Call the Python method
|
||||
let result = py_namespace.call_method1(py, method_name, (request_arg,))?;
|
||||
|
||||
// Convert response to dict, then to JSON
|
||||
// Pydantic models have model_dump() method
|
||||
let result_dict = if result.bind(py).hasattr("model_dump")? {
|
||||
result.call_method0(py, "model_dump")?
|
||||
} else {
|
||||
result
|
||||
};
|
||||
let response_json: String = json_module
|
||||
.call_method1("dumps", (result_dict,))?
|
||||
.extract()?;
|
||||
Ok::<_, PyErr>(response_json)
|
||||
})
|
||||
})
|
||||
.await
|
||||
.map_err(|e| lance_core::Error::io(format!("Task join error for {}: {}", method_name, e)))?
|
||||
.map_err(|e: PyErr| lance_core::Error::io(format!("Python error in {}: {}", method_name, e)))?;
|
||||
|
||||
serde_json::from_str(&response_json).map_err(|e| {
|
||||
lance_core::Error::io(format!(
|
||||
"Failed to deserialize response from {}: {}",
|
||||
method_name, e
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper for methods that return () on success
|
||||
async fn call_py_method_unit<Req>(
|
||||
py_namespace: Arc<Py<PyAny>>,
|
||||
method_name: &'static str,
|
||||
request: Req,
|
||||
) -> lance_core::Result<()>
|
||||
where
|
||||
Req: serde::Serialize + Send + 'static,
|
||||
{
|
||||
let request_json = serde_json::to_string(&request).map_err(|e| {
|
||||
lance_core::Error::io(format!(
|
||||
"Failed to serialize request for {}: {}",
|
||||
method_name, e
|
||||
))
|
||||
})?;
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
Python::attach(|py| {
|
||||
let json_module = py.import("json")?;
|
||||
let request_dict = json_module.call_method1("loads", (&request_json,))?;
|
||||
|
||||
// Wrap dict in DictWithModelDump
|
||||
let dict_class = get_dict_with_model_dump_class(py)?;
|
||||
let request_arg = dict_class.call1((request_dict,))?;
|
||||
|
||||
// Call the Python method
|
||||
py_namespace.call_method1(py, method_name, (request_arg,))?;
|
||||
Ok::<_, PyErr>(())
|
||||
})
|
||||
})
|
||||
.await
|
||||
.map_err(|e| lance_core::Error::io(format!("Task join error for {}: {}", method_name, e)))?
|
||||
.map_err(|e: PyErr| lance_core::Error::io(format!("Python error in {}: {}", method_name, e)))
|
||||
}
|
||||
|
||||
/// Helper for methods that return a primitive type
|
||||
async fn call_py_method_primitive<Req, Resp>(
|
||||
py_namespace: Arc<Py<PyAny>>,
|
||||
method_name: &'static str,
|
||||
request: Req,
|
||||
) -> lance_core::Result<Resp>
|
||||
where
|
||||
Req: serde::Serialize + Send + 'static,
|
||||
Resp: for<'py> pyo3::FromPyObject<'py> + Send + 'static,
|
||||
{
|
||||
let request_json = serde_json::to_string(&request).map_err(|e| {
|
||||
lance_core::Error::io(format!(
|
||||
"Failed to serialize request for {}: {}",
|
||||
method_name, e
|
||||
))
|
||||
})?;
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
Python::attach(|py| {
|
||||
let json_module = py.import("json")?;
|
||||
let request_dict = json_module.call_method1("loads", (&request_json,))?;
|
||||
|
||||
// Wrap dict in DictWithModelDump
|
||||
let dict_class = get_dict_with_model_dump_class(py)?;
|
||||
let request_arg = dict_class.call1((request_dict,))?;
|
||||
|
||||
// Call the Python method
|
||||
let result = py_namespace.call_method1(py, method_name, (request_arg,))?;
|
||||
let value: Resp = result.extract(py)?;
|
||||
Ok::<_, PyErr>(value)
|
||||
})
|
||||
})
|
||||
.await
|
||||
.map_err(|e| lance_core::Error::io(format!("Task join error for {}: {}", method_name, e)))?
|
||||
.map_err(|e: PyErr| lance_core::Error::io(format!("Python error in {}: {}", method_name, e)))
|
||||
}
|
||||
|
||||
/// Helper for methods that return Bytes
|
||||
async fn call_py_method_bytes<Req>(
|
||||
py_namespace: Arc<Py<PyAny>>,
|
||||
method_name: &'static str,
|
||||
request: Req,
|
||||
) -> lance_core::Result<Bytes>
|
||||
where
|
||||
Req: serde::Serialize + Send + 'static,
|
||||
{
|
||||
let request_json = serde_json::to_string(&request).map_err(|e| {
|
||||
lance_core::Error::io(format!(
|
||||
"Failed to serialize request for {}: {}",
|
||||
method_name, e
|
||||
))
|
||||
})?;
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
Python::attach(|py| {
|
||||
let json_module = py.import("json")?;
|
||||
let request_dict = json_module.call_method1("loads", (&request_json,))?;
|
||||
|
||||
// Wrap dict in DictWithModelDump
|
||||
let dict_class = get_dict_with_model_dump_class(py)?;
|
||||
let request_arg = dict_class.call1((request_dict,))?;
|
||||
|
||||
// Call the Python method
|
||||
let result = py_namespace.call_method1(py, method_name, (request_arg,))?;
|
||||
let bytes_data: Vec<u8> = result.extract(py)?;
|
||||
Ok::<_, PyErr>(Bytes::from(bytes_data))
|
||||
})
|
||||
})
|
||||
.await
|
||||
.map_err(|e| lance_core::Error::io(format!("Task join error for {}: {}", method_name, e)))?
|
||||
.map_err(|e: PyErr| lance_core::Error::io(format!("Python error in {}: {}", method_name, e)))
|
||||
}
|
||||
|
||||
/// Helper for methods that take request + data and return a response
|
||||
async fn call_py_method_with_data<Req, Resp>(
|
||||
py_namespace: Arc<Py<PyAny>>,
|
||||
method_name: &'static str,
|
||||
request: Req,
|
||||
data: Bytes,
|
||||
) -> lance_core::Result<Resp>
|
||||
where
|
||||
Req: serde::Serialize + Send + 'static,
|
||||
Resp: serde::de::DeserializeOwned + Send + 'static,
|
||||
{
|
||||
let request_json = serde_json::to_string(&request).map_err(|e| {
|
||||
lance_core::Error::io(format!(
|
||||
"Failed to serialize request for {}: {}",
|
||||
method_name, e
|
||||
))
|
||||
})?;
|
||||
|
||||
let response_json = tokio::task::spawn_blocking(move || {
|
||||
Python::attach(|py| {
|
||||
let json_module = py.import("json")?;
|
||||
let request_dict = json_module.call_method1("loads", (&request_json,))?;
|
||||
|
||||
// Wrap dict in DictWithModelDump
|
||||
let dict_class = get_dict_with_model_dump_class(py)?;
|
||||
let request_arg = dict_class.call1((request_dict,))?;
|
||||
|
||||
// Pass request and bytes to Python method
|
||||
let py_bytes = pyo3::types::PyBytes::new(py, &data);
|
||||
let result = py_namespace.call_method1(py, method_name, (request_arg, py_bytes))?;
|
||||
|
||||
// Convert response dict to JSON
|
||||
let response_json: String = json_module.call_method1("dumps", (result,))?.extract()?;
|
||||
Ok::<_, PyErr>(response_json)
|
||||
})
|
||||
})
|
||||
.await
|
||||
.map_err(|e| lance_core::Error::io(format!("Task join error for {}: {}", method_name, e)))?
|
||||
.map_err(|e: PyErr| lance_core::Error::io(format!("Python error in {}: {}", method_name, e)))?;
|
||||
|
||||
serde_json::from_str(&response_json).map_err(|e| {
|
||||
lance_core::Error::io(format!(
|
||||
"Failed to deserialize response from {}: {}",
|
||||
method_name, e
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LanceNamespaceTrait for PyLanceNamespace {
|
||||
fn namespace_id(&self) -> String {
|
||||
self.namespace_id.clone()
|
||||
}
|
||||
|
||||
async fn list_namespaces(
|
||||
&self,
|
||||
request: ListNamespacesRequest,
|
||||
) -> lance_core::Result<ListNamespacesResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "list_namespaces", request).await
|
||||
}
|
||||
|
||||
async fn describe_namespace(
|
||||
&self,
|
||||
request: DescribeNamespaceRequest,
|
||||
) -> lance_core::Result<DescribeNamespaceResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "describe_namespace", request).await
|
||||
}
|
||||
|
||||
async fn create_namespace(
|
||||
&self,
|
||||
request: CreateNamespaceRequest,
|
||||
) -> lance_core::Result<CreateNamespaceResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "create_namespace", request).await
|
||||
}
|
||||
|
||||
async fn drop_namespace(
|
||||
&self,
|
||||
request: DropNamespaceRequest,
|
||||
) -> lance_core::Result<DropNamespaceResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "drop_namespace", request).await
|
||||
}
|
||||
|
||||
async fn namespace_exists(&self, request: NamespaceExistsRequest) -> lance_core::Result<()> {
|
||||
call_py_method_unit(self.py_namespace.clone(), "namespace_exists", request).await
|
||||
}
|
||||
|
||||
async fn list_tables(
|
||||
&self,
|
||||
request: ListTablesRequest,
|
||||
) -> lance_core::Result<ListTablesResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "list_tables", request).await
|
||||
}
|
||||
|
||||
async fn describe_table(
|
||||
&self,
|
||||
request: DescribeTableRequest,
|
||||
) -> lance_core::Result<DescribeTableResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "describe_table", request).await
|
||||
}
|
||||
|
||||
async fn register_table(
|
||||
&self,
|
||||
request: RegisterTableRequest,
|
||||
) -> lance_core::Result<RegisterTableResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "register_table", request).await
|
||||
}
|
||||
|
||||
async fn table_exists(&self, request: TableExistsRequest) -> lance_core::Result<()> {
|
||||
call_py_method_unit(self.py_namespace.clone(), "table_exists", request).await
|
||||
}
|
||||
|
||||
async fn drop_table(&self, request: DropTableRequest) -> lance_core::Result<DropTableResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "drop_table", request).await
|
||||
}
|
||||
|
||||
async fn deregister_table(
|
||||
&self,
|
||||
request: DeregisterTableRequest,
|
||||
) -> lance_core::Result<DeregisterTableResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "deregister_table", request).await
|
||||
}
|
||||
|
||||
async fn count_table_rows(&self, request: CountTableRowsRequest) -> lance_core::Result<i64> {
|
||||
call_py_method_primitive(self.py_namespace.clone(), "count_table_rows", request).await
|
||||
}
|
||||
|
||||
async fn create_table(
|
||||
&self,
|
||||
request: CreateTableRequest,
|
||||
request_data: Bytes,
|
||||
) -> lance_core::Result<CreateTableResponse> {
|
||||
call_py_method_with_data(
|
||||
self.py_namespace.clone(),
|
||||
"create_table",
|
||||
request,
|
||||
request_data,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn declare_table(
|
||||
&self,
|
||||
request: DeclareTableRequest,
|
||||
) -> lance_core::Result<DeclareTableResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "declare_table", request).await
|
||||
}
|
||||
|
||||
async fn insert_into_table(
|
||||
&self,
|
||||
request: InsertIntoTableRequest,
|
||||
request_data: Bytes,
|
||||
) -> lance_core::Result<InsertIntoTableResponse> {
|
||||
call_py_method_with_data(
|
||||
self.py_namespace.clone(),
|
||||
"insert_into_table",
|
||||
request,
|
||||
request_data,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn merge_insert_into_table(
|
||||
&self,
|
||||
request: MergeInsertIntoTableRequest,
|
||||
request_data: Bytes,
|
||||
) -> lance_core::Result<MergeInsertIntoTableResponse> {
|
||||
call_py_method_with_data(
|
||||
self.py_namespace.clone(),
|
||||
"merge_insert_into_table",
|
||||
request,
|
||||
request_data,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn update_table(
|
||||
&self,
|
||||
request: UpdateTableRequest,
|
||||
) -> lance_core::Result<UpdateTableResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "update_table", request).await
|
||||
}
|
||||
|
||||
async fn delete_from_table(
|
||||
&self,
|
||||
request: DeleteFromTableRequest,
|
||||
) -> lance_core::Result<DeleteFromTableResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "delete_from_table", request).await
|
||||
}
|
||||
|
||||
async fn query_table(&self, request: QueryTableRequest) -> lance_core::Result<Bytes> {
|
||||
call_py_method_bytes(self.py_namespace.clone(), "query_table", request).await
|
||||
}
|
||||
|
||||
async fn create_table_index(
|
||||
&self,
|
||||
request: CreateTableIndexRequest,
|
||||
) -> lance_core::Result<CreateTableIndexResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "create_table_index", request).await
|
||||
}
|
||||
|
||||
async fn list_table_indices(
|
||||
&self,
|
||||
request: ListTableIndicesRequest,
|
||||
) -> lance_core::Result<ListTableIndicesResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "list_table_indices", request).await
|
||||
}
|
||||
|
||||
async fn describe_table_index_stats(
|
||||
&self,
|
||||
request: DescribeTableIndexStatsRequest,
|
||||
) -> lance_core::Result<DescribeTableIndexStatsResponse> {
|
||||
call_py_method(
|
||||
self.py_namespace.clone(),
|
||||
"describe_table_index_stats",
|
||||
request,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn describe_transaction(
|
||||
&self,
|
||||
request: DescribeTransactionRequest,
|
||||
) -> lance_core::Result<DescribeTransactionResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "describe_transaction", request).await
|
||||
}
|
||||
|
||||
async fn alter_transaction(
|
||||
&self,
|
||||
request: AlterTransactionRequest,
|
||||
) -> lance_core::Result<AlterTransactionResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "alter_transaction", request).await
|
||||
}
|
||||
|
||||
async fn create_table_scalar_index(
|
||||
&self,
|
||||
request: CreateTableIndexRequest,
|
||||
) -> lance_core::Result<CreateTableScalarIndexResponse> {
|
||||
call_py_method(
|
||||
self.py_namespace.clone(),
|
||||
"create_table_scalar_index",
|
||||
request,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn drop_table_index(
|
||||
&self,
|
||||
request: DropTableIndexRequest,
|
||||
) -> lance_core::Result<DropTableIndexResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "drop_table_index", request).await
|
||||
}
|
||||
|
||||
async fn list_all_tables(
|
||||
&self,
|
||||
request: ListTablesRequest,
|
||||
) -> lance_core::Result<ListTablesResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "list_all_tables", request).await
|
||||
}
|
||||
|
||||
async fn restore_table(
|
||||
&self,
|
||||
request: RestoreTableRequest,
|
||||
) -> lance_core::Result<RestoreTableResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "restore_table", request).await
|
||||
}
|
||||
|
||||
async fn rename_table(
|
||||
&self,
|
||||
request: RenameTableRequest,
|
||||
) -> lance_core::Result<RenameTableResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "rename_table", request).await
|
||||
}
|
||||
|
||||
async fn list_table_versions(
|
||||
&self,
|
||||
request: ListTableVersionsRequest,
|
||||
) -> lance_core::Result<ListTableVersionsResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "list_table_versions", request).await
|
||||
}
|
||||
|
||||
async fn create_table_version(
|
||||
&self,
|
||||
request: CreateTableVersionRequest,
|
||||
) -> lance_core::Result<CreateTableVersionResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "create_table_version", request).await
|
||||
}
|
||||
|
||||
async fn describe_table_version(
|
||||
&self,
|
||||
request: DescribeTableVersionRequest,
|
||||
) -> lance_core::Result<DescribeTableVersionResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "describe_table_version", request).await
|
||||
}
|
||||
|
||||
async fn batch_delete_table_versions(
|
||||
&self,
|
||||
request: BatchDeleteTableVersionsRequest,
|
||||
) -> lance_core::Result<BatchDeleteTableVersionsResponse> {
|
||||
call_py_method(
|
||||
self.py_namespace.clone(),
|
||||
"batch_delete_table_versions",
|
||||
request,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn update_table_schema_metadata(
|
||||
&self,
|
||||
request: UpdateTableSchemaMetadataRequest,
|
||||
) -> lance_core::Result<UpdateTableSchemaMetadataResponse> {
|
||||
call_py_method(
|
||||
self.py_namespace.clone(),
|
||||
"update_table_schema_metadata",
|
||||
request,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_table_stats(
|
||||
&self,
|
||||
request: GetTableStatsRequest,
|
||||
) -> lance_core::Result<GetTableStatsResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "get_table_stats", request).await
|
||||
}
|
||||
|
||||
async fn explain_table_query_plan(
|
||||
&self,
|
||||
request: ExplainTableQueryPlanRequest,
|
||||
) -> lance_core::Result<String> {
|
||||
call_py_method_primitive(
|
||||
self.py_namespace.clone(),
|
||||
"explain_table_query_plan",
|
||||
request,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn analyze_table_query_plan(
|
||||
&self,
|
||||
request: AnalyzeTableQueryPlanRequest,
|
||||
) -> lance_core::Result<String> {
|
||||
call_py_method_primitive(
|
||||
self.py_namespace.clone(),
|
||||
"analyze_table_query_plan",
|
||||
request,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn alter_table_add_columns(
|
||||
&self,
|
||||
request: AlterTableAddColumnsRequest,
|
||||
) -> lance_core::Result<AlterTableAddColumnsResponse> {
|
||||
call_py_method(
|
||||
self.py_namespace.clone(),
|
||||
"alter_table_add_columns",
|
||||
request,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn alter_table_alter_columns(
|
||||
&self,
|
||||
request: AlterTableAlterColumnsRequest,
|
||||
) -> lance_core::Result<AlterTableAlterColumnsResponse> {
|
||||
call_py_method(
|
||||
self.py_namespace.clone(),
|
||||
"alter_table_alter_columns",
|
||||
request,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn alter_table_drop_columns(
|
||||
&self,
|
||||
request: AlterTableDropColumnsRequest,
|
||||
) -> lance_core::Result<AlterTableDropColumnsResponse> {
|
||||
call_py_method(
|
||||
self.py_namespace.clone(),
|
||||
"alter_table_drop_columns",
|
||||
request,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn list_table_tags(
|
||||
&self,
|
||||
request: ListTableTagsRequest,
|
||||
) -> lance_core::Result<ListTableTagsResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "list_table_tags", request).await
|
||||
}
|
||||
|
||||
async fn create_table_tag(
|
||||
&self,
|
||||
request: CreateTableTagRequest,
|
||||
) -> lance_core::Result<CreateTableTagResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "create_table_tag", request).await
|
||||
}
|
||||
|
||||
async fn delete_table_tag(
|
||||
&self,
|
||||
request: DeleteTableTagRequest,
|
||||
) -> lance_core::Result<DeleteTableTagResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "delete_table_tag", request).await
|
||||
}
|
||||
|
||||
async fn update_table_tag(
|
||||
&self,
|
||||
request: UpdateTableTagRequest,
|
||||
) -> lance_core::Result<UpdateTableTagResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "update_table_tag", request).await
|
||||
}
|
||||
|
||||
async fn get_table_tag_version(
|
||||
&self,
|
||||
request: GetTableTagVersionRequest,
|
||||
) -> lance_core::Result<GetTableTagVersionResponse> {
|
||||
call_py_method(self.py_namespace.clone(), "get_table_tag_version", request).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert Python dict to HashMap<String, String>
|
||||
#[allow(dead_code)]
|
||||
fn dict_to_hashmap(dict: &Bound<'_, PyDict>) -> PyResult<HashMap<String, String>> {
|
||||
let mut map = HashMap::new();
|
||||
for (key, value) in dict.iter() {
|
||||
let key_str: String = key.extract()?;
|
||||
let value_str: String = value.extract()?;
|
||||
map.insert(key_str, value_str);
|
||||
}
|
||||
Ok(map)
|
||||
}
|
||||
|
||||
/// Extract an Arc<dyn LanceNamespace> from a Python namespace object.
|
||||
///
|
||||
/// This function wraps any Python namespace object with PyLanceNamespace.
|
||||
/// The PyLanceNamespace wrapper uses DictWithModelDump to pass requests,
|
||||
/// which works with both:
|
||||
/// - Native namespaces (DirectoryNamespace, RestNamespace) that use depythonize (expects dict)
|
||||
/// - Custom Python implementations that call .model_dump() on the request
|
||||
pub fn extract_namespace_arc(
|
||||
py: Python<'_>,
|
||||
ns: Py<PyAny>,
|
||||
) -> PyResult<Arc<dyn LanceNamespaceTrait>> {
|
||||
let ns_ref = ns.bind(py);
|
||||
PyLanceNamespace::create_arc(py, ns_ref)
|
||||
}
|
||||
@@ -16,17 +16,32 @@ use lancedb::{
|
||||
query::Select,
|
||||
};
|
||||
use pyo3::{
|
||||
Bound, PyAny, PyRef, PyRefMut, PyResult, Python,
|
||||
exceptions::PyRuntimeError,
|
||||
pyclass, pymethods,
|
||||
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
|
||||
Bound, PyAny, PyRef, PyRefMut, PyResult, Python,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
fn table_from_py<'a>(table: Bound<'a, PyAny>) -> PyResult<Bound<'a, Table>> {
|
||||
if table.hasattr("_inner")? {
|
||||
Ok(table.getattr("_inner")?.downcast_into::<Table>()?)
|
||||
} else if table.hasattr("_table")? {
|
||||
Ok(table
|
||||
.getattr("_table")?
|
||||
.getattr("_inner")?
|
||||
.downcast_into::<Table>()?)
|
||||
} else {
|
||||
Err(PyRuntimeError::new_err(
|
||||
"Provided table does not appear to be a Table or RemoteTable instance",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a permutation builder for the given table
|
||||
#[pyo3::pyfunction]
|
||||
pub fn async_permutation_builder(table: Bound<'_, PyAny>) -> PyResult<PyAsyncPermutationBuilder> {
|
||||
let table = table.getattr("_inner")?.downcast_into::<Table>()?;
|
||||
let table = table_from_py(table)?;
|
||||
let inner_table = table.borrow().inner_ref()?.clone();
|
||||
let inner_builder = LancePermutationBuilder::new(inner_table);
|
||||
|
||||
@@ -250,10 +265,8 @@ impl PyPermutationReader {
|
||||
permutation_table: Option<Bound<'py, PyAny>>,
|
||||
split: u64,
|
||||
) -> PyResult<Bound<'py, PyAny>> {
|
||||
let base_table = base_table.getattr("_inner")?.downcast_into::<Table>()?;
|
||||
let permutation_table = permutation_table
|
||||
.map(|p| PyResult::Ok(p.getattr("_inner")?.downcast_into::<Table>()?))
|
||||
.transpose()?;
|
||||
let base_table = table_from_py(base_table)?;
|
||||
let permutation_table = permutation_table.map(table_from_py).transpose()?;
|
||||
|
||||
let base_table = base_table.borrow().inner_ref()?.base_table().clone();
|
||||
let permutation_table = permutation_table
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use arrow::array::make_array;
|
||||
use arrow::array::Array;
|
||||
use arrow::array::ArrayData;
|
||||
use arrow::array::make_array;
|
||||
use arrow::pyarrow::FromPyArrow;
|
||||
use arrow::pyarrow::IntoPyArrow;
|
||||
use arrow::pyarrow::ToPyArrow;
|
||||
@@ -22,23 +22,23 @@ use lancedb::query::{
|
||||
VectorQuery as LanceDbVectorQuery,
|
||||
};
|
||||
use lancedb::table::AnyQuery;
|
||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
||||
use pyo3::pyfunction;
|
||||
use pyo3::pymethods;
|
||||
use pyo3::types::PyList;
|
||||
use pyo3::types::{PyDict, PyString};
|
||||
use pyo3::Bound;
|
||||
use pyo3::IntoPyObject;
|
||||
use pyo3::PyAny;
|
||||
use pyo3::PyRef;
|
||||
use pyo3::PyResult;
|
||||
use pyo3::Python;
|
||||
use pyo3::{exceptions::PyRuntimeError, FromPyObject};
|
||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
||||
use pyo3::pyfunction;
|
||||
use pyo3::pymethods;
|
||||
use pyo3::types::PyList;
|
||||
use pyo3::types::{PyDict, PyString};
|
||||
use pyo3::{FromPyObject, exceptions::PyRuntimeError};
|
||||
use pyo3::{PyErr, pyclass};
|
||||
use pyo3::{
|
||||
exceptions::{PyNotImplementedError, PyValueError},
|
||||
intern,
|
||||
};
|
||||
use pyo3::{pyclass, PyErr};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
use crate::util::parse_distance_type;
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use lancedb::{ObjectStoreRegistry, Session as LanceSession};
|
||||
use pyo3::{pyclass, pymethods, PyResult};
|
||||
use pyo3::{PyResult, pyclass, pymethods};
|
||||
|
||||
/// A session for managing caches and object stores across LanceDB operations.
|
||||
///
|
||||
|
||||
@@ -66,13 +66,10 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
|
||||
.inner
|
||||
.bind(py)
|
||||
.call_method0("fetch_storage_options")
|
||||
.map_err(|e| lance_core::Error::IO {
|
||||
source: Box::new(std::io::Error::other(format!(
|
||||
"Failed to call fetch_storage_options: {}",
|
||||
e
|
||||
))),
|
||||
location: snafu::location!(),
|
||||
})?;
|
||||
.map_err(|e| lance_core::Error::io_source(Box::new(std::io::Error::other(format!(
|
||||
"Failed to call fetch_storage_options: {}",
|
||||
e
|
||||
)))))?;
|
||||
|
||||
// If result is None, return None
|
||||
if result.is_none() {
|
||||
@@ -81,26 +78,19 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
|
||||
|
||||
// Extract the result dict - should be a flat Map<String, String>
|
||||
let result_dict = result.downcast::<PyDict>().map_err(|_| {
|
||||
lance_core::Error::InvalidInput {
|
||||
source: "fetch_storage_options() must return None or a dict of string key-value pairs".into(),
|
||||
location: snafu::location!(),
|
||||
}
|
||||
lance_core::Error::invalid_input(
|
||||
"fetch_storage_options() must return a dict of string key-value pairs or None",
|
||||
)
|
||||
})?;
|
||||
|
||||
// Convert all entries to HashMap<String, String>
|
||||
let mut storage_options = HashMap::new();
|
||||
for (key, value) in result_dict.iter() {
|
||||
let key_str: String = key.extract().map_err(|e| {
|
||||
lance_core::Error::InvalidInput {
|
||||
source: format!("Storage option key must be a string: {}", e).into(),
|
||||
location: snafu::location!(),
|
||||
}
|
||||
lance_core::Error::invalid_input(format!("Storage option key must be a string: {}", e))
|
||||
})?;
|
||||
let value_str: String = value.extract().map_err(|e| {
|
||||
lance_core::Error::InvalidInput {
|
||||
source: format!("Storage option value must be a string: {}", e).into(),
|
||||
location: snafu::location!(),
|
||||
}
|
||||
lance_core::Error::invalid_input(format!("Storage option value must be a string: {}", e))
|
||||
})?;
|
||||
storage_options.insert(key_str, value_str);
|
||||
}
|
||||
@@ -109,13 +99,10 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
|
||||
})
|
||||
})
|
||||
.await
|
||||
.map_err(|e| lance_core::Error::IO {
|
||||
source: Box::new(std::io::Error::other(format!(
|
||||
"Task join error: {}",
|
||||
e
|
||||
))),
|
||||
location: snafu::location!(),
|
||||
})?
|
||||
.map_err(|e| lance_core::Error::io_source(Box::new(std::io::Error::other(format!(
|
||||
"Task join error: {}",
|
||||
e
|
||||
)))))?
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> String {
|
||||
|
||||
@@ -5,8 +5,9 @@ use std::{collections::HashMap, sync::Arc};
|
||||
use crate::{
|
||||
connection::Connection,
|
||||
error::PythonErrorExt,
|
||||
index::{extract_index_params, IndexConfig},
|
||||
index::{IndexConfig, extract_index_params},
|
||||
query::{Query, TakeQuery},
|
||||
table::scannable::PyScannable,
|
||||
};
|
||||
use arrow::{
|
||||
datatypes::{DataType, Schema},
|
||||
@@ -18,13 +19,15 @@ use lancedb::table::{
|
||||
Table as LanceDbTable,
|
||||
};
|
||||
use pyo3::{
|
||||
Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
|
||||
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
|
||||
pyclass, pymethods,
|
||||
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
|
||||
Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
mod scannable;
|
||||
|
||||
/// Statistics about a compaction operation.
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -109,19 +112,24 @@ impl From<lancedb::table::AddResult> for AddResult {
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DeleteResult {
|
||||
pub num_deleted_rows: u64,
|
||||
pub version: u64,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl DeleteResult {
|
||||
pub fn __repr__(&self) -> String {
|
||||
format!("DeleteResult(version={})", self.version)
|
||||
format!(
|
||||
"DeleteResult(num_deleted_rows={}, version={})",
|
||||
self.num_deleted_rows, self.version
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<lancedb::table::DeleteResult> for DeleteResult {
|
||||
fn from(result: lancedb::table::DeleteResult) -> Self {
|
||||
Self {
|
||||
num_deleted_rows: result.num_deleted_rows,
|
||||
version: result.version,
|
||||
}
|
||||
}
|
||||
@@ -293,12 +301,10 @@ impl Table {
|
||||
|
||||
pub fn add<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
data: Bound<'_, PyAny>,
|
||||
data: PyScannable,
|
||||
mode: String,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let batches: Box<dyn arrow::array::RecordBatchReader + Send> =
|
||||
Box::new(ArrowArrayStreamReader::from_pyarrow_bound(&data)?);
|
||||
let mut op = self_.inner_ref()?.add(batches);
|
||||
let mut op = self_.inner_ref()?.add(data);
|
||||
if mode == "append" {
|
||||
op = op.mode(AddDataMode::Append);
|
||||
} else if mode == "overwrite" {
|
||||
@@ -536,7 +542,7 @@ impl Table {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let versions = inner.list_versions().await.infer_error()?;
|
||||
let versions_as_dict = Python::attach(|py| {
|
||||
Python::attach(|py| {
|
||||
versions
|
||||
.iter()
|
||||
.map(|v| {
|
||||
@@ -553,9 +559,7 @@ impl Table {
|
||||
Ok(dict.unbind())
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()
|
||||
});
|
||||
|
||||
versions_as_dict
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
145
python/src/table/scannable.rs
Normal file
145
python/src/table/scannable.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::{
|
||||
datatypes::{Schema, SchemaRef},
|
||||
ffi_stream::ArrowArrayStreamReader,
|
||||
pyarrow::{FromPyArrow, PyArrowType},
|
||||
};
|
||||
use futures::StreamExt;
|
||||
use lancedb::{
|
||||
Error,
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
data::scannable::Scannable,
|
||||
};
|
||||
use pyo3::{FromPyObject, Py, PyAny, Python, types::PyAnyMethods};
|
||||
|
||||
/// Adapter that implements Scannable for a Python reader factory callable.
|
||||
///
|
||||
/// This holds a Python callable that returns a RecordBatchReader when called.
|
||||
/// For rescannable sources, the callable can be invoked multiple times to
|
||||
/// get fresh readers.
|
||||
pub struct PyScannable {
|
||||
/// Python callable that returns a RecordBatchReader
|
||||
reader_factory: Py<PyAny>,
|
||||
schema: SchemaRef,
|
||||
num_rows: Option<usize>,
|
||||
rescannable: bool,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PyScannable {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PyScannable")
|
||||
.field("schema", &self.schema)
|
||||
.field("num_rows", &self.num_rows)
|
||||
.field("rescannable", &self.rescannable)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Scannable for PyScannable {
|
||||
fn schema(&self) -> SchemaRef {
|
||||
self.schema.clone()
|
||||
}
|
||||
|
||||
fn scan_as_stream(&mut self) -> SendableRecordBatchStream {
|
||||
let reader: Result<ArrowArrayStreamReader, Error> = {
|
||||
Python::attach(|py| {
|
||||
let result =
|
||||
self.reader_factory
|
||||
.call0(py)
|
||||
.map_err(|e| lancedb::Error::Runtime {
|
||||
message: format!("Python reader factory failed: {}", e),
|
||||
})?;
|
||||
ArrowArrayStreamReader::from_pyarrow_bound(result.bind(py)).map_err(|e| {
|
||||
lancedb::Error::Runtime {
|
||||
message: format!("Failed to create Arrow reader from Python: {}", e),
|
||||
}
|
||||
})
|
||||
})
|
||||
};
|
||||
|
||||
// Reader is blocking but stream is non-blocking, so we need to spawn a task to pull.
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(1);
|
||||
|
||||
let join_handle = tokio::task::spawn_blocking(move || {
|
||||
let reader = match reader {
|
||||
Ok(reader) => reader,
|
||||
Err(e) => {
|
||||
let _ = tx.blocking_send(Err(e));
|
||||
return;
|
||||
}
|
||||
};
|
||||
for batch in reader {
|
||||
match batch {
|
||||
Ok(batch) => {
|
||||
if tx.blocking_send(Ok(batch)).is_err() {
|
||||
// Receiver dropped, stop processing
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(source) => {
|
||||
let _ = tx.blocking_send(Err(Error::Arrow { source }));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let schema = self.schema.clone();
|
||||
let stream = futures::stream::unfold(
|
||||
(rx, Some(join_handle)),
|
||||
|(mut rx, join_handle)| async move {
|
||||
match rx.recv().await {
|
||||
Some(Ok(batch)) => Some((Ok(batch), (rx, join_handle))),
|
||||
Some(Err(e)) => Some((Err(e), (rx, join_handle))),
|
||||
None => {
|
||||
// Channel closed. Check if the task panicked — a panic
|
||||
// drops the sender without sending an error, so without
|
||||
// this check we'd silently return a truncated stream.
|
||||
if let Some(handle) = join_handle
|
||||
&& let Err(join_err) = handle.await
|
||||
{
|
||||
return Some((
|
||||
Err(Error::Runtime {
|
||||
message: format!("Reader task panicked: {}", join_err),
|
||||
}),
|
||||
(rx, None),
|
||||
));
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
Box::pin(SimpleRecordBatchStream::new(stream.fuse(), schema))
|
||||
}
|
||||
|
||||
fn num_rows(&self) -> Option<usize> {
|
||||
self.num_rows
|
||||
}
|
||||
|
||||
fn rescannable(&self) -> bool {
|
||||
self.rescannable
|
||||
}
|
||||
}
|
||||
|
||||
impl<'py> FromPyObject<'py> for PyScannable {
|
||||
fn extract_bound(ob: &pyo3::Bound<'py, PyAny>) -> pyo3::PyResult<Self> {
|
||||
// Convert from Scannable dataclass.
|
||||
let schema: PyArrowType<Schema> = ob.getattr("schema")?.extract()?;
|
||||
let schema = Arc::new(schema.0);
|
||||
let num_rows: Option<usize> = ob.getattr("num_rows")?.extract()?;
|
||||
let rescannable: bool = ob.getattr("rescannable")?.extract()?;
|
||||
let reader_factory: Py<PyAny> = ob.getattr("reader")?.unbind();
|
||||
|
||||
Ok(Self {
|
||||
schema,
|
||||
reader_factory,
|
||||
num_rows,
|
||||
rescannable,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,8 +5,9 @@ use std::sync::Mutex;
|
||||
|
||||
use lancedb::DistanceType;
|
||||
use pyo3::{
|
||||
PyResult,
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
pyfunction, PyResult,
|
||||
pyfunction,
|
||||
};
|
||||
|
||||
/// A wrapper around a rust builder
|
||||
|
||||
4
python/uv.lock
generated
4
python/uv.lock
generated
@@ -2006,7 +2006,7 @@ requires-dist = [
|
||||
{ name = "botocore", marker = "extra == 'embeddings'", specifier = ">=1.31.57" },
|
||||
{ name = "cohere", marker = "extra == 'embeddings'" },
|
||||
{ name = "colpali-engine", marker = "extra == 'embeddings'", specifier = ">=0.3.10" },
|
||||
{ name = "datafusion", marker = "extra == 'tests'" },
|
||||
{ name = "datafusion", marker = "extra == 'tests'", specifier = "<52" },
|
||||
{ name = "deprecation" },
|
||||
{ name = "duckdb", marker = "extra == 'tests'" },
|
||||
{ name = "google-generativeai", marker = "extra == 'embeddings'" },
|
||||
@@ -2035,7 +2035,7 @@ requires-dist = [
|
||||
{ name = "pyarrow-stubs", marker = "extra == 'tests'" },
|
||||
{ name = "pydantic", specifier = ">=1.10" },
|
||||
{ name = "pylance", marker = "extra == 'pylance'", specifier = ">=1.0.0b14" },
|
||||
{ name = "pylance", marker = "extra == 'tests'", specifier = ">=1.0.0b14" },
|
||||
{ name = "pylance", marker = "extra == 'tests'", specifier = ">=1.0.0b14,<3.0.0" },
|
||||
{ name = "pyright", marker = "extra == 'dev'" },
|
||||
{ name = "pytest", marker = "extra == 'tests'" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'tests'" },
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
[toolchain]
|
||||
channel = "1.90.0"
|
||||
channel = "1.91.0"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.26.2"
|
||||
version = "0.27.0-beta.4"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
@@ -25,7 +25,9 @@ datafusion-catalog.workspace = true
|
||||
datafusion-common.workspace = true
|
||||
datafusion-execution.workspace = true
|
||||
datafusion-expr.workspace = true
|
||||
datafusion-functions.workspace = true
|
||||
datafusion-physical-expr.workspace = true
|
||||
datafusion-sql.workspace = true
|
||||
datafusion-physical-plan.workspace = true
|
||||
datafusion.workspace = true
|
||||
object_store = { workspace = true }
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# LanceDB Rust
|
||||
# LanceDB Rust SDK
|
||||
|
||||
<a href="https://crates.io/crates/vectordb"></a>
|
||||
<a href="https://docs.rs/vectordb/latest/vectordb/"></a>
|
||||
|
||||
@@ -9,10 +9,9 @@ use aws_config::Region;
|
||||
use aws_sdk_bedrockruntime::Client;
|
||||
use futures::StreamExt;
|
||||
use lancedb::{
|
||||
connect,
|
||||
embeddings::{bedrock::BedrockEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction},
|
||||
Result, connect,
|
||||
embeddings::{EmbeddingDefinition, EmbeddingFunction, bedrock::BedrockEmbeddingFunction},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
Result,
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
|
||||
@@ -10,10 +10,10 @@ use futures::TryStreamExt;
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
use lancedb::connection::Connection;
|
||||
|
||||
use lancedb::index::scalar::FtsIndexBuilder;
|
||||
use lancedb::index::Index;
|
||||
use lancedb::index::scalar::FtsIndexBuilder;
|
||||
use lancedb::query::{ExecutableQuery, QueryBase};
|
||||
use lancedb::{connect, Result, Table};
|
||||
use lancedb::{Result, Table, connect};
|
||||
use rand::random;
|
||||
|
||||
#[tokio::main]
|
||||
@@ -46,19 +46,21 @@ fn create_some_records() -> Result<Box<dyn arrow_array::RecordBatchReader + Send
|
||||
.collect::<Vec<_>>();
|
||||
let n_terms = 3;
|
||||
let batches = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
|
||||
Arc::new(StringArray::from_iter_values((0..TOTAL).map(|_| {
|
||||
(0..n_terms)
|
||||
.map(|_| words[random::<u32>() as usize % words.len()])
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ")
|
||||
}))),
|
||||
],
|
||||
)
|
||||
.unwrap()]
|
||||
vec![
|
||||
RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
|
||||
Arc::new(StringArray::from_iter_values((0..TOTAL).map(|_| {
|
||||
(0..n_terms)
|
||||
.map(|_| words[random::<u32>() as usize % words.len()])
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ")
|
||||
}))),
|
||||
],
|
||||
)
|
||||
.unwrap(),
|
||||
]
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
schema.clone(),
|
||||
|
||||
@@ -5,16 +5,15 @@ use arrow_array::{RecordBatch, StringArray};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use futures::TryStreamExt;
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
use lancedb::index::scalar::FtsIndexBuilder;
|
||||
use lancedb::index::Index;
|
||||
use lancedb::index::scalar::FtsIndexBuilder;
|
||||
use lancedb::{
|
||||
connect,
|
||||
Result, Table, connect,
|
||||
embeddings::{
|
||||
sentence_transformers::SentenceTransformersEmbeddings, EmbeddingDefinition,
|
||||
EmbeddingFunction,
|
||||
EmbeddingDefinition, EmbeddingFunction,
|
||||
sentence_transformers::SentenceTransformersEmbeddings,
|
||||
},
|
||||
query::{QueryBase, QueryExecutionOptions},
|
||||
Result, Table,
|
||||
};
|
||||
use std::{iter::once, sync::Arc};
|
||||
|
||||
|
||||
@@ -14,10 +14,10 @@ use arrow_schema::{DataType, Field, Schema};
|
||||
use futures::TryStreamExt;
|
||||
use lancedb::connection::Connection;
|
||||
|
||||
use lancedb::index::vector::IvfPqIndexBuilder;
|
||||
use lancedb::index::Index;
|
||||
use lancedb::index::vector::IvfPqIndexBuilder;
|
||||
use lancedb::query::{ExecutableQuery, QueryBase};
|
||||
use lancedb::{connect, DistanceType, Result, Table};
|
||||
use lancedb::{DistanceType, Result, Table, connect};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
@@ -51,19 +51,21 @@ fn create_some_records() -> Result<Box<dyn arrow_array::RecordBatchReader + Send
|
||||
|
||||
// Create a RecordBatch stream.
|
||||
let batches = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
|
||||
Arc::new(
|
||||
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
(0..TOTAL).map(|_| Some(vec![Some(1.0); DIM])),
|
||||
DIM as i32,
|
||||
vec![
|
||||
RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
|
||||
Arc::new(
|
||||
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
(0..TOTAL).map(|_| Some(vec![Some(1.0); DIM])),
|
||||
DIM as i32,
|
||||
),
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
.unwrap()]
|
||||
],
|
||||
)
|
||||
.unwrap(),
|
||||
]
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
schema.clone(),
|
||||
|
||||
@@ -8,10 +8,9 @@ use std::{iter::once, sync::Arc};
|
||||
use arrow_array::{RecordBatch, StringArray};
|
||||
use futures::StreamExt;
|
||||
use lancedb::{
|
||||
connect,
|
||||
embeddings::{openai::OpenAIEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction},
|
||||
Result, connect,
|
||||
embeddings::{EmbeddingDefinition, EmbeddingFunction, openai::OpenAIEmbeddingFunction},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
Result,
|
||||
};
|
||||
|
||||
// --8<-- [end:imports]
|
||||
|
||||
@@ -7,13 +7,12 @@ use arrow_array::{RecordBatch, StringArray};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use futures::StreamExt;
|
||||
use lancedb::{
|
||||
connect,
|
||||
Result, connect,
|
||||
embeddings::{
|
||||
sentence_transformers::SentenceTransformersEmbeddings, EmbeddingDefinition,
|
||||
EmbeddingFunction,
|
||||
EmbeddingDefinition, EmbeddingFunction,
|
||||
sentence_transformers::SentenceTransformersEmbeddings,
|
||||
},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
Result,
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
|
||||
@@ -14,7 +14,7 @@ use futures::TryStreamExt;
|
||||
use lancedb::connection::Connection;
|
||||
use lancedb::index::Index;
|
||||
use lancedb::query::{ExecutableQuery, QueryBase};
|
||||
use lancedb::{connect, Result, Table as LanceDbTable};
|
||||
use lancedb::{Result, Table as LanceDbTable, connect};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
|
||||
@@ -12,7 +12,7 @@ use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount};
|
||||
#[cfg(feature = "polars")]
|
||||
use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame};
|
||||
|
||||
use crate::{error::Result, Error};
|
||||
use crate::{Error, error::Result};
|
||||
|
||||
/// An iterator of batches that also has a schema
|
||||
pub trait RecordBatchReader: Iterator<Item = Result<arrow_array::RecordBatch>> {
|
||||
@@ -155,9 +155,7 @@ impl IntoArrowStream for SendableRecordBatchStream {
|
||||
impl IntoArrowStream for datafusion_physical_plan::SendableRecordBatchStream {
|
||||
fn into_arrow(self) -> Result<SendableRecordBatchStream> {
|
||||
let schema = self.schema();
|
||||
let stream = self.map_err(|df_err| Error::Runtime {
|
||||
message: df_err.to_string(),
|
||||
});
|
||||
let stream = self.map_err(|df_err| df_err.into());
|
||||
Ok(Box::pin(SimpleRecordBatchStream::new(stream, schema)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ use lance_namespace::models::{
|
||||
#[cfg(feature = "aws")]
|
||||
use object_store::aws::AwsCredential;
|
||||
|
||||
use crate::Table;
|
||||
use crate::connection::create_table::CreateTableBuilder;
|
||||
use crate::data::scannable::Scannable;
|
||||
use crate::database::listing::ListingDatabase;
|
||||
@@ -31,7 +32,6 @@ use crate::remote::{
|
||||
client::ClientConfig,
|
||||
db::{OPT_REMOTE_API_KEY, OPT_REMOTE_HOST_OVERRIDE, OPT_REMOTE_REGION},
|
||||
};
|
||||
use crate::Table;
|
||||
use lance::io::ObjectStoreParams;
|
||||
pub use lance_encoding::version::LanceFileVersion;
|
||||
#[cfg(feature = "remote")]
|
||||
@@ -136,6 +136,7 @@ impl OpenTableBuilder {
|
||||
lance_read_params: None,
|
||||
location: None,
|
||||
namespace_client: None,
|
||||
managed_versioning: None,
|
||||
},
|
||||
embedding_registry,
|
||||
}
|
||||
@@ -235,6 +236,29 @@ impl OpenTableBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a namespace client for managed versioning support.
|
||||
///
|
||||
/// When a namespace client is provided and the table has `managed_versioning` enabled,
|
||||
/// the table will use the namespace's commit handler to notify the namespace of
|
||||
/// version changes. This enables features like event emission for table modifications.
|
||||
pub fn namespace_client(mut self, client: Arc<dyn lance_namespace::LanceNamespace>) -> Self {
|
||||
self.request.namespace_client = Some(client);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set whether managed versioning is enabled for this table.
|
||||
///
|
||||
/// When set to `Some(true)`, the table will use namespace-managed commits.
|
||||
/// When set to `Some(false)`, the table will use local commits even if namespace_client is set.
|
||||
/// When set to `None` (default), the value will be fetched from the namespace if namespace_client is set.
|
||||
///
|
||||
/// This is typically set when the caller has already queried the namespace and knows the
|
||||
/// managed_versioning status, avoiding a redundant describe_table call.
|
||||
pub fn managed_versioning(mut self, enabled: bool) -> Self {
|
||||
self.request.managed_versioning = Some(enabled);
|
||||
self
|
||||
}
|
||||
|
||||
/// Open the table
|
||||
pub async fn execute(self) -> Result<Table> {
|
||||
let table = self.parent.open_table(self.request).await?;
|
||||
@@ -294,6 +318,12 @@ impl CloneTableBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a namespace client for managed versioning support.
|
||||
pub fn namespace_client(mut self, client: Arc<dyn lance_namespace::LanceNamespace>) -> Self {
|
||||
self.request.namespace_client = Some(client);
|
||||
self
|
||||
}
|
||||
|
||||
/// Execute the clone operation
|
||||
pub async fn execute(self) -> Result<Table> {
|
||||
let parent = self.parent.clone();
|
||||
@@ -566,8 +596,11 @@ pub struct ConnectBuilder {
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
const ENV_VARS_TO_STORAGE_OPTS: [(&str, &str); 1] =
|
||||
[("AZURE_STORAGE_ACCOUNT_NAME", "azure_storage_account_name")];
|
||||
const ENV_VARS_TO_STORAGE_OPTS: [(&str, &str); 3] = [
|
||||
("AZURE_STORAGE_ACCOUNT_NAME", "azure_storage_account_name"),
|
||||
("AZURE_CLIENT_ID", "azure_client_id"),
|
||||
("AZURE_TENANT_ID", "azure_tenant_id"),
|
||||
];
|
||||
|
||||
impl ConnectBuilder {
|
||||
/// Create a new [`ConnectOptions`] with the given database URI.
|
||||
@@ -758,10 +791,10 @@ impl ConnectBuilder {
|
||||
options: &mut HashMap<String, String>,
|
||||
) {
|
||||
for (env_key, opt_key) in env_var_to_remote_storage_option {
|
||||
if let Ok(env_value) = std::env::var(env_key) {
|
||||
if !options.contains_key(*opt_key) {
|
||||
options.insert((*opt_key).to_string(), env_value);
|
||||
}
|
||||
if let Ok(env_value) = std::env::var(env_key)
|
||||
&& !options.contains_key(*opt_key)
|
||||
{
|
||||
options.insert((*opt_key).to_string(), env_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1011,14 +1044,13 @@ mod tests {
|
||||
#[cfg(feature = "remote")]
|
||||
#[test]
|
||||
fn test_apply_env_defaults() {
|
||||
let env_key = "TEST_APPLY_ENV_DEFAULTS_ENVIRONMENT_VARIABLE_ENV_KEY";
|
||||
let env_val = "TEST_APPLY_ENV_DEFAULTS_ENVIRONMENT_VARIABLE_ENV_VAL";
|
||||
let env_key = "PATH";
|
||||
let env_val = std::env::var(env_key).expect("PATH should be set in test environment");
|
||||
let opts_key = "test_apply_env_defaults_environment_variable_opts_key";
|
||||
std::env::set_var(env_key, env_val);
|
||||
|
||||
let mut options = HashMap::new();
|
||||
ConnectBuilder::apply_env_defaults(&[(env_key, opts_key)], &mut options);
|
||||
assert_eq!(Some(&env_val.to_string()), options.get(opts_key));
|
||||
assert_eq!(Some(&env_val), options.get(opts_key));
|
||||
|
||||
options.insert(opts_key.to_string(), "EXPLICIT-VALUE".to_string());
|
||||
ConnectBuilder::apply_env_defaults(&[(env_key, opts_key)], &mut options);
|
||||
|
||||
@@ -6,12 +6,12 @@ use std::sync::Arc;
|
||||
use lance_io::object_store::StorageOptionsProvider;
|
||||
|
||||
use crate::{
|
||||
Error, Result, Table,
|
||||
connection::{merge_storage_options, set_storage_options_provider},
|
||||
data::scannable::{Scannable, WithEmbeddingsScannable},
|
||||
database::{CreateTableMode, CreateTableRequest, Database},
|
||||
embeddings::{EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry},
|
||||
table::WriteOptions,
|
||||
Error, Result, Table,
|
||||
};
|
||||
|
||||
pub struct CreateTableBuilder {
|
||||
@@ -167,7 +167,7 @@ impl CreateTableBuilder {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow_array::{
|
||||
record_batch, Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator,
|
||||
Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, record_batch,
|
||||
};
|
||||
use arrow_schema::{ArrowError, DataType, Field, Schema};
|
||||
use futures::TryStreamExt;
|
||||
@@ -380,11 +380,12 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
let other_schema = Arc::new(Schema::new(vec![Field::new("y", DataType::Int32, false)]));
|
||||
assert!(db
|
||||
.create_empty_table("test", other_schema.clone())
|
||||
.execute()
|
||||
.await
|
||||
.is_err()); // TODO: assert what this error is
|
||||
assert!(
|
||||
db.create_empty_table("test", other_schema.clone())
|
||||
.execute()
|
||||
.await
|
||||
.is_err()
|
||||
); // TODO: assert what this error is
|
||||
let overwritten = db
|
||||
.create_empty_table("test", other_schema.clone())
|
||||
.mode(CreateTableMode::Overwrite)
|
||||
|
||||
@@ -5,9 +5,9 @@ use std::collections::HashMap;
|
||||
|
||||
use arrow::compute::kernels::{aggregate::bool_and, length::length};
|
||||
use arrow_array::{
|
||||
Array, GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatchReader,
|
||||
cast::AsArray,
|
||||
types::{ArrowPrimitiveType, Int32Type, Int64Type},
|
||||
Array, GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatchReader,
|
||||
};
|
||||
use arrow_ord::cmp::eq;
|
||||
use arrow_schema::DataType;
|
||||
@@ -78,7 +78,7 @@ pub fn infer_vector_columns(
|
||||
_ => {
|
||||
return Err(Error::Schema {
|
||||
message: format!("Column {} is not a list", col_name),
|
||||
})
|
||||
});
|
||||
}
|
||||
} {
|
||||
if let Some(Some(prev_dim)) = columns_to_infer.get(&col_name) {
|
||||
@@ -102,8 +102,8 @@ mod tests {
|
||||
use super::*;
|
||||
|
||||
use arrow_array::{
|
||||
types::{Float32Type, Float64Type},
|
||||
FixedSizeListArray, Float32Array, ListArray, RecordBatch, RecordBatchIterator, StringArray,
|
||||
types::{Float32Type, Float64Type},
|
||||
};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use std::{sync::Arc, vec};
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
use std::{iter::repeat_with, sync::Arc};
|
||||
|
||||
use arrow_array::{
|
||||
cast::AsArray,
|
||||
types::{Float16Type, Float32Type, Float64Type, Int32Type, Int64Type},
|
||||
Array, ArrowNumericType, FixedSizeListArray, PrimitiveArray, RecordBatch, RecordBatchIterator,
|
||||
RecordBatchReader,
|
||||
cast::AsArray,
|
||||
types::{Float16Type, Float32Type, Float64Type, Int32Type, Int64Type},
|
||||
};
|
||||
use arrow_cast::{can_cast_types, cast};
|
||||
use arrow_schema::{ArrowError, DataType, Field, Schema};
|
||||
@@ -184,7 +184,7 @@ mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{
|
||||
FixedSizeListArray, Float16Array, Float32Array, Float64Array, Int32Array, Int8Array,
|
||||
FixedSizeListArray, Float16Array, Float32Array, Float64Array, Int8Array, Int32Array,
|
||||
RecordBatch, RecordBatchIterator, StringArray,
|
||||
};
|
||||
use arrow_schema::Field;
|
||||
|
||||
@@ -9,22 +9,21 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{ArrowError, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::once;
|
||||
use futures::StreamExt;
|
||||
use lance_datafusion::utils::StreamingWriteSource;
|
||||
|
||||
use crate::arrow::{
|
||||
SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream,
|
||||
};
|
||||
use crate::embeddings::{
|
||||
compute_embeddings_for_batch, compute_output_schema, EmbeddingDefinition, EmbeddingFunction,
|
||||
EmbeddingRegistry,
|
||||
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, compute_embeddings_for_batch,
|
||||
compute_output_schema,
|
||||
};
|
||||
use crate::table::{ColumnDefinition, ColumnKind, TableDefinition};
|
||||
use crate::{Error, Result};
|
||||
use arrow_array::{ArrayRef, RecordBatch, RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{ArrowError, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::once;
|
||||
use lance_datafusion::utils::StreamingWriteSource;
|
||||
|
||||
pub trait Scannable: Send {
|
||||
/// Returns the schema of the data.
|
||||
@@ -228,6 +227,19 @@ impl WithEmbeddingsScannable {
|
||||
let table_definition = TableDefinition::new(output_schema, column_definitions);
|
||||
let output_schema = table_definition.into_rich_schema();
|
||||
|
||||
Self::with_schema(inner, embeddings, output_schema)
|
||||
}
|
||||
|
||||
/// Create a WithEmbeddingsScannable with a specific output schema.
|
||||
///
|
||||
/// Use this when the table schema is already known (e.g. during add) to
|
||||
/// avoid nullability mismatches between the embedding function's declared
|
||||
/// type and the table's stored type.
|
||||
pub fn with_schema(
|
||||
inner: Box<dyn Scannable>,
|
||||
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
||||
output_schema: SchemaRef,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
inner,
|
||||
embeddings,
|
||||
@@ -245,9 +257,11 @@ impl Scannable for WithEmbeddingsScannable {
|
||||
let inner_stream = self.inner.scan_as_stream();
|
||||
let embeddings = self.embeddings.clone();
|
||||
let output_schema = self.output_schema.clone();
|
||||
let stream_schema = output_schema.clone();
|
||||
|
||||
let mapped_stream = inner_stream.then(move |batch_result| {
|
||||
let embeddings = embeddings.clone();
|
||||
let output_schema = output_schema.clone();
|
||||
async move {
|
||||
let batch = batch_result?;
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
@@ -257,12 +271,29 @@ impl Scannable for WithEmbeddingsScannable {
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Task panicked during embedding computation: {}", e),
|
||||
})??;
|
||||
// Cast columns to match the declared output schema. The data is
|
||||
// identical but field metadata (e.g. nested nullability) may
|
||||
// differ between the embedding function output and the table.
|
||||
let columns: Vec<ArrayRef> = result
|
||||
.columns()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, col)| {
|
||||
let target_type = output_schema.field(i).data_type();
|
||||
if col.data_type() == target_type {
|
||||
Ok(col.clone())
|
||||
} else {
|
||||
arrow_cast::cast(col, target_type).map_err(Error::from)
|
||||
}
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
let result = RecordBatch::try_new(output_schema, columns)?;
|
||||
Ok(result)
|
||||
}
|
||||
});
|
||||
|
||||
Box::pin(SimpleRecordBatchStream {
|
||||
schema: output_schema,
|
||||
schema: stream_schema,
|
||||
stream: mapped_stream,
|
||||
})
|
||||
}
|
||||
@@ -303,8 +334,13 @@ pub fn scannable_with_embeddings(
|
||||
}
|
||||
|
||||
if !embeddings.is_empty() {
|
||||
return Ok(Box::new(WithEmbeddingsScannable::try_new(
|
||||
inner, embeddings,
|
||||
// Use the table's schema so embedding column types (including nested
|
||||
// nullability) match what's stored, avoiding mismatches with the
|
||||
// embedding function's declared dest_type.
|
||||
return Ok(Box::new(WithEmbeddingsScannable::with_schema(
|
||||
inner,
|
||||
embeddings,
|
||||
table_definition.schema.clone(),
|
||||
)?));
|
||||
}
|
||||
}
|
||||
@@ -312,6 +348,133 @@ pub fn scannable_with_embeddings(
|
||||
Ok(inner)
|
||||
}
|
||||
|
||||
/// A wrapper that buffers the first RecordBatch from a Scannable so we can
|
||||
/// inspect it (e.g. to estimate data size) without losing it.
|
||||
pub(crate) struct PeekedScannable {
|
||||
inner: Box<dyn Scannable>,
|
||||
peeked: Option<RecordBatch>,
|
||||
/// The first item from the stream, if it was an error. Stored so we can
|
||||
/// re-emit it from `scan_as_stream` instead of silently dropping it.
|
||||
first_error: Option<crate::Error>,
|
||||
stream: Option<SendableRecordBatchStream>,
|
||||
}
|
||||
|
||||
impl PeekedScannable {
|
||||
pub fn new(inner: Box<dyn Scannable>) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
peeked: None,
|
||||
first_error: None,
|
||||
stream: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads and buffers the first batch from the inner scannable.
|
||||
/// Returns a clone of it. Subsequent calls return the same batch.
|
||||
///
|
||||
/// Returns `None` if the stream is empty or the first item is an error.
|
||||
/// Errors are preserved and re-emitted by `scan_as_stream`.
|
||||
pub async fn peek(&mut self) -> Option<RecordBatch> {
|
||||
if self.peeked.is_some() {
|
||||
return self.peeked.clone();
|
||||
}
|
||||
// Already peeked and got an error or empty stream.
|
||||
if self.stream.is_some() || self.first_error.is_some() {
|
||||
return None;
|
||||
}
|
||||
let mut stream = self.inner.scan_as_stream();
|
||||
match stream.next().await {
|
||||
Some(Ok(batch)) => {
|
||||
self.peeked = Some(batch.clone());
|
||||
self.stream = Some(stream);
|
||||
Some(batch)
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
self.first_error = Some(e);
|
||||
self.stream = Some(stream);
|
||||
None
|
||||
}
|
||||
None => {
|
||||
self.stream = Some(stream);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Scannable for PeekedScannable {
|
||||
fn schema(&self) -> SchemaRef {
|
||||
self.inner.schema()
|
||||
}
|
||||
|
||||
fn num_rows(&self) -> Option<usize> {
|
||||
self.inner.num_rows()
|
||||
}
|
||||
|
||||
fn rescannable(&self) -> bool {
|
||||
self.inner.rescannable()
|
||||
}
|
||||
|
||||
fn scan_as_stream(&mut self) -> SendableRecordBatchStream {
|
||||
let schema = self.inner.schema();
|
||||
|
||||
// If peek() hit an error, prepend it so downstream sees the error.
|
||||
let error_item = self.first_error.take().map(Err);
|
||||
|
||||
match (self.peeked.take(), self.stream.take()) {
|
||||
(Some(batch), Some(rest)) => {
|
||||
let prepend = futures::stream::once(std::future::ready(Ok(batch)));
|
||||
Box::pin(SimpleRecordBatchStream {
|
||||
schema,
|
||||
stream: prepend.chain(rest),
|
||||
})
|
||||
}
|
||||
(Some(batch), None) => Box::pin(SimpleRecordBatchStream {
|
||||
schema,
|
||||
stream: futures::stream::once(std::future::ready(Ok(batch))),
|
||||
}),
|
||||
(None, Some(rest)) => {
|
||||
if let Some(err) = error_item {
|
||||
let stream = futures::stream::once(std::future::ready(err));
|
||||
Box::pin(SimpleRecordBatchStream { schema, stream })
|
||||
} else {
|
||||
rest
|
||||
}
|
||||
}
|
||||
(None, None) => {
|
||||
// peek() was never called — just delegate
|
||||
self.inner.scan_as_stream()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the number of write partitions based on data size estimates.
|
||||
///
|
||||
/// `sample_bytes` and `sample_rows` come from a representative batch and are
|
||||
/// used to estimate per-row size. `total_rows_hint` is the total row count
|
||||
/// when known; otherwise `sample_rows` row count is used as a lower bound
|
||||
/// estimate.
|
||||
///
|
||||
/// Targets roughly 1 million rows or 2 GB per partition, capped at
|
||||
/// `max_partitions` (typically the number of available CPU cores).
|
||||
pub(crate) fn estimate_write_partitions(
|
||||
sample_bytes: usize,
|
||||
sample_rows: usize,
|
||||
total_rows_hint: Option<usize>,
|
||||
max_partitions: usize,
|
||||
) -> usize {
|
||||
if sample_rows == 0 {
|
||||
return 1;
|
||||
}
|
||||
let bytes_per_row = sample_bytes / sample_rows;
|
||||
let total_rows = total_rows_hint.unwrap_or(sample_rows);
|
||||
let total_bytes = total_rows * bytes_per_row;
|
||||
let by_rows = total_rows.div_ceil(1_000_000);
|
||||
let by_bytes = total_bytes.div_ceil(2 * 1024 * 1024 * 1024);
|
||||
by_rows.max(by_bytes).max(1).min(max_partitions)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -408,6 +571,231 @@ mod tests {
|
||||
assert!(result2.unwrap().is_err());
|
||||
}
|
||||
|
||||
mod peeked_scannable_tests {
|
||||
use crate::test_utils::TestCustomError;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_peek_returns_first_batch() {
|
||||
let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap();
|
||||
let mut peeked = PeekedScannable::new(Box::new(batch.clone()));
|
||||
|
||||
let first = peeked.peek().await.unwrap();
|
||||
assert_eq!(first, batch);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_peek_is_idempotent() {
|
||||
let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap();
|
||||
let mut peeked = PeekedScannable::new(Box::new(batch.clone()));
|
||||
|
||||
let first = peeked.peek().await.unwrap();
|
||||
let second = peeked.peek().await.unwrap();
|
||||
assert_eq!(first, second);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scan_after_peek_returns_all_data() {
|
||||
let batches = vec![
|
||||
record_batch!(("id", Int64, [1, 2])).unwrap(),
|
||||
record_batch!(("id", Int64, [3, 4, 5])).unwrap(),
|
||||
];
|
||||
let mut peeked = PeekedScannable::new(Box::new(batches.clone()));
|
||||
|
||||
let first = peeked.peek().await.unwrap();
|
||||
assert_eq!(first, batches[0]);
|
||||
|
||||
let result: Vec<RecordBatch> = peeked.scan_as_stream().try_collect().await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0], batches[0]);
|
||||
assert_eq!(result[1], batches[1]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scan_without_peek_passes_through() {
|
||||
let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap();
|
||||
let mut peeked = PeekedScannable::new(Box::new(batch.clone()));
|
||||
|
||||
let result: Vec<RecordBatch> = peeked.scan_as_stream().try_collect().await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0], batch);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delegates_num_rows() {
|
||||
let batches = vec![
|
||||
record_batch!(("id", Int64, [1, 2])).unwrap(),
|
||||
record_batch!(("id", Int64, [3])).unwrap(),
|
||||
];
|
||||
let peeked = PeekedScannable::new(Box::new(batches));
|
||||
assert_eq!(peeked.num_rows(), Some(3));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_non_rescannable_stream_data_preserved() {
|
||||
let batches = vec![
|
||||
record_batch!(("id", Int64, [1, 2])).unwrap(),
|
||||
record_batch!(("id", Int64, [3])).unwrap(),
|
||||
];
|
||||
let schema = batches[0].schema();
|
||||
let inner = futures::stream::iter(batches.clone().into_iter().map(Ok));
|
||||
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream {
|
||||
schema,
|
||||
stream: inner,
|
||||
});
|
||||
|
||||
let mut peeked = PeekedScannable::new(Box::new(stream));
|
||||
assert!(!peeked.rescannable());
|
||||
assert_eq!(peeked.num_rows(), None);
|
||||
|
||||
let first = peeked.peek().await.unwrap();
|
||||
assert_eq!(first, batches[0]);
|
||||
|
||||
// All data is still available via scan_as_stream
|
||||
let result: Vec<RecordBatch> = peeked.scan_as_stream().try_collect().await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0], batches[0]);
|
||||
assert_eq!(result[1], batches[1]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_error_in_first_batch_propagates() {
|
||||
let schema = Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new(
|
||||
"id",
|
||||
arrow_schema::DataType::Int64,
|
||||
false,
|
||||
)]));
|
||||
let inner = futures::stream::iter(vec![Err(Error::External {
|
||||
source: Box::new(TestCustomError),
|
||||
})]);
|
||||
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream {
|
||||
schema,
|
||||
stream: inner,
|
||||
});
|
||||
|
||||
let mut peeked = PeekedScannable::new(Box::new(stream));
|
||||
|
||||
// peek returns None for errors
|
||||
assert!(peeked.peek().await.is_none());
|
||||
|
||||
// But the error should come through when scanning
|
||||
let mut stream = peeked.scan_as_stream();
|
||||
let first = stream.next().await.unwrap();
|
||||
assert!(first.is_err());
|
||||
let err = first.unwrap_err();
|
||||
assert!(
|
||||
matches!(&err, Error::External { source } if source.downcast_ref::<TestCustomError>().is_some()),
|
||||
"Expected TestCustomError to be preserved, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_error_in_later_batch_propagates() {
|
||||
let good_batch = record_batch!(("id", Int64, [1, 2])).unwrap();
|
||||
let schema = good_batch.schema();
|
||||
let inner = futures::stream::iter(vec![
|
||||
Ok(good_batch.clone()),
|
||||
Err(Error::External {
|
||||
source: Box::new(TestCustomError),
|
||||
}),
|
||||
]);
|
||||
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream {
|
||||
schema,
|
||||
stream: inner,
|
||||
});
|
||||
|
||||
let mut peeked = PeekedScannable::new(Box::new(stream));
|
||||
|
||||
// peek succeeds with the first batch
|
||||
let first = peeked.peek().await.unwrap();
|
||||
assert_eq!(first, good_batch);
|
||||
|
||||
// scan_as_stream should yield the first batch, then the error
|
||||
let mut stream = peeked.scan_as_stream();
|
||||
let batch1 = stream.next().await.unwrap().unwrap();
|
||||
assert_eq!(batch1, good_batch);
|
||||
|
||||
let batch2 = stream.next().await.unwrap();
|
||||
assert!(batch2.is_err());
|
||||
let err = batch2.unwrap_err();
|
||||
assert!(
|
||||
matches!(&err, Error::External { source } if source.downcast_ref::<TestCustomError>().is_some()),
|
||||
"Expected TestCustomError to be preserved, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_empty_stream_returns_none() {
|
||||
let schema = Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new(
|
||||
"id",
|
||||
arrow_schema::DataType::Int64,
|
||||
false,
|
||||
)]));
|
||||
let inner = futures::stream::empty();
|
||||
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream {
|
||||
schema,
|
||||
stream: inner,
|
||||
});
|
||||
|
||||
let mut peeked = PeekedScannable::new(Box::new(stream));
|
||||
assert!(peeked.peek().await.is_none());
|
||||
|
||||
// Scanning an empty (post-peek) stream should yield nothing
|
||||
let result: Vec<RecordBatch> = peeked.scan_as_stream().try_collect().await.unwrap();
|
||||
assert!(result.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
mod estimate_write_partitions_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_small_data_single_partition() {
|
||||
// 100 rows * 24 bytes/row = 2400 bytes — well under both thresholds
|
||||
assert_eq!(estimate_write_partitions(2400, 100, Some(100), 8), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scales_by_row_count() {
|
||||
// 2.5M rows at 24 bytes/row — row threshold dominates
|
||||
// ceil(2_500_000 / 1_000_000) = 3
|
||||
assert_eq!(estimate_write_partitions(72, 3, Some(2_500_000), 8), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scales_by_byte_size() {
|
||||
// 100k rows at 40KB/row = ~4GB total → ceil(4GB / 2GB) = 2
|
||||
let sample_bytes = 40_000 * 10;
|
||||
assert_eq!(
|
||||
estimate_write_partitions(sample_bytes, 10, Some(100_000), 8),
|
||||
2
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_capped_at_max_partitions() {
|
||||
// 10M rows would want 10 partitions, but capped at 4
|
||||
assert_eq!(estimate_write_partitions(72, 3, Some(10_000_000), 4), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_sample_rows_returns_one() {
|
||||
assert_eq!(estimate_write_partitions(0, 0, Some(1_000_000), 8), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_row_hint_uses_sample_size() {
|
||||
// Without a hint, uses sample_rows (3), which is small
|
||||
assert_eq!(estimate_write_partitions(72, 3, None, 8), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_always_at_least_one() {
|
||||
assert_eq!(estimate_write_partitions(24, 1, Some(1), 8), 1);
|
||||
}
|
||||
}
|
||||
|
||||
mod embedding_tests {
|
||||
use super::*;
|
||||
use crate::embeddings::MemoryRegistry;
|
||||
|
||||
@@ -19,12 +19,12 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use lance::dataset::ReadParams;
|
||||
use lance_namespace::LanceNamespace;
|
||||
use lance_namespace::models::{
|
||||
CreateNamespaceRequest, CreateNamespaceResponse, DescribeNamespaceRequest,
|
||||
DescribeNamespaceResponse, DropNamespaceRequest, DropNamespaceResponse, ListNamespacesRequest,
|
||||
ListNamespacesResponse, ListTablesRequest, ListTablesResponse,
|
||||
};
|
||||
use lance_namespace::LanceNamespace;
|
||||
|
||||
use crate::data::scannable::Scannable;
|
||||
use crate::error::Result;
|
||||
@@ -66,6 +66,10 @@ pub struct OpenTableRequest {
|
||||
/// Optional namespace client for server-side query execution.
|
||||
/// When set, queries will be executed on the namespace server instead of locally.
|
||||
pub namespace_client: Option<Arc<dyn LanceNamespace>>,
|
||||
/// Whether managed versioning is enabled for this table.
|
||||
/// When Some(true), the table will use namespace-managed commits instead of local commits.
|
||||
/// When None and namespace_client is provided, the value will be fetched from the namespace.
|
||||
pub managed_versioning: Option<bool>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OpenTableRequest {
|
||||
@@ -77,6 +81,7 @@ impl std::fmt::Debug for OpenTableRequest {
|
||||
.field("lance_read_params", &self.lance_read_params)
|
||||
.field("location", &self.location)
|
||||
.field("namespace_client", &self.namespace_client)
|
||||
.field("managed_versioning", &self.managed_versioning)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
@@ -85,8 +90,10 @@ pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableRequest) -> OpenTableReq
|
||||
|
||||
/// Describes what happens when creating a table and a table with
|
||||
/// the same name already exists
|
||||
#[derive(Default)]
|
||||
pub enum CreateTableMode {
|
||||
/// If the table already exists, an error is returned
|
||||
#[default]
|
||||
Create,
|
||||
/// If the table already exists, it is opened. Any provided data is
|
||||
/// ignored. The function will be passed an OpenTableBuilder to customize
|
||||
@@ -104,12 +111,6 @@ impl CreateTableMode {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CreateTableMode {
|
||||
fn default() -> Self {
|
||||
Self::Create
|
||||
}
|
||||
}
|
||||
|
||||
/// A request to create a table
|
||||
pub struct CreateTableRequest {
|
||||
/// The name of the new table
|
||||
@@ -165,6 +166,9 @@ pub struct CloneTableRequest {
|
||||
/// Whether to perform a shallow clone (true) or deep clone (false). Defaults to true.
|
||||
/// Currently only shallow clone is supported.
|
||||
pub is_shallow: bool,
|
||||
/// Optional namespace client for managed versioning support.
|
||||
/// When set, enables the commit handler to track table versions through the namespace.
|
||||
pub namespace_client: Option<Arc<dyn LanceNamespace>>,
|
||||
}
|
||||
|
||||
impl CloneTableRequest {
|
||||
@@ -176,6 +180,7 @@ impl CloneTableRequest {
|
||||
source_version: None,
|
||||
source_tag: None,
|
||||
is_shallow: true,
|
||||
namespace_client: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ use std::path::Path;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use lance::dataset::refs::Ref;
|
||||
use lance::dataset::{builder::DatasetBuilder, ReadParams, WriteMode};
|
||||
use lance::dataset::{ReadParams, WriteMode, builder::DatasetBuilder};
|
||||
use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore};
|
||||
use lance_datafusion::utils::StreamingWriteSource;
|
||||
use lance_encoding::version::LanceFileVersion;
|
||||
@@ -669,6 +669,7 @@ impl ListingDatabase {
|
||||
lance_read_params: None,
|
||||
location: None,
|
||||
namespace_client: None,
|
||||
managed_versioning: None,
|
||||
};
|
||||
let req = (callback)(req);
|
||||
let table = self.open_table(req).await?;
|
||||
@@ -869,6 +870,7 @@ impl Database for ListingDatabase {
|
||||
Some(write_params),
|
||||
self.read_consistency_interval,
|
||||
request.namespace_client,
|
||||
false, // server_side_query_enabled - listing database doesn't support server-side queries
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -946,7 +948,9 @@ impl Database for ListingDatabase {
|
||||
self.store_wrapper.clone(),
|
||||
None,
|
||||
self.read_consistency_interval,
|
||||
None,
|
||||
request.namespace_client,
|
||||
false, // server_side_query_enabled - listing database doesn't support server-side queries
|
||||
None, // managed_versioning - will be queried if namespace_client is provided
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -1022,6 +1026,8 @@ impl Database for ListingDatabase {
|
||||
Some(read_params),
|
||||
self.read_consistency_interval,
|
||||
request.namespace_client,
|
||||
false, // server_side_query_enabled - listing database doesn't support server-side queries
|
||||
request.managed_versioning, // Pass through managed_versioning from request
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
@@ -1097,11 +1103,11 @@ impl Database for ListingDatabase {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::Table;
|
||||
use crate::connection::ConnectRequest;
|
||||
use crate::data::scannable::Scannable;
|
||||
use crate::database::{CreateTableMode, CreateTableRequest};
|
||||
use crate::table::WriteOptions;
|
||||
use crate::Table;
|
||||
use arrow_array::{Int32Array, RecordBatch, StringArray};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use std::path::PathBuf;
|
||||
@@ -1162,6 +1168,7 @@ mod tests {
|
||||
source_version: None,
|
||||
source_tag: None,
|
||||
is_shallow: true,
|
||||
namespace_client: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1222,6 +1229,7 @@ mod tests {
|
||||
source_version: None,
|
||||
source_tag: None,
|
||||
is_shallow: true,
|
||||
namespace_client: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1281,6 +1289,7 @@ mod tests {
|
||||
source_version: None,
|
||||
source_tag: None,
|
||||
is_shallow: true,
|
||||
namespace_client: None,
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -1317,6 +1326,7 @@ mod tests {
|
||||
source_version: None,
|
||||
source_tag: None,
|
||||
is_shallow: false, // Request deep clone
|
||||
namespace_client: None,
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -1357,6 +1367,7 @@ mod tests {
|
||||
source_version: None,
|
||||
source_tag: None,
|
||||
is_shallow: true,
|
||||
namespace_client: None,
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -1397,6 +1408,7 @@ mod tests {
|
||||
source_version: None,
|
||||
source_tag: None,
|
||||
is_shallow: true,
|
||||
namespace_client: None,
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -1416,6 +1428,7 @@ mod tests {
|
||||
source_version: None,
|
||||
source_tag: None,
|
||||
is_shallow: true,
|
||||
namespace_client: None,
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -1452,6 +1465,7 @@ mod tests {
|
||||
source_version: Some(1),
|
||||
source_tag: Some("v1.0".to_string()),
|
||||
is_shallow: true,
|
||||
namespace_client: None,
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -1525,6 +1539,7 @@ mod tests {
|
||||
source_version: Some(initial_version),
|
||||
source_tag: None,
|
||||
is_shallow: true,
|
||||
namespace_client: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1603,6 +1618,7 @@ mod tests {
|
||||
source_version: None,
|
||||
source_tag: Some("v1.0".to_string()),
|
||||
is_shallow: true,
|
||||
namespace_client: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1654,6 +1670,7 @@ mod tests {
|
||||
source_version: None,
|
||||
source_tag: None,
|
||||
is_shallow: true,
|
||||
namespace_client: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1746,6 +1763,7 @@ mod tests {
|
||||
source_version: None,
|
||||
source_tag: None,
|
||||
is_shallow: true,
|
||||
namespace_client: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -7,17 +7,20 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use lance::io::commit::namespace_manifest::LanceNamespaceExternalManifestStore;
|
||||
use lance_io::object_store::{ObjectStoreParams, StorageOptionsAccessor};
|
||||
use lance_namespace::{
|
||||
models::{
|
||||
CreateEmptyTableRequest, CreateNamespaceRequest, CreateNamespaceResponse,
|
||||
DeclareTableRequest, DescribeNamespaceRequest, DescribeNamespaceResponse,
|
||||
DescribeTableRequest, DropNamespaceRequest, DropNamespaceResponse, DropTableRequest,
|
||||
ListNamespacesRequest, ListNamespacesResponse, ListTablesRequest, ListTablesResponse,
|
||||
},
|
||||
LanceNamespace,
|
||||
models::{
|
||||
CreateNamespaceRequest, CreateNamespaceResponse, DeclareTableRequest,
|
||||
DescribeNamespaceRequest, DescribeNamespaceResponse, DescribeTableRequest,
|
||||
DropNamespaceRequest, DropNamespaceResponse, DropTableRequest, ListNamespacesRequest,
|
||||
ListNamespacesResponse, ListTablesRequest, ListTablesResponse,
|
||||
},
|
||||
};
|
||||
use lance_namespace_impls::ConnectBuilder;
|
||||
use log::warn;
|
||||
use lance_table::io::commit::CommitHandler;
|
||||
use lance_table::io::commit::external_manifest::ExternalManifestCommitHandler;
|
||||
|
||||
use crate::database::ReadConsistency;
|
||||
use crate::error::{Error, Result};
|
||||
@@ -205,54 +208,49 @@ impl Database for LanceNamespaceDatabase {
|
||||
let mut table_id = request.namespace.clone();
|
||||
table_id.push(request.name.clone());
|
||||
|
||||
// Try declare_table first, falling back to create_empty_table for backwards
|
||||
// compatibility with older namespace clients that don't support declare_table
|
||||
let declare_request = DeclareTableRequest {
|
||||
id: Some(table_id.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let location = match self.namespace.declare_table(declare_request).await {
|
||||
Ok(response) => response.location.ok_or_else(|| Error::Runtime {
|
||||
let (location, initial_storage_options, managed_versioning) = {
|
||||
let response = self.namespace.declare_table(declare_request).await?;
|
||||
let loc = response.location.ok_or_else(|| Error::Runtime {
|
||||
message: "Table location is missing from declare_table response".to_string(),
|
||||
})?,
|
||||
Err(e) => {
|
||||
// Check if the error is "not supported" and try create_empty_table as fallback
|
||||
let err_str = e.to_string().to_lowercase();
|
||||
if err_str.contains("not supported") || err_str.contains("not implemented") {
|
||||
warn!(
|
||||
"declare_table is not supported by the namespace client, \
|
||||
falling back to deprecated create_empty_table. \
|
||||
create_empty_table is deprecated and will be removed in Lance 3.0.0. \
|
||||
Please upgrade your namespace client to support declare_table."
|
||||
);
|
||||
#[allow(deprecated)]
|
||||
let create_empty_request = CreateEmptyTableRequest {
|
||||
id: Some(table_id.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
#[allow(deprecated)]
|
||||
let create_response = self
|
||||
.namespace
|
||||
.create_empty_table(create_empty_request)
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to create empty table: {}", e),
|
||||
})?;
|
||||
|
||||
create_response.location.ok_or_else(|| Error::Runtime {
|
||||
message: "Table location is missing from create_empty_table response"
|
||||
.to_string(),
|
||||
})?
|
||||
} else {
|
||||
return Err(Error::Runtime {
|
||||
message: format!("Failed to declare table: {}", e),
|
||||
});
|
||||
}
|
||||
}
|
||||
})?;
|
||||
// Use storage options from response, fall back to self.storage_options
|
||||
let opts = response
|
||||
.storage_options
|
||||
.or_else(|| Some(self.storage_options.clone()))
|
||||
.filter(|o| !o.is_empty());
|
||||
(loc, opts, response.managed_versioning)
|
||||
};
|
||||
|
||||
// Build write params with storage options and commit handler
|
||||
let mut params = request.write_options.lance_write_params.unwrap_or_default();
|
||||
|
||||
// Set up storage options if provided
|
||||
if let Some(storage_opts) = initial_storage_options {
|
||||
let store_params = params
|
||||
.store_params
|
||||
.get_or_insert_with(ObjectStoreParams::default);
|
||||
store_params.storage_options_accessor = Some(Arc::new(
|
||||
StorageOptionsAccessor::with_static_options(storage_opts),
|
||||
));
|
||||
}
|
||||
|
||||
// Set up commit handler when managed_versioning is enabled
|
||||
if managed_versioning == Some(true) {
|
||||
let external_store =
|
||||
LanceNamespaceExternalManifestStore::new(self.namespace.clone(), table_id.clone());
|
||||
let commit_handler: Arc<dyn CommitHandler> = Arc::new(ExternalManifestCommitHandler {
|
||||
external_manifest_store: Arc::new(external_store),
|
||||
});
|
||||
params.commit_handler = Some(commit_handler);
|
||||
}
|
||||
|
||||
let write_params = Some(params);
|
||||
|
||||
let native_table = NativeTable::create_from_namespace(
|
||||
self.namespace.clone(),
|
||||
&location,
|
||||
@@ -260,7 +258,7 @@ impl Database for LanceNamespaceDatabase {
|
||||
request.namespace.clone(),
|
||||
request.data,
|
||||
None, // write_store_wrapper not used for namespace connections
|
||||
request.write_options.lance_write_params,
|
||||
write_params,
|
||||
self.read_consistency_interval,
|
||||
self.server_side_query_enabled,
|
||||
self.session.clone(),
|
||||
|
||||
@@ -11,16 +11,16 @@ use lance_core::ROW_ID;
|
||||
use lance_datafusion::exec::SessionContextExt;
|
||||
|
||||
use crate::{
|
||||
Error, Result, Table,
|
||||
arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream},
|
||||
connect,
|
||||
database::{CreateTableRequest, Database},
|
||||
dataloader::permutation::{
|
||||
shuffle::{Shuffler, ShufflerConfig},
|
||||
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
|
||||
util::{rename_column, TemporaryDirectory},
|
||||
split::{SPLIT_ID_COLUMN, SplitStrategy, Splitter},
|
||||
util::{TemporaryDirectory, rename_column},
|
||||
},
|
||||
query::{ExecutableQuery, QueryBase, Select},
|
||||
Error, Result, Table,
|
||||
};
|
||||
|
||||
pub const SRC_ROW_ID_COL: &str = "row_id";
|
||||
@@ -57,7 +57,7 @@ pub struct PermutationConfig {
|
||||
}
|
||||
|
||||
/// Strategy for shuffling the data.
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub enum ShuffleStrategy {
|
||||
/// The data is randomly shuffled
|
||||
///
|
||||
@@ -78,15 +78,10 @@ pub enum ShuffleStrategy {
|
||||
/// The data is not shuffled
|
||||
///
|
||||
/// This is useful for debugging and testing.
|
||||
#[default]
|
||||
None,
|
||||
}
|
||||
|
||||
impl Default for ShuffleStrategy {
|
||||
fn default() -> Self {
|
||||
Self::None
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for creating a permutation table.
|
||||
///
|
||||
/// A permutation table is a table that stores split assignments and a shuffled order of rows. This
|
||||
|
||||
@@ -25,8 +25,8 @@ use futures::{StreamExt, TryStreamExt};
|
||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
use lance::io::RecordBatchStream;
|
||||
use lance_arrow::RecordBatchExt;
|
||||
use lance_core::error::LanceOptionExt;
|
||||
use lance_core::ROW_ID;
|
||||
use lance_core::error::LanceOptionExt;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -426,6 +426,7 @@ impl PermutationReader {
|
||||
row_ids_query = row_ids_query.limit(limit as usize);
|
||||
}
|
||||
let mut row_ids = row_ids_query.execute().await?;
|
||||
let mut idx_offset = 0;
|
||||
while let Some(batch) = row_ids.try_next().await? {
|
||||
let row_ids = batch
|
||||
.column(0)
|
||||
@@ -433,8 +434,9 @@ impl PermutationReader {
|
||||
.values()
|
||||
.to_vec();
|
||||
for (i, row_id) in row_ids.iter().enumerate() {
|
||||
offset_map.insert(i as u64, *row_id);
|
||||
offset_map.insert(i as u64 + idx_offset, *row_id);
|
||||
}
|
||||
idx_offset += batch.num_rows() as u64;
|
||||
}
|
||||
let offset_map = Arc::new(offset_map);
|
||||
*offset_map_ref = Some(offset_map.clone());
|
||||
@@ -498,10 +500,10 @@ mod tests {
|
||||
use rand::seq::SliceRandom;
|
||||
|
||||
use crate::{
|
||||
Table,
|
||||
arrow::SendableRecordBatchStream,
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
test_utils::datagen::{virtual_table, LanceDbDatagenExt},
|
||||
Table,
|
||||
test_utils::datagen::{LanceDbDatagenExt, virtual_table},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@@ -845,4 +847,106 @@ mod tests {
|
||||
.to_vec();
|
||||
assert_eq!(idx_values, vec![row_ids[2] as i32]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_filtered_permutation_full_iteration() {
|
||||
use crate::dataloader::permutation::builder::PermutationBuilder;
|
||||
|
||||
// Create a base table with 10000 rows where idx goes 0..10000.
|
||||
// Filter to even values only, giving 5000 rows in the permutation.
|
||||
let base_table = lance_datagen::gen_batch()
|
||||
.col("idx", lance_datagen::array::step::<Int32Type>())
|
||||
.into_mem_table("tbl", RowCount::from(10000), BatchCount::from(1))
|
||||
.await;
|
||||
|
||||
let permutation_table = PermutationBuilder::new(base_table.clone())
|
||||
.with_filter("idx % 2 = 0".to_string())
|
||||
.build()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 5000);
|
||||
|
||||
let reader = PermutationReader::try_from_tables(
|
||||
base_table.base_table().clone(),
|
||||
permutation_table.base_table().clone(),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(reader.count_rows(), 5000);
|
||||
|
||||
// Iterate through all batches using a batch size that doesn't evenly divide
|
||||
// the row count (5000 / 128 = 39 full batches + 1 batch of 8 rows).
|
||||
let batch_size = 128;
|
||||
let mut stream = reader
|
||||
.read(
|
||||
Select::All,
|
||||
QueryExecutionOptions {
|
||||
max_batch_length: batch_size,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut total_rows = 0u64;
|
||||
let mut all_idx_values = Vec::new();
|
||||
while let Some(batch) = stream.try_next().await.unwrap() {
|
||||
assert!(batch.num_rows() <= batch_size as usize);
|
||||
total_rows += batch.num_rows() as u64;
|
||||
let idx_col = batch.column(0).as_primitive::<Int32Type>().values();
|
||||
all_idx_values.extend(idx_col.iter().copied());
|
||||
}
|
||||
|
||||
assert_eq!(total_rows, 5000);
|
||||
assert_eq!(all_idx_values.len(), 5000);
|
||||
|
||||
// Every value should be even (from the filter)
|
||||
assert!(all_idx_values.iter().all(|v| v % 2 == 0));
|
||||
|
||||
// Should have 5000 unique values
|
||||
let unique: std::collections::HashSet<i32> = all_idx_values.iter().copied().collect();
|
||||
assert_eq!(unique.len(), 5000);
|
||||
|
||||
// Use take_offsets to fetch rows from the beginning, middle, and end
|
||||
// of the permutation. The values should match what we saw during iteration.
|
||||
|
||||
// Beginning
|
||||
let batch = reader.take_offsets(&[0, 1, 2], Select::All).await.unwrap();
|
||||
assert_eq!(batch.num_rows(), 3);
|
||||
let idx_values = batch
|
||||
.column(0)
|
||||
.as_primitive::<Int32Type>()
|
||||
.values()
|
||||
.to_vec();
|
||||
assert_eq!(idx_values, &all_idx_values[0..3]);
|
||||
|
||||
// Middle
|
||||
let batch = reader
|
||||
.take_offsets(&[2499, 2500, 2501], Select::All)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(batch.num_rows(), 3);
|
||||
let idx_values = batch
|
||||
.column(0)
|
||||
.as_primitive::<Int32Type>()
|
||||
.values()
|
||||
.to_vec();
|
||||
assert_eq!(idx_values, &all_idx_values[2499..2502]);
|
||||
|
||||
// End (last 3 rows)
|
||||
let batch = reader
|
||||
.take_offsets(&[4997, 4998, 4999], Select::All)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(batch.num_rows(), 3);
|
||||
let idx_values = batch
|
||||
.column(0)
|
||||
.as_primitive::<Int32Type>()
|
||||
.values()
|
||||
.to_vec();
|
||||
assert_eq!(idx_values, &all_idx_values[4997..5000]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,12 +18,12 @@ use lance_io::{
|
||||
scheduler::{ScanScheduler, SchedulerConfig},
|
||||
utils::CachedFileSize,
|
||||
};
|
||||
use rand::{seq::SliceRandom, Rng, RngCore};
|
||||
use rand::{Rng, RngCore, seq::SliceRandom};
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
dataloader::permutation::util::{non_crypto_rng, TemporaryDirectory},
|
||||
Error, Result,
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
dataloader::permutation::util::{TemporaryDirectory, non_crypto_rng},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -281,7 +281,7 @@ mod tests {
|
||||
use datafusion_expr::col;
|
||||
use futures::TryStreamExt;
|
||||
use lance_datagen::{BatchCount, BatchGeneratorBuilder, ByteCount, RowCount, Seed};
|
||||
use rand::{rngs::SmallRng, SeedableRng};
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
|
||||
fn test_gen() -> BatchGeneratorBuilder {
|
||||
lance_datagen::gen_batch()
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
|
||||
Arc,
|
||||
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
|
||||
};
|
||||
|
||||
use arrow_array::{Array, BooleanArray, RecordBatch, UInt64Array};
|
||||
@@ -15,21 +15,22 @@ use lance_arrow::SchemaExt;
|
||||
use lance_core::ROW_ID;
|
||||
|
||||
use crate::{
|
||||
Error, Result,
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
dataloader::{
|
||||
permutation::shuffle::{Shuffler, ShufflerConfig},
|
||||
permutation::util::TemporaryDirectory,
|
||||
},
|
||||
query::{Query, QueryBase, Select},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
pub const SPLIT_ID_COLUMN: &str = "split_id";
|
||||
|
||||
/// Strategy for assigning rows to splits
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub enum SplitStrategy {
|
||||
/// All rows will have split id 0
|
||||
#[default]
|
||||
NoSplit,
|
||||
/// Rows will be randomly assigned to splits
|
||||
///
|
||||
@@ -73,15 +74,6 @@ pub enum SplitStrategy {
|
||||
Calculated { calculation: String },
|
||||
}
|
||||
|
||||
// The default is not to split the data
|
||||
//
|
||||
// All data will be assigned to a single split.
|
||||
impl Default for SplitStrategy {
|
||||
fn default() -> Self {
|
||||
Self::NoSplit
|
||||
}
|
||||
}
|
||||
|
||||
impl SplitStrategy {
|
||||
pub fn validate(&self, num_rows: u64) -> Result<()> {
|
||||
match self {
|
||||
|
||||
@@ -7,12 +7,12 @@ use arrow_array::RecordBatch;
|
||||
use arrow_schema::{Fields, Schema};
|
||||
use datafusion_execution::disk_manager::DiskManagerMode;
|
||||
use futures::TryStreamExt;
|
||||
use rand::{rngs::SmallRng, RngCore, SeedableRng};
|
||||
use rand::{RngCore, SeedableRng, rngs::SmallRng};
|
||||
use tempfile::TempDir;
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
Error, Result,
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
};
|
||||
|
||||
/// Directory to use for temporary files
|
||||
|
||||
@@ -23,9 +23,9 @@ use arrow_schema::{DataType, Field, SchemaBuilder, SchemaRef};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
Error,
|
||||
error::Result,
|
||||
table::{ColumnDefinition, ColumnKind, TableDefinition},
|
||||
Error,
|
||||
};
|
||||
|
||||
/// Trait for embedding functions
|
||||
|
||||
@@ -8,7 +8,7 @@ use arrow::array::{AsArray, Float32Builder};
|
||||
use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array};
|
||||
use arrow_data::ArrayData;
|
||||
use arrow_schema::DataType;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
use super::EmbeddingFunction;
|
||||
use crate::{Error, Result};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user