mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-01 03:10:43 +00:00
Compare commits
1 Commits
python-v0.
...
codex/upda
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
beaf2e35dc |
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.27.2-beta.1"
|
current_version = "0.27.0-beta.5"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
@@ -23,10 +23,8 @@ runs:
|
|||||||
steps:
|
steps:
|
||||||
- name: CONFIRM ARM BUILD
|
- name: CONFIRM ARM BUILD
|
||||||
shell: bash
|
shell: bash
|
||||||
env:
|
|
||||||
ARM_BUILD: ${{ inputs.arm-build }}
|
|
||||||
run: |
|
run: |
|
||||||
echo "ARM BUILD: $ARM_BUILD"
|
echo "ARM BUILD: ${{ inputs.arm-build }}"
|
||||||
- name: Build x86_64 Manylinux wheel
|
- name: Build x86_64 Manylinux wheel
|
||||||
if: ${{ inputs.arm-build == 'false' }}
|
if: ${{ inputs.arm-build == 'false' }}
|
||||||
uses: PyO3/maturin-action@v1
|
uses: PyO3/maturin-action@v1
|
||||||
|
|||||||
1
.github/workflows/nodejs.yml
vendored
1
.github/workflows/nodejs.yml
vendored
@@ -7,7 +7,6 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- Cargo.toml
|
- Cargo.toml
|
||||||
- Cargo.lock
|
|
||||||
- nodejs/**
|
- nodejs/**
|
||||||
- rust/**
|
- rust/**
|
||||||
- docs/src/js/**
|
- docs/src/js/**
|
||||||
|
|||||||
17
.github/workflows/npm-publish.yml
vendored
17
.github/workflows/npm-publish.yml
vendored
@@ -19,7 +19,6 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- .github/workflows/npm-publish.yml
|
- .github/workflows/npm-publish.yml
|
||||||
- Cargo.toml # Change in dependency frequently breaks builds
|
- Cargo.toml # Change in dependency frequently breaks builds
|
||||||
- Cargo.lock
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
@@ -125,12 +124,7 @@ jobs:
|
|||||||
pre_build: |-
|
pre_build: |-
|
||||||
set -e &&
|
set -e &&
|
||||||
apt-get update &&
|
apt-get update &&
|
||||||
apt-get install -y protobuf-compiler pkg-config &&
|
apt-get install -y protobuf-compiler pkg-config
|
||||||
# The base image (manylinux2014-cross) sets TARGET_CC to the old
|
|
||||||
# GCC 4.8 cross-compiler. aws-lc-sys checks TARGET_CC before CC,
|
|
||||||
# so it picks up GCC even though the napi-rs image sets CC=clang.
|
|
||||||
# Override to use the image's clang-18 which supports -fuse-ld=lld.
|
|
||||||
export TARGET_CC=clang TARGET_CXX=clang++
|
|
||||||
- target: x86_64-unknown-linux-musl
|
- target: x86_64-unknown-linux-musl
|
||||||
# This one seems to need some extra memory
|
# This one seems to need some extra memory
|
||||||
host: ubuntu-2404-8x-x64
|
host: ubuntu-2404-8x-x64
|
||||||
@@ -150,10 +144,9 @@ jobs:
|
|||||||
set -e &&
|
set -e &&
|
||||||
apt-get update &&
|
apt-get update &&
|
||||||
apt-get install -y protobuf-compiler pkg-config &&
|
apt-get install -y protobuf-compiler pkg-config &&
|
||||||
export TARGET_CC=clang TARGET_CXX=clang++ &&
|
# https://github.com/aws/aws-lc-rs/issues/737#issuecomment-2725918627
|
||||||
# The manylinux2014 sysroot has glibc 2.17 headers which lack
|
ln -s /usr/aarch64-unknown-linux-gnu/lib/gcc/aarch64-unknown-linux-gnu/4.8.5/crtbeginS.o /usr/aarch64-unknown-linux-gnu/aarch64-unknown-linux-gnu/sysroot/usr/lib/crtbeginS.o &&
|
||||||
# AT_HWCAP2 (added in Linux 3.17). Define it for aws-lc-sys.
|
ln -s /usr/aarch64-unknown-linux-gnu/lib/gcc /usr/aarch64-unknown-linux-gnu/aarch64-unknown-linux-gnu/sysroot/usr/lib/gcc &&
|
||||||
export CFLAGS="$CFLAGS -DAT_HWCAP2=26" &&
|
|
||||||
rustup target add aarch64-unknown-linux-gnu
|
rustup target add aarch64-unknown-linux-gnu
|
||||||
- target: aarch64-unknown-linux-musl
|
- target: aarch64-unknown-linux-musl
|
||||||
host: ubuntu-2404-8x-x64
|
host: ubuntu-2404-8x-x64
|
||||||
@@ -273,7 +266,7 @@ jobs:
|
|||||||
- target: x86_64-unknown-linux-gnu
|
- target: x86_64-unknown-linux-gnu
|
||||||
host: ubuntu-latest
|
host: ubuntu-latest
|
||||||
- target: aarch64-unknown-linux-gnu
|
- target: aarch64-unknown-linux-gnu
|
||||||
host: ubuntu-2404-8x-arm64
|
host: buildjet-16vcpu-ubuntu-2204-arm
|
||||||
node:
|
node:
|
||||||
- '20'
|
- '20'
|
||||||
runs-on: ${{ matrix.settings.host }}
|
runs-on: ${{ matrix.settings.host }}
|
||||||
|
|||||||
1
.github/workflows/pypi-publish.yml
vendored
1
.github/workflows/pypi-publish.yml
vendored
@@ -9,7 +9,6 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- .github/workflows/pypi-publish.yml
|
- .github/workflows/pypi-publish.yml
|
||||||
- Cargo.toml # Change in dependency frequently breaks builds
|
- Cargo.toml # Change in dependency frequently breaks builds
|
||||||
- Cargo.lock
|
|
||||||
|
|
||||||
env:
|
env:
|
||||||
PIP_EXTRA_INDEX_URL: "https://pypi.fury.io/lance-format/ https://pypi.fury.io/lancedb/"
|
PIP_EXTRA_INDEX_URL: "https://pypi.fury.io/lance-format/ https://pypi.fury.io/lancedb/"
|
||||||
|
|||||||
1
.github/workflows/python.yml
vendored
1
.github/workflows/python.yml
vendored
@@ -7,7 +7,6 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- Cargo.toml
|
- Cargo.toml
|
||||||
- Cargo.lock
|
|
||||||
- python/**
|
- python/**
|
||||||
- rust/**
|
- rust/**
|
||||||
- .github/workflows/python.yml
|
- .github/workflows/python.yml
|
||||||
|
|||||||
17
.github/workflows/rust.yml
vendored
17
.github/workflows/rust.yml
vendored
@@ -7,7 +7,6 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- Cargo.toml
|
- Cargo.toml
|
||||||
- Cargo.lock
|
|
||||||
- rust/**
|
- rust/**
|
||||||
- .github/workflows/rust.yml
|
- .github/workflows/rust.yml
|
||||||
|
|
||||||
@@ -207,14 +206,14 @@ jobs:
|
|||||||
- name: Downgrade dependencies
|
- name: Downgrade dependencies
|
||||||
# These packages have newer requirements for MSRV
|
# These packages have newer requirements for MSRV
|
||||||
run: |
|
run: |
|
||||||
cargo update -p aws-sdk-bedrockruntime --precise 1.77.0
|
cargo update -p aws-sdk-bedrockruntime --precise 1.64.0
|
||||||
cargo update -p aws-sdk-dynamodb --precise 1.68.0
|
cargo update -p aws-sdk-dynamodb --precise 1.55.0
|
||||||
cargo update -p aws-config --precise 1.6.0
|
cargo update -p aws-config --precise 1.5.10
|
||||||
cargo update -p aws-sdk-kms --precise 1.63.0
|
cargo update -p aws-sdk-kms --precise 1.51.0
|
||||||
cargo update -p aws-sdk-s3 --precise 1.79.0
|
cargo update -p aws-sdk-s3 --precise 1.65.0
|
||||||
cargo update -p aws-sdk-sso --precise 1.62.0
|
cargo update -p aws-sdk-sso --precise 1.50.0
|
||||||
cargo update -p aws-sdk-ssooidc --precise 1.63.0
|
cargo update -p aws-sdk-ssooidc --precise 1.51.0
|
||||||
cargo update -p aws-sdk-sts --precise 1.63.0
|
cargo update -p aws-sdk-sts --precise 1.51.0
|
||||||
cargo update -p home --precise 0.5.9
|
cargo update -p home --precise 0.5.9
|
||||||
- name: cargo +${{ matrix.msrv }} check
|
- name: cargo +${{ matrix.msrv }} check
|
||||||
env:
|
env:
|
||||||
|
|||||||
2115
Cargo.lock
generated
2115
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
28
Cargo.toml
28
Cargo.toml
@@ -15,20 +15,20 @@ categories = ["database-implementations"]
|
|||||||
rust-version = "1.91.0"
|
rust-version = "1.91.0"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { version = "=4.0.0", default-features = false }
|
lance = { "version" = "=4.0.0-beta.11", default-features = false, "tag" = "v4.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-core = { version = "=4.0.0" }
|
lance-core = { "version" = "=4.0.0-beta.11", "tag" = "v4.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-datagen = { version = "=4.0.0" }
|
lance-datagen = { "version" = "=4.0.0-beta.11", "tag" = "v4.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-file = { version = "=4.0.0" }
|
lance-file = { "version" = "=4.0.0-beta.11", "tag" = "v4.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-io = { version = "=4.0.0", default-features = false }
|
lance-io = { "version" = "=4.0.0-beta.11", default-features = false, "tag" = "v4.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-index = { version = "=4.0.0" }
|
lance-index = { "version" = "=4.0.0-beta.11", "tag" = "v4.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-linalg = { version = "=4.0.0" }
|
lance-linalg = { "version" = "=4.0.0-beta.11", "tag" = "v4.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-namespace = { version = "=4.0.0" }
|
lance-namespace = { "version" = "=4.0.0-beta.11", "tag" = "v4.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-namespace-impls = { version = "=4.0.0", default-features = false }
|
lance-namespace-impls = { "version" = "=4.0.0-beta.11", default-features = false, "tag" = "v4.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-table = { version = "=4.0.0" }
|
lance-table = { "version" = "=4.0.0-beta.11", "tag" = "v4.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-testing = { version = "=4.0.0" }
|
lance-testing = { "version" = "=4.0.0-beta.11", "tag" = "v4.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-datafusion = { version = "=4.0.0" }
|
lance-datafusion = { "version" = "=4.0.0-beta.11", "tag" = "v4.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-encoding = { version = "=4.0.0" }
|
lance-encoding = { "version" = "=4.0.0-beta.11", "tag" = "v4.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
lance-arrow = { version = "=4.0.0" }
|
lance-arrow = { "version" = "=4.0.0-beta.11", "tag" = "v4.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
|
||||||
ahash = "0.8"
|
ahash = "0.8"
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "57.2", optional = false }
|
arrow = { version = "57.2", optional = false }
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import functools
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -27,7 +26,6 @@ SEMVER_RE = re.compile(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@functools.total_ordering
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class SemVer:
|
class SemVer:
|
||||||
major: int
|
major: int
|
||||||
@@ -158,9 +156,7 @@ def read_current_version(repo_root: Path) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def determine_latest_tag(tags: Iterable[TagInfo]) -> TagInfo:
|
def determine_latest_tag(tags: Iterable[TagInfo]) -> TagInfo:
|
||||||
# Stable releases (no prerelease) are always preferred over pre-releases.
|
return max(tags, key=lambda tag: tag.semver)
|
||||||
# Within each group, standard semver ordering applies.
|
|
||||||
return max(tags, key=lambda tag: (not tag.semver.prerelease, tag.semver))
|
|
||||||
|
|
||||||
|
|
||||||
def write_outputs(args: argparse.Namespace, payload: dict) -> None:
|
def write_outputs(args: argparse.Namespace, payload: dict) -> None:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
version: "3.9"
|
version: "3.9"
|
||||||
services:
|
services:
|
||||||
localstack:
|
localstack:
|
||||||
image: localstack/localstack:4.0
|
image: localstack/localstack:3.3
|
||||||
ports:
|
ports:
|
||||||
- 4566:4566
|
- 4566:4566
|
||||||
environment:
|
environment:
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
mkdocs==1.6.1
|
mkdocs==1.5.3
|
||||||
mkdocs-jupyter==0.24.1
|
mkdocs-jupyter==0.24.1
|
||||||
mkdocs-material==9.6.23
|
mkdocs-material==9.5.3
|
||||||
mkdocs-autorefs>=0.5,<=1.0
|
mkdocs-autorefs>=0.5,<=1.0
|
||||||
mkdocstrings[python]>=0.24,<1.0
|
mkdocstrings[python]==0.25.2
|
||||||
griffe>=0.40,<1.0
|
griffe>=0.40,<1.0
|
||||||
mkdocs-render-swagger-plugin>=0.1.0
|
mkdocs-render-swagger-plugin>=0.1.0
|
||||||
pydantic>=2.0,<3.0
|
pydantic>=2.0,<3.0
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-core</artifactId>
|
<artifactId>lancedb-core</artifactId>
|
||||||
<version>0.27.2-beta.1</version>
|
<version>0.27.0-beta.5</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -485,7 +485,19 @@ Modeled after ``VACUUM`` in PostgreSQL.
|
|||||||
- Prune: Removes old versions of the dataset
|
- Prune: Removes old versions of the dataset
|
||||||
- Index: Optimizes the indices, adding new data to existing indices
|
- Index: Optimizes the indices, adding new data to existing indices
|
||||||
|
|
||||||
The frequency an application should call optimize is based on the frequency of
|
Experimental API
|
||||||
|
----------------
|
||||||
|
|
||||||
|
The optimization process is undergoing active development and may change.
|
||||||
|
Our goal with these changes is to improve the performance of optimization and
|
||||||
|
reduce the complexity.
|
||||||
|
|
||||||
|
That being said, it is essential today to run optimize if you want the best
|
||||||
|
performance. It should be stable and safe to use in production, but it our
|
||||||
|
hope that the API may be simplified (or not even need to be called) in the
|
||||||
|
future.
|
||||||
|
|
||||||
|
The frequency an application shoudl call optimize is based on the frequency of
|
||||||
data modifications. If data is frequently added, deleted, or updated then
|
data modifications. If data is frequently added, deleted, or updated then
|
||||||
optimize should be run frequently. A good rule of thumb is to run optimize if
|
optimize should be run frequently. A good rule of thumb is to run optimize if
|
||||||
you have added or modified 100,000 or more records or run more than 20 data
|
you have added or modified 100,000 or more records or run more than 20 data
|
||||||
|
|||||||
@@ -37,12 +37,3 @@ tbl.optimize({cleanupOlderThan: new Date()});
|
|||||||
```ts
|
```ts
|
||||||
deleteUnverified: boolean;
|
deleteUnverified: boolean;
|
||||||
```
|
```
|
||||||
|
|
||||||
Because they may be part of an in-progress transaction, files newer than
|
|
||||||
7 days old are not deleted by default. If you are sure that there are no
|
|
||||||
in-progress transactions, then you can set this to true to delete all
|
|
||||||
files older than `cleanupOlderThan`.
|
|
||||||
|
|
||||||
**WARNING**: This should only be set to true if you can guarantee that
|
|
||||||
no other process is currently working on this dataset. Otherwise the
|
|
||||||
dataset could be put into a corrupted state.
|
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ new EmbeddingFunction<T, M>(): EmbeddingFunction<T, M>
|
|||||||
### computeQueryEmbeddings()
|
### computeQueryEmbeddings()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
computeQueryEmbeddings(data): Promise<number[] | Uint8Array | Float32Array | Float64Array>
|
computeQueryEmbeddings(data): Promise<number[] | Float32Array | Float64Array>
|
||||||
```
|
```
|
||||||
|
|
||||||
Compute the embeddings for a single query
|
Compute the embeddings for a single query
|
||||||
@@ -63,7 +63,7 @@ Compute the embeddings for a single query
|
|||||||
|
|
||||||
#### Returns
|
#### Returns
|
||||||
|
|
||||||
`Promise`<`number`[] \| `Uint8Array` \| `Float32Array` \| `Float64Array`>
|
`Promise`<`number`[] \| `Float32Array` \| `Float64Array`>
|
||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ new TextEmbeddingFunction<M>(): TextEmbeddingFunction<M>
|
|||||||
### computeQueryEmbeddings()
|
### computeQueryEmbeddings()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
computeQueryEmbeddings(data): Promise<number[] | Uint8Array | Float32Array | Float64Array>
|
computeQueryEmbeddings(data): Promise<number[] | Float32Array | Float64Array>
|
||||||
```
|
```
|
||||||
|
|
||||||
Compute the embeddings for a single query
|
Compute the embeddings for a single query
|
||||||
@@ -48,7 +48,7 @@ Compute the embeddings for a single query
|
|||||||
|
|
||||||
#### Returns
|
#### Returns
|
||||||
|
|
||||||
`Promise`<`number`[] \| `Uint8Array` \| `Float32Array` \| `Float64Array`>
|
`Promise`<`number`[] \| `Float32Array` \| `Float64Array`>
|
||||||
|
|
||||||
#### Overrides
|
#### Overrides
|
||||||
|
|
||||||
|
|||||||
@@ -7,10 +7,5 @@
|
|||||||
# Type Alias: IntoVector
|
# Type Alias: IntoVector
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
type IntoVector:
|
type IntoVector: Float32Array | Float64Array | number[] | Promise<Float32Array | Float64Array | number[]>;
|
||||||
| Float32Array
|
|
||||||
| Float64Array
|
|
||||||
| Uint8Array
|
|
||||||
| number[]
|
|
||||||
| Promise<Float32Array | Float64Array | Uint8Array | number[]>;
|
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -36,20 +36,6 @@ is also an [asynchronous API client](#connections-asynchronous).
|
|||||||
|
|
||||||
::: lancedb.table.Tags
|
::: lancedb.table.Tags
|
||||||
|
|
||||||
## Expressions
|
|
||||||
|
|
||||||
Type-safe expression builder for filters and projections. Use these instead
|
|
||||||
of raw SQL strings with [where][lancedb.query.LanceQueryBuilder.where] and
|
|
||||||
[select][lancedb.query.LanceQueryBuilder.select].
|
|
||||||
|
|
||||||
::: lancedb.expr.Expr
|
|
||||||
|
|
||||||
::: lancedb.expr.col
|
|
||||||
|
|
||||||
::: lancedb.expr.lit
|
|
||||||
|
|
||||||
::: lancedb.expr.func
|
|
||||||
|
|
||||||
## Querying (Synchronous)
|
## Querying (Synchronous)
|
||||||
|
|
||||||
::: lancedb.query.Query
|
::: lancedb.query.Query
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
<parent>
|
<parent>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.27.2-beta.1</version>
|
<version>0.27.0-beta.5</version>
|
||||||
<relativePath>../pom.xml</relativePath>
|
<relativePath>../pom.xml</relativePath>
|
||||||
</parent>
|
</parent>
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.27.2-beta.1</version>
|
<version>0.27.0-beta.5</version>
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
<name>${project.artifactId}</name>
|
<name>${project.artifactId}</name>
|
||||||
<description>LanceDB Java SDK Parent POM</description>
|
<description>LanceDB Java SDK Parent POM</description>
|
||||||
@@ -28,7 +28,7 @@
|
|||||||
<properties>
|
<properties>
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
<arrow.version>15.0.0</arrow.version>
|
<arrow.version>15.0.0</arrow.version>
|
||||||
<lance-core.version>3.0.1</lance-core.version>
|
<lance-core.version>4.0.0-beta.11</lance-core.version>
|
||||||
<spotless.skip>false</spotless.skip>
|
<spotless.skip>false</spotless.skip>
|
||||||
<spotless.version>2.30.0</spotless.version>
|
<spotless.version>2.30.0</spotless.version>
|
||||||
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
|
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-nodejs"
|
name = "lancedb-nodejs"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
version = "0.27.2-beta.1"
|
version = "0.27.0-beta.5"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
description.workspace = true
|
description.workspace = true
|
||||||
repository.workspace = true
|
repository.workspace = true
|
||||||
@@ -15,8 +15,6 @@ crate-type = ["cdylib"]
|
|||||||
async-trait.workspace = true
|
async-trait.workspace = true
|
||||||
arrow-ipc.workspace = true
|
arrow-ipc.workspace = true
|
||||||
arrow-array.workspace = true
|
arrow-array.workspace = true
|
||||||
arrow-buffer = "57.2"
|
|
||||||
half.workspace = true
|
|
||||||
arrow-schema.workspace = true
|
arrow-schema.workspace = true
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
|
|||||||
@@ -1,110 +0,0 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
||||||
|
|
||||||
import * as tmp from "tmp";
|
|
||||||
|
|
||||||
import { type Table, connect } from "../lancedb";
|
|
||||||
import {
|
|
||||||
Field,
|
|
||||||
FixedSizeList,
|
|
||||||
Float32,
|
|
||||||
Int64,
|
|
||||||
Schema,
|
|
||||||
makeArrowTable,
|
|
||||||
} from "../lancedb/arrow";
|
|
||||||
|
|
||||||
describe("Vector query with different typed arrays", () => {
|
|
||||||
let tmpDir: tmp.DirResult;
|
|
||||||
|
|
||||||
afterEach(() => {
|
|
||||||
tmpDir?.removeCallback();
|
|
||||||
});
|
|
||||||
|
|
||||||
async function createFloat32Table(): Promise<Table> {
|
|
||||||
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
|
||||||
const db = await connect(tmpDir.name);
|
|
||||||
const schema = new Schema([
|
|
||||||
new Field("id", new Int64(), true),
|
|
||||||
new Field(
|
|
||||||
"vec",
|
|
||||||
new FixedSizeList(2, new Field("item", new Float32())),
|
|
||||||
true,
|
|
||||||
),
|
|
||||||
]);
|
|
||||||
const data = makeArrowTable(
|
|
||||||
[
|
|
||||||
{ id: 1n, vec: [1.0, 0.0] },
|
|
||||||
{ id: 2n, vec: [0.0, 1.0] },
|
|
||||||
{ id: 3n, vec: [1.0, 1.0] },
|
|
||||||
],
|
|
||||||
{ schema },
|
|
||||||
);
|
|
||||||
return db.createTable("test_f32", data);
|
|
||||||
}
|
|
||||||
|
|
||||||
it("should search with Float32Array (baseline)", async () => {
|
|
||||||
const table = await createFloat32Table();
|
|
||||||
const results = await table
|
|
||||||
.query()
|
|
||||||
.nearestTo(new Float32Array([1.0, 0.0]))
|
|
||||||
.limit(1)
|
|
||||||
.toArray();
|
|
||||||
|
|
||||||
expect(results.length).toBe(1);
|
|
||||||
expect(Number(results[0].id)).toBe(1);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should search with number[] (backward compat)", async () => {
|
|
||||||
const table = await createFloat32Table();
|
|
||||||
const results = await table
|
|
||||||
.query()
|
|
||||||
.nearestTo([1.0, 0.0])
|
|
||||||
.limit(1)
|
|
||||||
.toArray();
|
|
||||||
|
|
||||||
expect(results.length).toBe(1);
|
|
||||||
expect(Number(results[0].id)).toBe(1);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should search with Float64Array via raw path", async () => {
|
|
||||||
const table = await createFloat32Table();
|
|
||||||
const results = await table
|
|
||||||
.query()
|
|
||||||
.nearestTo(new Float64Array([1.0, 0.0]))
|
|
||||||
.limit(1)
|
|
||||||
.toArray();
|
|
||||||
|
|
||||||
expect(results.length).toBe(1);
|
|
||||||
expect(Number(results[0].id)).toBe(1);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should add multiple query vectors with Float64Array", async () => {
|
|
||||||
const table = await createFloat32Table();
|
|
||||||
const results = await table
|
|
||||||
.query()
|
|
||||||
.nearestTo(new Float64Array([1.0, 0.0]))
|
|
||||||
.addQueryVector(new Float64Array([0.0, 1.0]))
|
|
||||||
.limit(2)
|
|
||||||
.toArray();
|
|
||||||
|
|
||||||
expect(results.length).toBeGreaterThanOrEqual(2);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Float16Array is only available in Node 22+; not in TypeScript's standard lib yet
|
|
||||||
const float16ArrayCtor = (globalThis as unknown as Record<string, unknown>)
|
|
||||||
.Float16Array as (new (values: number[]) => unknown) | undefined;
|
|
||||||
const hasFloat16 = float16ArrayCtor !== undefined;
|
|
||||||
const f16it = hasFloat16 ? it : it.skip;
|
|
||||||
|
|
||||||
f16it("should search with Float16Array via raw path", async () => {
|
|
||||||
const table = await createFloat32Table();
|
|
||||||
const results = await table
|
|
||||||
.query()
|
|
||||||
.nearestTo(new float16ArrayCtor!([1.0, 0.0]) as Float32Array)
|
|
||||||
.limit(1)
|
|
||||||
.toArray();
|
|
||||||
|
|
||||||
expect(results.length).toBe(1);
|
|
||||||
expect(Number(results[0].id)).toBe(1);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
@@ -117,9 +117,8 @@ export type TableLike =
|
|||||||
export type IntoVector =
|
export type IntoVector =
|
||||||
| Float32Array
|
| Float32Array
|
||||||
| Float64Array
|
| Float64Array
|
||||||
| Uint8Array
|
|
||||||
| number[]
|
| number[]
|
||||||
| Promise<Float32Array | Float64Array | Uint8Array | number[]>;
|
| Promise<Float32Array | Float64Array | number[]>;
|
||||||
|
|
||||||
export type MultiVector = IntoVector[];
|
export type MultiVector = IntoVector[];
|
||||||
|
|
||||||
@@ -127,48 +126,14 @@ export function isMultiVector(value: unknown): value is MultiVector {
|
|||||||
return Array.isArray(value) && isIntoVector(value[0]);
|
return Array.isArray(value) && isIntoVector(value[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Float16Array is not in TypeScript's standard lib yet; access dynamically
|
|
||||||
type Float16ArrayCtor = new (
|
|
||||||
...args: unknown[]
|
|
||||||
) => { buffer: ArrayBuffer; byteOffset: number; byteLength: number };
|
|
||||||
const float16ArrayCtor = (globalThis as unknown as Record<string, unknown>)
|
|
||||||
.Float16Array as Float16ArrayCtor | undefined;
|
|
||||||
|
|
||||||
export function isIntoVector(value: unknown): value is IntoVector {
|
export function isIntoVector(value: unknown): value is IntoVector {
|
||||||
return (
|
return (
|
||||||
value instanceof Float32Array ||
|
value instanceof Float32Array ||
|
||||||
value instanceof Float64Array ||
|
value instanceof Float64Array ||
|
||||||
value instanceof Uint8Array ||
|
|
||||||
(float16ArrayCtor !== undefined && value instanceof float16ArrayCtor) ||
|
|
||||||
(Array.isArray(value) && !Array.isArray(value[0]))
|
(Array.isArray(value) && !Array.isArray(value[0]))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract the underlying byte buffer and data type from a typed array
|
|
||||||
* for passing to the Rust NAPI layer without precision loss.
|
|
||||||
*/
|
|
||||||
export function extractVectorBuffer(
|
|
||||||
vector: Float32Array | Float64Array | Uint8Array,
|
|
||||||
): { data: Uint8Array; dtype: string } | null {
|
|
||||||
if (float16ArrayCtor !== undefined && vector instanceof float16ArrayCtor) {
|
|
||||||
return {
|
|
||||||
data: new Uint8Array(vector.buffer, vector.byteOffset, vector.byteLength),
|
|
||||||
dtype: "float16",
|
|
||||||
};
|
|
||||||
}
|
|
||||||
if (vector instanceof Float64Array) {
|
|
||||||
return {
|
|
||||||
data: new Uint8Array(vector.buffer, vector.byteOffset, vector.byteLength),
|
|
||||||
dtype: "float64",
|
|
||||||
};
|
|
||||||
}
|
|
||||||
if (vector instanceof Uint8Array && !(vector instanceof Float32Array)) {
|
|
||||||
return { data: vector, dtype: "uint8" };
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function isArrowTable(value: object): value is TableLike {
|
export function isArrowTable(value: object): value is TableLike {
|
||||||
if (value instanceof ArrowTable) return true;
|
if (value instanceof ArrowTable) return true;
|
||||||
return "schema" in value && "batches" in value;
|
return "schema" in value && "batches" in value;
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import {
|
|||||||
Table as ArrowTable,
|
Table as ArrowTable,
|
||||||
type IntoVector,
|
type IntoVector,
|
||||||
RecordBatch,
|
RecordBatch,
|
||||||
extractVectorBuffer,
|
|
||||||
fromBufferToRecordBatch,
|
fromBufferToRecordBatch,
|
||||||
fromRecordBatchToBuffer,
|
fromRecordBatchToBuffer,
|
||||||
tableFromIPC,
|
tableFromIPC,
|
||||||
@@ -662,8 +661,10 @@ export class VectorQuery extends StandardQueryBase<NativeVectorQuery> {
|
|||||||
const res = (async () => {
|
const res = (async () => {
|
||||||
try {
|
try {
|
||||||
const v = await vector;
|
const v = await vector;
|
||||||
|
const arr = Float32Array.from(v);
|
||||||
|
//
|
||||||
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
|
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
|
||||||
const value: any = this.addQueryVector(v);
|
const value: any = this.addQueryVector(arr);
|
||||||
const inner = value.inner as
|
const inner = value.inner as
|
||||||
| NativeVectorQuery
|
| NativeVectorQuery
|
||||||
| Promise<NativeVectorQuery>;
|
| Promise<NativeVectorQuery>;
|
||||||
@@ -675,12 +676,7 @@ export class VectorQuery extends StandardQueryBase<NativeVectorQuery> {
|
|||||||
return new VectorQuery(res);
|
return new VectorQuery(res);
|
||||||
} else {
|
} else {
|
||||||
super.doCall((inner) => {
|
super.doCall((inner) => {
|
||||||
const raw = Array.isArray(vector) ? null : extractVectorBuffer(vector);
|
inner.addQueryVector(Float32Array.from(vector));
|
||||||
if (raw) {
|
|
||||||
inner.addQueryVectorRaw(raw.data, raw.dtype);
|
|
||||||
} else {
|
|
||||||
inner.addQueryVector(Float32Array.from(vector as number[]));
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
@@ -769,23 +765,14 @@ export class Query extends StandardQueryBase<NativeQuery> {
|
|||||||
* a default `limit` of 10 will be used. @see {@link Query#limit}
|
* a default `limit` of 10 will be used. @see {@link Query#limit}
|
||||||
*/
|
*/
|
||||||
nearestTo(vector: IntoVector): VectorQuery {
|
nearestTo(vector: IntoVector): VectorQuery {
|
||||||
const callNearestTo = (
|
|
||||||
inner: NativeQuery,
|
|
||||||
resolved: Float32Array | Float64Array | Uint8Array | number[],
|
|
||||||
): NativeVectorQuery => {
|
|
||||||
const raw = Array.isArray(resolved)
|
|
||||||
? null
|
|
||||||
: extractVectorBuffer(resolved);
|
|
||||||
if (raw) {
|
|
||||||
return inner.nearestToRaw(raw.data, raw.dtype);
|
|
||||||
}
|
|
||||||
return inner.nearestTo(Float32Array.from(resolved as number[]));
|
|
||||||
};
|
|
||||||
|
|
||||||
if (this.inner instanceof Promise) {
|
if (this.inner instanceof Promise) {
|
||||||
const nativeQuery = this.inner.then(async (inner) => {
|
const nativeQuery = this.inner.then(async (inner) => {
|
||||||
const resolved = vector instanceof Promise ? await vector : vector;
|
if (vector instanceof Promise) {
|
||||||
return callNearestTo(inner, resolved);
|
const arr = await vector.then((v) => Float32Array.from(v));
|
||||||
|
return inner.nearestTo(arr);
|
||||||
|
} else {
|
||||||
|
return inner.nearestTo(Float32Array.from(vector));
|
||||||
|
}
|
||||||
});
|
});
|
||||||
return new VectorQuery(nativeQuery);
|
return new VectorQuery(nativeQuery);
|
||||||
}
|
}
|
||||||
@@ -793,8 +780,10 @@ export class Query extends StandardQueryBase<NativeQuery> {
|
|||||||
const res = (async () => {
|
const res = (async () => {
|
||||||
try {
|
try {
|
||||||
const v = await vector;
|
const v = await vector;
|
||||||
|
const arr = Float32Array.from(v);
|
||||||
|
//
|
||||||
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
|
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
|
||||||
const value: any = this.nearestTo(v);
|
const value: any = this.nearestTo(arr);
|
||||||
const inner = value.inner as
|
const inner = value.inner as
|
||||||
| NativeVectorQuery
|
| NativeVectorQuery
|
||||||
| Promise<NativeVectorQuery>;
|
| Promise<NativeVectorQuery>;
|
||||||
@@ -805,7 +794,7 @@ export class Query extends StandardQueryBase<NativeQuery> {
|
|||||||
})();
|
})();
|
||||||
return new VectorQuery(res);
|
return new VectorQuery(res);
|
||||||
} else {
|
} else {
|
||||||
const vectorQuery = callNearestTo(this.inner, vector);
|
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
|
||||||
return new VectorQuery(vectorQuery);
|
return new VectorQuery(vectorQuery);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-arm64",
|
"name": "@lancedb/lancedb-darwin-arm64",
|
||||||
"version": "0.27.2-beta.1",
|
"version": "0.27.0-beta.5",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.darwin-arm64.node",
|
"main": "lancedb.darwin-arm64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||||
"version": "0.27.2-beta.1",
|
"version": "0.27.0-beta.5",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-gnu.node",
|
"main": "lancedb.linux-arm64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||||
"version": "0.27.2-beta.1",
|
"version": "0.27.0-beta.5",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-musl.node",
|
"main": "lancedb.linux-arm64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||||
"version": "0.27.2-beta.1",
|
"version": "0.27.0-beta.5",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-gnu.node",
|
"main": "lancedb.linux-x64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||||
"version": "0.27.2-beta.1",
|
"version": "0.27.0-beta.5",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-musl.node",
|
"main": "lancedb.linux-x64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||||
"version": "0.27.2-beta.1",
|
"version": "0.27.0-beta.5",
|
||||||
"os": [
|
"os": [
|
||||||
"win32"
|
"win32"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||||
"version": "0.27.2-beta.1",
|
"version": "0.27.0-beta.5",
|
||||||
"os": ["win32"],
|
"os": ["win32"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.win32-x64-msvc.node",
|
"main": "lancedb.win32-x64-msvc.node",
|
||||||
|
|||||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.27.2-beta.1",
|
"version": "0.27.0-beta.5",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.27.2-beta.1",
|
"version": "0.27.0-beta.5",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
"ann"
|
"ann"
|
||||||
],
|
],
|
||||||
"private": false,
|
"private": false,
|
||||||
"version": "0.27.2-beta.1",
|
"version": "0.27.0-beta.5",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"exports": {
|
"exports": {
|
||||||
".": "./dist/index.js",
|
".": "./dist/index.js",
|
||||||
|
|||||||
@@ -3,12 +3,6 @@
|
|||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use arrow_array::{
|
|
||||||
Array, Float16Array as ArrowFloat16Array, Float32Array as ArrowFloat32Array,
|
|
||||||
Float64Array as ArrowFloat64Array, UInt8Array as ArrowUInt8Array,
|
|
||||||
};
|
|
||||||
use arrow_buffer::ScalarBuffer;
|
|
||||||
use half::f16;
|
|
||||||
use lancedb::index::scalar::{
|
use lancedb::index::scalar::{
|
||||||
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
|
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
|
||||||
Operator, PhraseQuery,
|
Operator, PhraseQuery,
|
||||||
@@ -30,33 +24,6 @@ use crate::rerankers::RerankHybridCallbackArgs;
|
|||||||
use crate::rerankers::Reranker;
|
use crate::rerankers::Reranker;
|
||||||
use crate::util::{parse_distance_type, schema_to_buffer};
|
use crate::util::{parse_distance_type, schema_to_buffer};
|
||||||
|
|
||||||
fn bytes_to_arrow_array(data: Uint8Array, dtype: String) -> napi::Result<Arc<dyn Array>> {
|
|
||||||
let buf = arrow_buffer::Buffer::from(data.to_vec());
|
|
||||||
let num_bytes = buf.len();
|
|
||||||
match dtype.as_str() {
|
|
||||||
"float16" => {
|
|
||||||
let scalar_buf = ScalarBuffer::<f16>::new(buf, 0, num_bytes / 2);
|
|
||||||
Ok(Arc::new(ArrowFloat16Array::new(scalar_buf, None)))
|
|
||||||
}
|
|
||||||
"float32" => {
|
|
||||||
let scalar_buf = ScalarBuffer::<f32>::new(buf, 0, num_bytes / 4);
|
|
||||||
Ok(Arc::new(ArrowFloat32Array::new(scalar_buf, None)))
|
|
||||||
}
|
|
||||||
"float64" => {
|
|
||||||
let scalar_buf = ScalarBuffer::<f64>::new(buf, 0, num_bytes / 8);
|
|
||||||
Ok(Arc::new(ArrowFloat64Array::new(scalar_buf, None)))
|
|
||||||
}
|
|
||||||
"uint8" => {
|
|
||||||
let scalar_buf = ScalarBuffer::<u8>::new(buf, 0, num_bytes);
|
|
||||||
Ok(Arc::new(ArrowUInt8Array::new(scalar_buf, None)))
|
|
||||||
}
|
|
||||||
_ => Err(napi::Error::from_reason(format!(
|
|
||||||
"Unsupported vector dtype: {}. Expected one of: float16, float32, float64, uint8",
|
|
||||||
dtype
|
|
||||||
))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub struct Query {
|
pub struct Query {
|
||||||
inner: LanceDbQuery,
|
inner: LanceDbQuery,
|
||||||
@@ -111,13 +78,6 @@ impl Query {
|
|||||||
Ok(VectorQuery { inner })
|
Ok(VectorQuery { inner })
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
|
||||||
pub fn nearest_to_raw(&mut self, data: Uint8Array, dtype: String) -> Result<VectorQuery> {
|
|
||||||
let array = bytes_to_arrow_array(data, dtype)?;
|
|
||||||
let inner = self.inner.clone().nearest_to(array).default_error()?;
|
|
||||||
Ok(VectorQuery { inner })
|
|
||||||
}
|
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub fn fast_search(&mut self) {
|
pub fn fast_search(&mut self) {
|
||||||
self.inner = self.inner.clone().fast_search();
|
self.inner = self.inner.clone().fast_search();
|
||||||
@@ -203,13 +163,6 @@ impl VectorQuery {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
|
||||||
pub fn add_query_vector_raw(&mut self, data: Uint8Array, dtype: String) -> Result<()> {
|
|
||||||
let array = bytes_to_arrow_array(data, dtype)?;
|
|
||||||
self.inner = self.inner.clone().add_query_vector(array).default_error()?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> {
|
pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> {
|
||||||
let distance_type = parse_distance_type(distance_type)?;
|
let distance_type = parse_distance_type(distance_type)?;
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.30.2"
|
current_version = "0.30.0-beta.5"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
2
python/.gitignore
vendored
2
python/.gitignore
vendored
@@ -1,5 +1,3 @@
|
|||||||
# Test data created by some example tests
|
# Test data created by some example tests
|
||||||
data/
|
data/
|
||||||
_lancedb.pyd
|
_lancedb.pyd
|
||||||
# macOS debug symbols bundle generated during build
|
|
||||||
*.dSYM/
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-python"
|
name = "lancedb-python"
|
||||||
version = "0.30.2"
|
version = "0.30.0-beta.5"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "Python bindings for LanceDB"
|
description = "Python bindings for LanceDB"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
@@ -23,7 +23,6 @@ lance-namespace.workspace = true
|
|||||||
lance-namespace-impls.workspace = true
|
lance-namespace-impls.workspace = true
|
||||||
lance-io.workspace = true
|
lance-io.workspace = true
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
log.workspace = true
|
|
||||||
pyo3 = { version = "0.26", features = ["extension-module", "abi3-py39"] }
|
pyo3 = { version = "0.26", features = ["extension-module", "abi3-py39"] }
|
||||||
pyo3-async-runtimes = { version = "0.26", features = [
|
pyo3-async-runtimes = { version = "0.26", features = [
|
||||||
"attributes",
|
"attributes",
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection
|
|||||||
from .io import StorageOptionsProvider
|
from .io import StorageOptionsProvider
|
||||||
from .remote import ClientConfig
|
from .remote import ClientConfig
|
||||||
from .remote.db import RemoteDBConnection
|
from .remote.db import RemoteDBConnection
|
||||||
from .expr import Expr, col, lit, func
|
|
||||||
from .schema import vector
|
from .schema import vector
|
||||||
from .table import AsyncTable, Table
|
from .table import AsyncTable, Table
|
||||||
from ._lancedb import Session
|
from ._lancedb import Session
|
||||||
@@ -272,10 +271,6 @@ __all__ = [
|
|||||||
"AsyncConnection",
|
"AsyncConnection",
|
||||||
"AsyncLanceNamespaceDBConnection",
|
"AsyncLanceNamespaceDBConnection",
|
||||||
"AsyncTable",
|
"AsyncTable",
|
||||||
"col",
|
|
||||||
"Expr",
|
|
||||||
"func",
|
|
||||||
"lit",
|
|
||||||
"URI",
|
"URI",
|
||||||
"sanitize_uri",
|
"sanitize_uri",
|
||||||
"vector",
|
"vector",
|
||||||
|
|||||||
@@ -27,32 +27,6 @@ from .remote import ClientConfig
|
|||||||
IvfHnswPq: type[HnswPq] = HnswPq
|
IvfHnswPq: type[HnswPq] = HnswPq
|
||||||
IvfHnswSq: type[HnswSq] = HnswSq
|
IvfHnswSq: type[HnswSq] = HnswSq
|
||||||
|
|
||||||
class PyExpr:
|
|
||||||
"""A type-safe DataFusion expression node (Rust-side handle)."""
|
|
||||||
|
|
||||||
def eq(self, other: "PyExpr") -> "PyExpr": ...
|
|
||||||
def ne(self, other: "PyExpr") -> "PyExpr": ...
|
|
||||||
def lt(self, other: "PyExpr") -> "PyExpr": ...
|
|
||||||
def lte(self, other: "PyExpr") -> "PyExpr": ...
|
|
||||||
def gt(self, other: "PyExpr") -> "PyExpr": ...
|
|
||||||
def gte(self, other: "PyExpr") -> "PyExpr": ...
|
|
||||||
def and_(self, other: "PyExpr") -> "PyExpr": ...
|
|
||||||
def or_(self, other: "PyExpr") -> "PyExpr": ...
|
|
||||||
def not_(self) -> "PyExpr": ...
|
|
||||||
def add(self, other: "PyExpr") -> "PyExpr": ...
|
|
||||||
def sub(self, other: "PyExpr") -> "PyExpr": ...
|
|
||||||
def mul(self, other: "PyExpr") -> "PyExpr": ...
|
|
||||||
def div(self, other: "PyExpr") -> "PyExpr": ...
|
|
||||||
def lower(self) -> "PyExpr": ...
|
|
||||||
def upper(self) -> "PyExpr": ...
|
|
||||||
def contains(self, substr: "PyExpr") -> "PyExpr": ...
|
|
||||||
def cast(self, data_type: pa.DataType) -> "PyExpr": ...
|
|
||||||
def to_sql(self) -> str: ...
|
|
||||||
|
|
||||||
def expr_col(name: str) -> PyExpr: ...
|
|
||||||
def expr_lit(value: Union[bool, int, float, str]) -> PyExpr: ...
|
|
||||||
def expr_func(name: str, args: List[PyExpr]) -> PyExpr: ...
|
|
||||||
|
|
||||||
class Session:
|
class Session:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -161,10 +135,7 @@ class Table:
|
|||||||
def close(self) -> None: ...
|
def close(self) -> None: ...
|
||||||
async def schema(self) -> pa.Schema: ...
|
async def schema(self) -> pa.Schema: ...
|
||||||
async def add(
|
async def add(
|
||||||
self,
|
self, data: pa.RecordBatchReader, mode: Literal["append", "overwrite"]
|
||||||
data: pa.RecordBatchReader,
|
|
||||||
mode: Literal["append", "overwrite"],
|
|
||||||
progress: Optional[Any] = None,
|
|
||||||
) -> AddResult: ...
|
) -> AddResult: ...
|
||||||
async def update(
|
async def update(
|
||||||
self, updates: Dict[str, str], where: Optional[str]
|
self, updates: Dict[str, str], where: Optional[str]
|
||||||
@@ -251,9 +222,7 @@ class RecordBatchStream:
|
|||||||
|
|
||||||
class Query:
|
class Query:
|
||||||
def where(self, filter: str): ...
|
def where(self, filter: str): ...
|
||||||
def where_expr(self, expr: PyExpr): ...
|
def select(self, columns: Tuple[str, str]): ...
|
||||||
def select(self, columns: List[Tuple[str, str]]): ...
|
|
||||||
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
|
|
||||||
def select_columns(self, columns: List[str]): ...
|
def select_columns(self, columns: List[str]): ...
|
||||||
def limit(self, limit: int): ...
|
def limit(self, limit: int): ...
|
||||||
def offset(self, offset: int): ...
|
def offset(self, offset: int): ...
|
||||||
@@ -279,9 +248,7 @@ class TakeQuery:
|
|||||||
|
|
||||||
class FTSQuery:
|
class FTSQuery:
|
||||||
def where(self, filter: str): ...
|
def where(self, filter: str): ...
|
||||||
def where_expr(self, expr: PyExpr): ...
|
def select(self, columns: List[str]): ...
|
||||||
def select(self, columns: List[Tuple[str, str]]): ...
|
|
||||||
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
|
|
||||||
def limit(self, limit: int): ...
|
def limit(self, limit: int): ...
|
||||||
def offset(self, offset: int): ...
|
def offset(self, offset: int): ...
|
||||||
def fast_search(self): ...
|
def fast_search(self): ...
|
||||||
@@ -300,9 +267,7 @@ class VectorQuery:
|
|||||||
async def output_schema(self) -> pa.Schema: ...
|
async def output_schema(self) -> pa.Schema: ...
|
||||||
async def execute(self) -> RecordBatchStream: ...
|
async def execute(self) -> RecordBatchStream: ...
|
||||||
def where(self, filter: str): ...
|
def where(self, filter: str): ...
|
||||||
def where_expr(self, expr: PyExpr): ...
|
def select(self, columns: List[str]): ...
|
||||||
def select(self, columns: List[Tuple[str, str]]): ...
|
|
||||||
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
|
|
||||||
def select_with_projection(self, columns: Tuple[str, str]): ...
|
def select_with_projection(self, columns: Tuple[str, str]): ...
|
||||||
def limit(self, limit: int): ...
|
def limit(self, limit: int): ...
|
||||||
def offset(self, offset: int): ...
|
def offset(self, offset: int): ...
|
||||||
@@ -319,9 +284,7 @@ class VectorQuery:
|
|||||||
|
|
||||||
class HybridQuery:
|
class HybridQuery:
|
||||||
def where(self, filter: str): ...
|
def where(self, filter: str): ...
|
||||||
def where_expr(self, expr: PyExpr): ...
|
def select(self, columns: List[str]): ...
|
||||||
def select(self, columns: List[Tuple[str, str]]): ...
|
|
||||||
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
|
|
||||||
def limit(self, limit: int): ...
|
def limit(self, limit: int): ...
|
||||||
def offset(self, offset: int): ...
|
def offset(self, offset: int): ...
|
||||||
def fast_search(self): ...
|
def fast_search(self): ...
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import urllib.error
|
import urllib.error
|
||||||
import urllib.request
|
|
||||||
import weakref
|
import weakref
|
||||||
import logging
|
import logging
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|||||||
@@ -1,298 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
||||||
|
|
||||||
"""Type-safe expression builder for filters and projections.
|
|
||||||
|
|
||||||
Instead of writing raw SQL strings you can build expressions with Python
|
|
||||||
operators::
|
|
||||||
|
|
||||||
from lancedb.expr import col, lit
|
|
||||||
|
|
||||||
# filter: age > 18 AND status = 'active'
|
|
||||||
filt = (col("age") > lit(18)) & (col("status") == lit("active"))
|
|
||||||
|
|
||||||
# projection: compute a derived column
|
|
||||||
proj = {"score": col("raw_score") * lit(1.5)}
|
|
||||||
|
|
||||||
table.search().where(filt).select(proj).to_list()
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import pyarrow as pa
|
|
||||||
|
|
||||||
from lancedb._lancedb import PyExpr, expr_col, expr_lit, expr_func
|
|
||||||
|
|
||||||
__all__ = ["Expr", "col", "lit", "func"]
|
|
||||||
|
|
||||||
_STR_TO_PA_TYPE: dict = {
|
|
||||||
"bool": pa.bool_(),
|
|
||||||
"boolean": pa.bool_(),
|
|
||||||
"int8": pa.int8(),
|
|
||||||
"int16": pa.int16(),
|
|
||||||
"int32": pa.int32(),
|
|
||||||
"int64": pa.int64(),
|
|
||||||
"uint8": pa.uint8(),
|
|
||||||
"uint16": pa.uint16(),
|
|
||||||
"uint32": pa.uint32(),
|
|
||||||
"uint64": pa.uint64(),
|
|
||||||
"float16": pa.float16(),
|
|
||||||
"float32": pa.float32(),
|
|
||||||
"float": pa.float32(),
|
|
||||||
"float64": pa.float64(),
|
|
||||||
"double": pa.float64(),
|
|
||||||
"string": pa.string(),
|
|
||||||
"utf8": pa.string(),
|
|
||||||
"str": pa.string(),
|
|
||||||
"large_string": pa.large_utf8(),
|
|
||||||
"large_utf8": pa.large_utf8(),
|
|
||||||
"date32": pa.date32(),
|
|
||||||
"date": pa.date32(),
|
|
||||||
"date64": pa.date64(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _coerce(value: "ExprLike") -> "Expr":
|
|
||||||
"""Return *value* as an :class:`Expr`, wrapping plain Python values via
|
|
||||||
:func:`lit` if needed."""
|
|
||||||
if isinstance(value, Expr):
|
|
||||||
return value
|
|
||||||
return lit(value)
|
|
||||||
|
|
||||||
|
|
||||||
# Type alias used in annotations.
|
|
||||||
ExprLike = Union["Expr", bool, int, float, str]
|
|
||||||
|
|
||||||
|
|
||||||
class Expr:
|
|
||||||
"""A type-safe expression node.
|
|
||||||
|
|
||||||
Construct instances with :func:`col` and :func:`lit`, then combine them
|
|
||||||
using Python operators or the named methods below.
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
>>> from lancedb.expr import col, lit
|
|
||||||
>>> filt = (col("age") > lit(18)) & (col("name").lower() == lit("alice"))
|
|
||||||
>>> proj = {"double": col("x") * lit(2)}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Make Expr unhashable so that == returns an Expr rather than being used
|
|
||||||
# for dict keys / set membership.
|
|
||||||
__hash__ = None # type: ignore[assignment]
|
|
||||||
|
|
||||||
def __init__(self, inner: PyExpr) -> None:
|
|
||||||
self._inner = inner
|
|
||||||
|
|
||||||
# ── comparisons ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def __eq__(self, other: ExprLike) -> "Expr": # type: ignore[override]
|
|
||||||
"""Equal to (``col("x") == 1``)."""
|
|
||||||
return Expr(self._inner.eq(_coerce(other)._inner))
|
|
||||||
|
|
||||||
def __ne__(self, other: ExprLike) -> "Expr": # type: ignore[override]
|
|
||||||
"""Not equal to (``col("x") != 1``)."""
|
|
||||||
return Expr(self._inner.ne(_coerce(other)._inner))
|
|
||||||
|
|
||||||
def __lt__(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Less than (``col("x") < 1``)."""
|
|
||||||
return Expr(self._inner.lt(_coerce(other)._inner))
|
|
||||||
|
|
||||||
def __le__(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Less than or equal to (``col("x") <= 1``)."""
|
|
||||||
return Expr(self._inner.lte(_coerce(other)._inner))
|
|
||||||
|
|
||||||
def __gt__(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Greater than (``col("x") > 1``)."""
|
|
||||||
return Expr(self._inner.gt(_coerce(other)._inner))
|
|
||||||
|
|
||||||
def __ge__(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Greater than or equal to (``col("x") >= 1``)."""
|
|
||||||
return Expr(self._inner.gte(_coerce(other)._inner))
|
|
||||||
|
|
||||||
# ── logical ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def __and__(self, other: "Expr") -> "Expr":
|
|
||||||
"""Logical AND (``expr_a & expr_b``)."""
|
|
||||||
return Expr(self._inner.and_(_coerce(other)._inner))
|
|
||||||
|
|
||||||
def __or__(self, other: "Expr") -> "Expr":
|
|
||||||
"""Logical OR (``expr_a | expr_b``)."""
|
|
||||||
return Expr(self._inner.or_(_coerce(other)._inner))
|
|
||||||
|
|
||||||
def __invert__(self) -> "Expr":
|
|
||||||
"""Logical NOT (``~expr``)."""
|
|
||||||
return Expr(self._inner.not_())
|
|
||||||
|
|
||||||
# ── arithmetic ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def __add__(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Add (``col("x") + 1``)."""
|
|
||||||
return Expr(self._inner.add(_coerce(other)._inner))
|
|
||||||
|
|
||||||
def __radd__(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Right-hand add (``1 + col("x")``)."""
|
|
||||||
return Expr(_coerce(other)._inner.add(self._inner))
|
|
||||||
|
|
||||||
def __sub__(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Subtract (``col("x") - 1``)."""
|
|
||||||
return Expr(self._inner.sub(_coerce(other)._inner))
|
|
||||||
|
|
||||||
def __rsub__(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Right-hand subtract (``1 - col("x")``)."""
|
|
||||||
return Expr(_coerce(other)._inner.sub(self._inner))
|
|
||||||
|
|
||||||
def __mul__(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Multiply (``col("x") * 2``)."""
|
|
||||||
return Expr(self._inner.mul(_coerce(other)._inner))
|
|
||||||
|
|
||||||
def __rmul__(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Right-hand multiply (``2 * col("x")``)."""
|
|
||||||
return Expr(_coerce(other)._inner.mul(self._inner))
|
|
||||||
|
|
||||||
def __truediv__(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Divide (``col("x") / 2``)."""
|
|
||||||
return Expr(self._inner.div(_coerce(other)._inner))
|
|
||||||
|
|
||||||
def __rtruediv__(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Right-hand divide (``1 / col("x")``)."""
|
|
||||||
return Expr(_coerce(other)._inner.div(self._inner))
|
|
||||||
|
|
||||||
# ── string methods ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def lower(self) -> "Expr":
|
|
||||||
"""Convert string column values to lowercase."""
|
|
||||||
return Expr(self._inner.lower())
|
|
||||||
|
|
||||||
def upper(self) -> "Expr":
|
|
||||||
"""Convert string column values to uppercase."""
|
|
||||||
return Expr(self._inner.upper())
|
|
||||||
|
|
||||||
def contains(self, substr: "ExprLike") -> "Expr":
|
|
||||||
"""Return True where the string contains *substr*."""
|
|
||||||
return Expr(self._inner.contains(_coerce(substr)._inner))
|
|
||||||
|
|
||||||
# ── type cast ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def cast(self, data_type: Union[str, "pa.DataType"]) -> "Expr":
|
|
||||||
"""Cast values to *data_type*.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
data_type:
|
|
||||||
A PyArrow ``DataType`` (e.g. ``pa.int32()``) or one of the type
|
|
||||||
name strings: ``"bool"``, ``"int8"``, ``"int16"``, ``"int32"``,
|
|
||||||
``"int64"``, ``"uint8"``–``"uint64"``, ``"float32"``,
|
|
||||||
``"float64"``, ``"string"``, ``"date32"``, ``"date64"``.
|
|
||||||
"""
|
|
||||||
if isinstance(data_type, str):
|
|
||||||
try:
|
|
||||||
data_type = _STR_TO_PA_TYPE[data_type]
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError(
|
|
||||||
f"unsupported data type: '{data_type}'. Supported: "
|
|
||||||
f"{', '.join(_STR_TO_PA_TYPE)}"
|
|
||||||
)
|
|
||||||
return Expr(self._inner.cast(data_type))
|
|
||||||
|
|
||||||
# ── named comparison helpers (alternative to operators) ──────────────────
|
|
||||||
|
|
||||||
def eq(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Equal to."""
|
|
||||||
return self.__eq__(other)
|
|
||||||
|
|
||||||
def ne(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Not equal to."""
|
|
||||||
return self.__ne__(other)
|
|
||||||
|
|
||||||
def lt(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Less than."""
|
|
||||||
return self.__lt__(other)
|
|
||||||
|
|
||||||
def lte(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Less than or equal to."""
|
|
||||||
return self.__le__(other)
|
|
||||||
|
|
||||||
def gt(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Greater than."""
|
|
||||||
return self.__gt__(other)
|
|
||||||
|
|
||||||
def gte(self, other: ExprLike) -> "Expr":
|
|
||||||
"""Greater than or equal to."""
|
|
||||||
return self.__ge__(other)
|
|
||||||
|
|
||||||
def and_(self, other: "Expr") -> "Expr":
|
|
||||||
"""Logical AND."""
|
|
||||||
return self.__and__(other)
|
|
||||||
|
|
||||||
def or_(self, other: "Expr") -> "Expr":
|
|
||||||
"""Logical OR."""
|
|
||||||
return self.__or__(other)
|
|
||||||
|
|
||||||
# ── utilities ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def to_sql(self) -> str:
|
|
||||||
"""Render the expression as a SQL string (useful for debugging)."""
|
|
||||||
return self._inner.to_sql()
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"Expr({self._inner.to_sql()})"
|
|
||||||
|
|
||||||
|
|
||||||
# ── free functions ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def col(name: str) -> Expr:
|
|
||||||
"""Reference a table column by name.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
name:
|
|
||||||
The column name.
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
>>> from lancedb.expr import col, lit
|
|
||||||
>>> col("age") > lit(18)
|
|
||||||
Expr((age > 18))
|
|
||||||
"""
|
|
||||||
return Expr(expr_col(name))
|
|
||||||
|
|
||||||
|
|
||||||
def lit(value: Union[bool, int, float, str]) -> Expr:
|
|
||||||
"""Create a literal (constant) value expression.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
value:
|
|
||||||
A Python ``bool``, ``int``, ``float``, or ``str``.
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
>>> from lancedb.expr import col, lit
|
|
||||||
>>> col("price") * lit(1.1)
|
|
||||||
Expr((price * 1.1))
|
|
||||||
"""
|
|
||||||
return Expr(expr_lit(value))
|
|
||||||
|
|
||||||
|
|
||||||
def func(name: str, *args: ExprLike) -> Expr:
|
|
||||||
"""Call an arbitrary SQL function by name.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
name:
|
|
||||||
The SQL function name (e.g. ``"lower"``, ``"upper"``).
|
|
||||||
*args:
|
|
||||||
The function arguments as :class:`Expr` or plain Python literals.
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
>>> from lancedb.expr import col, func
|
|
||||||
>>> func("lower", col("name"))
|
|
||||||
Expr(lower(name))
|
|
||||||
"""
|
|
||||||
inner_args = [_coerce(a)._inner for a in args]
|
|
||||||
return Expr(expr_func(name, inner_args))
|
|
||||||
@@ -38,7 +38,6 @@ from .rerankers.base import Reranker
|
|||||||
from .rerankers.rrf import RRFReranker
|
from .rerankers.rrf import RRFReranker
|
||||||
from .rerankers.util import check_reranker_result
|
from .rerankers.util import check_reranker_result
|
||||||
from .util import flatten_columns
|
from .util import flatten_columns
|
||||||
from .expr import Expr
|
|
||||||
from lancedb._lancedb import fts_query_to_json
|
from lancedb._lancedb import fts_query_to_json
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
@@ -71,7 +70,7 @@ def ensure_vector_query(
|
|||||||
) -> Union[List[float], List[List[float]], pa.Array, List[pa.Array]]:
|
) -> Union[List[float], List[List[float]], pa.Array, List[pa.Array]]:
|
||||||
if isinstance(val, list):
|
if isinstance(val, list):
|
||||||
if len(val) == 0:
|
if len(val) == 0:
|
||||||
raise ValueError("Vector query must be a non-empty list")
|
return ValueError("Vector query must be a non-empty list")
|
||||||
sample = val[0]
|
sample = val[0]
|
||||||
else:
|
else:
|
||||||
if isinstance(val, float):
|
if isinstance(val, float):
|
||||||
@@ -84,7 +83,7 @@ def ensure_vector_query(
|
|||||||
return val
|
return val
|
||||||
if isinstance(sample, list):
|
if isinstance(sample, list):
|
||||||
if len(sample) == 0:
|
if len(sample) == 0:
|
||||||
raise ValueError("Vector query must be a non-empty list")
|
return ValueError("Vector query must be a non-empty list")
|
||||||
if isinstance(sample[0], float):
|
if isinstance(sample[0], float):
|
||||||
# val is list of list of floats
|
# val is list of list of floats
|
||||||
return val
|
return val
|
||||||
@@ -450,8 +449,8 @@ class Query(pydantic.BaseModel):
|
|||||||
ensure_vector_query,
|
ensure_vector_query,
|
||||||
] = None
|
] = None
|
||||||
|
|
||||||
# sql filter or type-safe Expr to refine the query with
|
# sql filter to refine the query with
|
||||||
filter: Optional[Union[str, Expr]] = None
|
filter: Optional[str] = None
|
||||||
|
|
||||||
# if True then apply the filter after vector search
|
# if True then apply the filter after vector search
|
||||||
postfilter: Optional[bool] = None
|
postfilter: Optional[bool] = None
|
||||||
@@ -465,8 +464,8 @@ class Query(pydantic.BaseModel):
|
|||||||
# distance type to use for vector search
|
# distance type to use for vector search
|
||||||
distance_type: Optional[str] = None
|
distance_type: Optional[str] = None
|
||||||
|
|
||||||
# which columns to return in the results (dict values may be str or Expr)
|
# which columns to return in the results
|
||||||
columns: Optional[Union[List[str], Dict[str, Union[str, Expr]]]] = None
|
columns: Optional[Union[List[str], Dict[str, str]]] = None
|
||||||
|
|
||||||
# minimum number of IVF partitions to search
|
# minimum number of IVF partitions to search
|
||||||
#
|
#
|
||||||
@@ -857,15 +856,14 @@ class LanceQueryBuilder(ABC):
|
|||||||
self._offset = offset
|
self._offset = offset
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def select(self, columns: Union[list[str], dict[str, Union[str, Expr]]]) -> Self:
|
def select(self, columns: Union[list[str], dict[str, str]]) -> Self:
|
||||||
"""Set the columns to return.
|
"""Set the columns to return.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
columns: list of str, or dict of str to str or Expr
|
columns: list of str, or dict of str to str default None
|
||||||
List of column names to be fetched.
|
List of column names to be fetched.
|
||||||
Or a dictionary of column names to SQL expressions or
|
Or a dictionary of column names to SQL expressions.
|
||||||
:class:`~lancedb.expr.Expr` objects.
|
|
||||||
All columns are fetched if None or unspecified.
|
All columns are fetched if None or unspecified.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@@ -879,15 +877,15 @@ class LanceQueryBuilder(ABC):
|
|||||||
raise ValueError("columns must be a list or a dictionary")
|
raise ValueError("columns must be a list or a dictionary")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def where(self, where: Union[str, Expr], prefilter: bool = True) -> Self:
|
def where(self, where: str, prefilter: bool = True) -> Self:
|
||||||
"""Set the where clause.
|
"""Set the where clause.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
where: str or :class:`~lancedb.expr.Expr`
|
where: str
|
||||||
The filter condition. Can be a SQL string or a type-safe
|
The where clause which is a valid SQL where clause. See
|
||||||
:class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col`
|
`Lance filter pushdown <https://lance.org/guide/read_and_write#filter-push-down>`_
|
||||||
and :func:`~lancedb.expr.lit`.
|
for valid SQL expressions.
|
||||||
prefilter: bool, default True
|
prefilter: bool, default True
|
||||||
If True, apply the filter before vector search, otherwise the
|
If True, apply the filter before vector search, otherwise the
|
||||||
filter is applied on the result of vector search.
|
filter is applied on the result of vector search.
|
||||||
@@ -1357,17 +1355,15 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
|
|
||||||
return result_set
|
return result_set
|
||||||
|
|
||||||
def where(
|
def where(self, where: str, prefilter: bool = None) -> LanceVectorQueryBuilder:
|
||||||
self, where: Union[str, Expr], prefilter: bool = None
|
|
||||||
) -> LanceVectorQueryBuilder:
|
|
||||||
"""Set the where clause.
|
"""Set the where clause.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
where: str or :class:`~lancedb.expr.Expr`
|
where: str
|
||||||
The filter condition. Can be a SQL string or a type-safe
|
The where clause which is a valid SQL where clause. See
|
||||||
:class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col`
|
`Lance filter pushdown <https://lance.org/guide/read_and_write#filter-push-down>`_
|
||||||
and :func:`~lancedb.expr.lit`.
|
for valid SQL expressions.
|
||||||
prefilter: bool, default True
|
prefilter: bool, default True
|
||||||
If True, apply the filter before vector search, otherwise the
|
If True, apply the filter before vector search, otherwise the
|
||||||
filter is applied on the result of vector search.
|
filter is applied on the result of vector search.
|
||||||
@@ -2209,8 +2205,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
|||||||
self._vector_query.select(self._columns)
|
self._vector_query.select(self._columns)
|
||||||
self._fts_query.select(self._columns)
|
self._fts_query.select(self._columns)
|
||||||
if self._where:
|
if self._where:
|
||||||
self._vector_query.where(self._where, not self._postfilter)
|
self._vector_query.where(self._where, self._postfilter)
|
||||||
self._fts_query.where(self._where, not self._postfilter)
|
self._fts_query.where(self._where, self._postfilter)
|
||||||
if self._with_row_id:
|
if self._with_row_id:
|
||||||
self._vector_query.with_row_id(True)
|
self._vector_query.with_row_id(True)
|
||||||
self._fts_query.with_row_id(True)
|
self._fts_query.with_row_id(True)
|
||||||
@@ -2290,20 +2286,10 @@ class AsyncQueryBase(object):
|
|||||||
"""
|
"""
|
||||||
if isinstance(columns, list) and all(isinstance(c, str) for c in columns):
|
if isinstance(columns, list) and all(isinstance(c, str) for c in columns):
|
||||||
self._inner.select_columns(columns)
|
self._inner.select_columns(columns)
|
||||||
elif isinstance(columns, dict) and all(isinstance(k, str) for k in columns):
|
elif isinstance(columns, dict) and all(
|
||||||
if any(isinstance(v, Expr) for v in columns.values()):
|
isinstance(k, str) and isinstance(v, str) for k, v in columns.items()
|
||||||
# At least one value is an Expr — use the type-safe path.
|
):
|
||||||
from .expr import _coerce
|
self._inner.select(list(columns.items()))
|
||||||
|
|
||||||
pairs = [(k, _coerce(v)._inner) for k, v in columns.items()]
|
|
||||||
self._inner.select_expr(pairs)
|
|
||||||
elif all(isinstance(v, str) for v in columns.values()):
|
|
||||||
self._inner.select(list(columns.items()))
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
"dict values must be str or Expr, got "
|
|
||||||
+ str({k: type(v) for k, v in columns.items()})
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("columns must be a list of column names or a dict")
|
raise TypeError("columns must be a list of column names or a dict")
|
||||||
return self
|
return self
|
||||||
@@ -2543,13 +2529,11 @@ class AsyncStandardQuery(AsyncQueryBase):
|
|||||||
"""
|
"""
|
||||||
super().__init__(inner)
|
super().__init__(inner)
|
||||||
|
|
||||||
def where(self, predicate: Union[str, Expr]) -> Self:
|
def where(self, predicate: str) -> Self:
|
||||||
"""
|
"""
|
||||||
Only return rows matching the given predicate
|
Only return rows matching the given predicate
|
||||||
|
|
||||||
The predicate can be a SQL string or a type-safe
|
The predicate should be supplied as an SQL query string.
|
||||||
:class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col`
|
|
||||||
and :func:`~lancedb.expr.lit`.
|
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
@@ -2561,10 +2545,7 @@ class AsyncStandardQuery(AsyncQueryBase):
|
|||||||
Filtering performance can often be improved by creating a scalar index
|
Filtering performance can often be improved by creating a scalar index
|
||||||
on the filter column(s).
|
on the filter column(s).
|
||||||
"""
|
"""
|
||||||
if isinstance(predicate, Expr):
|
self._inner.where(predicate)
|
||||||
self._inner.where_expr(predicate._inner)
|
|
||||||
else:
|
|
||||||
self._inner.where(predicate)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def limit(self, limit: int) -> Self:
|
def limit(self, limit: int) -> Self:
|
||||||
|
|||||||
@@ -568,4 +568,4 @@ class RemoteDBConnection(DBConnection):
|
|||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""Close the connection to the database."""
|
"""Close the connection to the database."""
|
||||||
self._conn.close()
|
self._client.close()
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import logging
|
import logging
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, Literal
|
from typing import Dict, Iterable, List, Optional, Union, Literal
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from lancedb._lancedb import (
|
from lancedb._lancedb import (
|
||||||
@@ -35,7 +35,6 @@ import pyarrow as pa
|
|||||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
from lancedb.merge import LanceMergeInsertBuilder
|
from lancedb.merge import LanceMergeInsertBuilder
|
||||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||||
from lancedb.table import _normalize_progress
|
|
||||||
|
|
||||||
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder, LanceTakeQueryBuilder
|
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder, LanceTakeQueryBuilder
|
||||||
from ..table import AsyncTable, IndexStatistics, Query, Table, Tags
|
from ..table import AsyncTable, IndexStatistics, Query, Table, Tags
|
||||||
@@ -309,7 +308,6 @@ class RemoteTable(Table):
|
|||||||
mode: str = "append",
|
mode: str = "append",
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
progress: Optional[Union[bool, Callable, Any]] = None,
|
|
||||||
) -> AddResult:
|
) -> AddResult:
|
||||||
"""Add more data to the [Table](Table). It has the same API signature as
|
"""Add more data to the [Table](Table). It has the same API signature as
|
||||||
the OSS version.
|
the OSS version.
|
||||||
@@ -332,29 +330,17 @@ class RemoteTable(Table):
|
|||||||
One of "error", "drop", "fill".
|
One of "error", "drop", "fill".
|
||||||
fill_value: float, default 0.
|
fill_value: float, default 0.
|
||||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
progress: bool, callable, or tqdm-like, optional
|
|
||||||
A callback or tqdm-compatible progress bar. See
|
|
||||||
:meth:`Table.add` for details.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
AddResult
|
AddResult
|
||||||
An object containing the new version number of the table after adding data.
|
An object containing the new version number of the table after adding data.
|
||||||
"""
|
"""
|
||||||
progress, owns = _normalize_progress(progress)
|
return LOOP.run(
|
||||||
try:
|
self._table.add(
|
||||||
return LOOP.run(
|
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||||
self._table.add(
|
|
||||||
data,
|
|
||||||
mode=mode,
|
|
||||||
on_bad_vectors=on_bad_vectors,
|
|
||||||
fill_value=fill_value,
|
|
||||||
progress=progress,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
finally:
|
)
|
||||||
if owns:
|
|
||||||
progress.close()
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from functools import cached_property
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
@@ -278,7 +277,7 @@ def _sanitize_data(
|
|||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
new_metadata = target_schema.metadata or {}
|
new_metadata = target_schema.metadata or {}
|
||||||
new_metadata.update(metadata)
|
new_metadata = new_metadata.update(metadata)
|
||||||
target_schema = target_schema.with_metadata(new_metadata)
|
target_schema = target_schema.with_metadata(new_metadata)
|
||||||
|
|
||||||
_validate_schema(target_schema)
|
_validate_schema(target_schema)
|
||||||
@@ -557,21 +556,6 @@ def _table_uri(base: str, table_name: str) -> str:
|
|||||||
return join_uri(base, f"{table_name}.lance")
|
return join_uri(base, f"{table_name}.lance")
|
||||||
|
|
||||||
|
|
||||||
def _normalize_progress(progress):
|
|
||||||
"""Normalize a ``progress`` parameter for :meth:`Table.add`.
|
|
||||||
|
|
||||||
Returns ``(progress_obj, owns)`` where *owns* is True when we created a
|
|
||||||
tqdm bar that the caller must close.
|
|
||||||
"""
|
|
||||||
if progress is True:
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
return tqdm(unit=" rows"), True
|
|
||||||
if progress is False or progress is None:
|
|
||||||
return None, False
|
|
||||||
return progress, False
|
|
||||||
|
|
||||||
|
|
||||||
class Table(ABC):
|
class Table(ABC):
|
||||||
"""
|
"""
|
||||||
A Table is a collection of Records in a LanceDB Database.
|
A Table is a collection of Records in a LanceDB Database.
|
||||||
@@ -990,7 +974,6 @@ class Table(ABC):
|
|||||||
mode: AddMode = "append",
|
mode: AddMode = "append",
|
||||||
on_bad_vectors: OnBadVectorsType = "error",
|
on_bad_vectors: OnBadVectorsType = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
progress: Optional[Union[bool, Callable, Any]] = None,
|
|
||||||
) -> AddResult:
|
) -> AddResult:
|
||||||
"""Add more data to the [Table](Table).
|
"""Add more data to the [Table](Table).
|
||||||
|
|
||||||
@@ -1012,29 +995,6 @@ class Table(ABC):
|
|||||||
One of "error", "drop", "fill".
|
One of "error", "drop", "fill".
|
||||||
fill_value: float, default 0.
|
fill_value: float, default 0.
|
||||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
progress: bool, callable, or tqdm-like, optional
|
|
||||||
Progress reporting during the add operation. Can be:
|
|
||||||
|
|
||||||
- ``True`` to automatically create and display a tqdm progress
|
|
||||||
bar (requires ``tqdm`` to be installed)::
|
|
||||||
|
|
||||||
table.add(data, progress=True)
|
|
||||||
|
|
||||||
- A **callable** that receives a dict with keys ``output_rows``,
|
|
||||||
``output_bytes``, ``total_rows``, ``elapsed_seconds``,
|
|
||||||
``active_tasks``, ``total_tasks``, and ``done``::
|
|
||||||
|
|
||||||
def on_progress(p):
|
|
||||||
print(f"{p['output_rows']}/{p['total_rows']} rows, "
|
|
||||||
f"{p['active_tasks']}/{p['total_tasks']} workers")
|
|
||||||
table.add(data, progress=on_progress)
|
|
||||||
|
|
||||||
- A **tqdm-compatible** progress bar whose ``total`` and
|
|
||||||
``update()`` will be called automatically. The postfix shows
|
|
||||||
write throughput (MB/s) and active worker count::
|
|
||||||
|
|
||||||
with tqdm() as pbar:
|
|
||||||
table.add(data, progress=pbar)
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -2532,7 +2492,6 @@ class LanceTable(Table):
|
|||||||
mode: AddMode = "append",
|
mode: AddMode = "append",
|
||||||
on_bad_vectors: OnBadVectorsType = "error",
|
on_bad_vectors: OnBadVectorsType = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
progress: Optional[Union[bool, Callable, Any]] = None,
|
|
||||||
) -> AddResult:
|
) -> AddResult:
|
||||||
"""Add data to the table.
|
"""Add data to the table.
|
||||||
If vector columns are missing and the table
|
If vector columns are missing and the table
|
||||||
@@ -2551,29 +2510,17 @@ class LanceTable(Table):
|
|||||||
One of "error", "drop", "fill", "null".
|
One of "error", "drop", "fill", "null".
|
||||||
fill_value: float, default 0.
|
fill_value: float, default 0.
|
||||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
progress: bool, callable, or tqdm-like, optional
|
|
||||||
A callback or tqdm-compatible progress bar. See
|
|
||||||
:meth:`Table.add` for details.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
int
|
int
|
||||||
The number of vectors in the table.
|
The number of vectors in the table.
|
||||||
"""
|
"""
|
||||||
progress, owns = _normalize_progress(progress)
|
return LOOP.run(
|
||||||
try:
|
self._table.add(
|
||||||
return LOOP.run(
|
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||||
self._table.add(
|
|
||||||
data,
|
|
||||||
mode=mode,
|
|
||||||
on_bad_vectors=on_bad_vectors,
|
|
||||||
fill_value=fill_value,
|
|
||||||
progress=progress,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
finally:
|
)
|
||||||
if owns:
|
|
||||||
progress.close()
|
|
||||||
|
|
||||||
def merge(
|
def merge(
|
||||||
self,
|
self,
|
||||||
@@ -3822,7 +3769,6 @@ class AsyncTable:
|
|||||||
mode: Optional[Literal["append", "overwrite"]] = "append",
|
mode: Optional[Literal["append", "overwrite"]] = "append",
|
||||||
on_bad_vectors: Optional[OnBadVectorsType] = None,
|
on_bad_vectors: Optional[OnBadVectorsType] = None,
|
||||||
fill_value: Optional[float] = None,
|
fill_value: Optional[float] = None,
|
||||||
progress: Optional[Union[bool, Callable, Any]] = None,
|
|
||||||
) -> AddResult:
|
) -> AddResult:
|
||||||
"""Add more data to the [Table](Table).
|
"""Add more data to the [Table](Table).
|
||||||
|
|
||||||
@@ -3844,9 +3790,6 @@ class AsyncTable:
|
|||||||
One of "error", "drop", "fill", "null".
|
One of "error", "drop", "fill", "null".
|
||||||
fill_value: float, default 0.
|
fill_value: float, default 0.
|
||||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
progress: callable or tqdm-like, optional
|
|
||||||
A callback or tqdm-compatible progress bar. See
|
|
||||||
:meth:`Table.add` for details.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
schema = await self.schema()
|
schema = await self.schema()
|
||||||
@@ -3857,13 +3800,7 @@ class AsyncTable:
|
|||||||
|
|
||||||
# _santitize_data is an old code path, but we will use it until the
|
# _santitize_data is an old code path, but we will use it until the
|
||||||
# new code path is ready.
|
# new code path is ready.
|
||||||
if mode == "overwrite":
|
if on_bad_vectors != "error" or (
|
||||||
# For overwrite, apply the same preprocessing as create_table
|
|
||||||
# so vector columns are inferred as FixedSizeList.
|
|
||||||
data, _ = sanitize_create_table(
|
|
||||||
data, None, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
|
||||||
)
|
|
||||||
elif on_bad_vectors != "error" or (
|
|
||||||
schema.metadata is not None and b"embedding_functions" in schema.metadata
|
schema.metadata is not None and b"embedding_functions" in schema.metadata
|
||||||
):
|
):
|
||||||
data = _sanitize_data(
|
data = _sanitize_data(
|
||||||
@@ -3876,9 +3813,8 @@ class AsyncTable:
|
|||||||
)
|
)
|
||||||
_register_optional_converters()
|
_register_optional_converters()
|
||||||
data = to_scannable(data)
|
data = to_scannable(data)
|
||||||
progress, owns = _normalize_progress(progress)
|
|
||||||
try:
|
try:
|
||||||
return await self._inner.add(data, mode or "append", progress=progress)
|
return await self._inner.add(data, mode or "append")
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if "Cast error" in str(e):
|
if "Cast error" in str(e):
|
||||||
raise ValueError(e)
|
raise ValueError(e)
|
||||||
@@ -3886,9 +3822,6 @@ class AsyncTable:
|
|||||||
raise ValueError(e)
|
raise ValueError(e)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
if owns:
|
|
||||||
progress.close()
|
|
||||||
|
|
||||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||||
"""
|
"""
|
||||||
@@ -4211,7 +4144,7 @@ class AsyncTable:
|
|||||||
async_query = async_query.offset(query.offset)
|
async_query = async_query.offset(query.offset)
|
||||||
if query.columns:
|
if query.columns:
|
||||||
async_query = async_query.select(query.columns)
|
async_query = async_query.select(query.columns)
|
||||||
if query.filter is not None:
|
if query.filter:
|
||||||
async_query = async_query.where(query.filter)
|
async_query = async_query.where(query.filter)
|
||||||
if query.fast_search:
|
if query.fast_search:
|
||||||
async_query = async_query.fast_search()
|
async_query = async_query.fast_search()
|
||||||
@@ -4818,16 +4751,7 @@ class IndexStatistics:
|
|||||||
num_indexed_rows: int
|
num_indexed_rows: int
|
||||||
num_unindexed_rows: int
|
num_unindexed_rows: int
|
||||||
index_type: Literal[
|
index_type: Literal[
|
||||||
"IVF_FLAT",
|
"IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ", "FTS", "BTREE", "BITMAP", "LABEL_LIST"
|
||||||
"IVF_SQ",
|
|
||||||
"IVF_PQ",
|
|
||||||
"IVF_RQ",
|
|
||||||
"IVF_HNSW_SQ",
|
|
||||||
"IVF_HNSW_PQ",
|
|
||||||
"FTS",
|
|
||||||
"BTREE",
|
|
||||||
"BITMAP",
|
|
||||||
"LABEL_LIST",
|
|
||||||
]
|
]
|
||||||
distance_type: Optional[Literal["l2", "cosine", "dot"]] = None
|
distance_type: Optional[Literal["l2", "cosine", "dot"]] = None
|
||||||
num_indices: Optional[int] = None
|
num_indices: Optional[int] = None
|
||||||
|
|||||||
@@ -546,24 +546,3 @@ def test_openai_no_retry_on_401(mock_sleep):
|
|||||||
assert mock_func.call_count == 1
|
assert mock_func.call_count == 1
|
||||||
# Verify that sleep was never called (no retries)
|
# Verify that sleep was never called (no retries)
|
||||||
assert mock_sleep.call_count == 0
|
assert mock_sleep.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
def test_url_retrieve_downloads_image():
|
|
||||||
"""
|
|
||||||
Embedding functions like open-clip, siglip, and jinaai use url_retrieve()
|
|
||||||
to download images from HTTP URLs. For example, open_clip._to_pil() calls:
|
|
||||||
|
|
||||||
PIL_Image.open(io.BytesIO(url_retrieve(image)))
|
|
||||||
|
|
||||||
Verify that url_retrieve() can download an image and open it as PIL Image,
|
|
||||||
matching the real usage pattern in embedding functions.
|
|
||||||
"""
|
|
||||||
import io
|
|
||||||
|
|
||||||
Image = pytest.importorskip("PIL.Image")
|
|
||||||
from lancedb.embeddings.utils import url_retrieve
|
|
||||||
|
|
||||||
image_url = "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg"
|
|
||||||
image_bytes = url_retrieve(image_url)
|
|
||||||
img = Image.open(io.BytesIO(image_bytes))
|
|
||||||
assert img.size[0] > 0 and img.size[1] > 0
|
|
||||||
|
|||||||
@@ -177,60 +177,6 @@ async def test_analyze_plan(table: AsyncTable):
|
|||||||
assert "metrics=" in res
|
assert "metrics=" in res
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def table_with_id(tmpdir_factory) -> Table:
|
|
||||||
tmp_path = str(tmpdir_factory.mktemp("data"))
|
|
||||||
db = lancedb.connect(tmp_path)
|
|
||||||
data = pa.table(
|
|
||||||
{
|
|
||||||
"id": pa.array([1, 2, 3, 4], type=pa.int64()),
|
|
||||||
"text": pa.array(["a", "b", "cat", "dog"]),
|
|
||||||
"vector": pa.array(
|
|
||||||
[[0.1, 0.1], [2, 2], [-0.1, -0.1], [0.5, -0.5]],
|
|
||||||
type=pa.list_(pa.float32(), list_size=2),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
table = db.create_table("test_with_id", data)
|
|
||||||
table.create_fts_index("text", with_position=False, use_tantivy=False)
|
|
||||||
return table
|
|
||||||
|
|
||||||
|
|
||||||
def test_hybrid_prefilter_explain_plan(table_with_id: Table):
|
|
||||||
"""
|
|
||||||
Verify that the prefilter logic is not inverted in LanceHybridQueryBuilder.
|
|
||||||
"""
|
|
||||||
plan_prefilter = (
|
|
||||||
table_with_id.search(query_type="hybrid")
|
|
||||||
.vector([0.0, 0.0])
|
|
||||||
.text("dog")
|
|
||||||
.where("id = 1", prefilter=True)
|
|
||||||
.limit(2)
|
|
||||||
.explain_plan(verbose=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
plan_postfilter = (
|
|
||||||
table_with_id.search(query_type="hybrid")
|
|
||||||
.vector([0.0, 0.0])
|
|
||||||
.text("dog")
|
|
||||||
.where("id = 1", prefilter=False)
|
|
||||||
.limit(2)
|
|
||||||
.explain_plan(verbose=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
# prefilter=True: filter is pushed into the LanceRead scan.
|
|
||||||
# The FTS sub-plan exposes this as "full_filter=id = Int64(1)" inside LanceRead.
|
|
||||||
assert "full_filter=id = Int64(1)" in plan_prefilter, (
|
|
||||||
f"Should push the filter into the scan.\nPlan:\n{plan_prefilter}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# prefilter=False: filter is applied as a separate FilterExec after the search.
|
|
||||||
# The filter must NOT be embedded in the scan.
|
|
||||||
assert "full_filter=id = Int64(1)" not in plan_postfilter, (
|
|
||||||
f"Should NOT push the filter into the scan.\nPlan:\n{plan_postfilter}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_scores():
|
def test_normalize_scores():
|
||||||
cases = [
|
cases = [
|
||||||
(pa.array([0.1, 0.4]), pa.array([0.0, 1.0])),
|
(pa.array([0.1, 0.4]), pa.array([0.0, 1.0])),
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import random
|
import random
|
||||||
from typing import get_args, get_type_hints
|
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pytest
|
import pytest
|
||||||
@@ -23,7 +22,6 @@ from lancedb.index import (
|
|||||||
HnswSq,
|
HnswSq,
|
||||||
FTS,
|
FTS,
|
||||||
)
|
)
|
||||||
from lancedb.table import IndexStatistics
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
@@ -285,23 +283,3 @@ async def test_create_index_with_binary_vectors(binary_table: AsyncTable):
|
|||||||
for v in range(256):
|
for v in range(256):
|
||||||
res = await binary_table.query().nearest_to([v] * 128).to_arrow()
|
res = await binary_table.query().nearest_to([v] * 128).to_arrow()
|
||||||
assert res["id"][0].as_py() == v
|
assert res["id"][0].as_py() == v
|
||||||
|
|
||||||
|
|
||||||
def test_index_statistics_index_type_lists_all_supported_values():
|
|
||||||
expected_index_types = {
|
|
||||||
"IVF_FLAT",
|
|
||||||
"IVF_SQ",
|
|
||||||
"IVF_PQ",
|
|
||||||
"IVF_RQ",
|
|
||||||
"IVF_HNSW_SQ",
|
|
||||||
"IVF_HNSW_PQ",
|
|
||||||
"FTS",
|
|
||||||
"BTREE",
|
|
||||||
"BITMAP",
|
|
||||||
"LABEL_LIST",
|
|
||||||
}
|
|
||||||
|
|
||||||
assert (
|
|
||||||
set(get_args(get_type_hints(IndexStatistics)["index_type"]))
|
|
||||||
== expected_index_types
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import shutil
|
|||||||
import pytest
|
import pytest
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import lancedb
|
import lancedb
|
||||||
from lance_namespace.errors import NamespaceNotEmptyError, TableNotFoundError
|
|
||||||
|
|
||||||
|
|
||||||
class TestNamespaceConnection:
|
class TestNamespaceConnection:
|
||||||
@@ -131,7 +130,7 @@ class TestNamespaceConnection:
|
|||||||
assert len(list(db.table_names(namespace=["test_ns"]))) == 0
|
assert len(list(db.table_names(namespace=["test_ns"]))) == 0
|
||||||
|
|
||||||
# Should not be able to open dropped table
|
# Should not be able to open dropped table
|
||||||
with pytest.raises(TableNotFoundError):
|
with pytest.raises(RuntimeError):
|
||||||
db.open_table("table1", namespace=["test_ns"])
|
db.open_table("table1", namespace=["test_ns"])
|
||||||
|
|
||||||
def test_create_table_with_schema(self):
|
def test_create_table_with_schema(self):
|
||||||
@@ -341,7 +340,7 @@ class TestNamespaceConnection:
|
|||||||
db.create_table("test_table", schema=schema, namespace=["test_namespace"])
|
db.create_table("test_table", schema=schema, namespace=["test_namespace"])
|
||||||
|
|
||||||
# Try to drop namespace with tables - should fail
|
# Try to drop namespace with tables - should fail
|
||||||
with pytest.raises(NamespaceNotEmptyError):
|
with pytest.raises(RuntimeError, match="is not empty"):
|
||||||
db.drop_namespace(["test_namespace"])
|
db.drop_namespace(["test_namespace"])
|
||||||
|
|
||||||
# Drop table first
|
# Drop table first
|
||||||
|
|||||||
@@ -147,12 +147,7 @@ class TrackingNamespace(LanceNamespace):
|
|||||||
This simulates a credential rotation system where each call returns
|
This simulates a credential rotation system where each call returns
|
||||||
new credentials that expire after credential_expires_in_seconds.
|
new credentials that expire after credential_expires_in_seconds.
|
||||||
"""
|
"""
|
||||||
# Start from base storage options (endpoint, region, allow_http, etc.)
|
modified = copy.deepcopy(storage_options) if storage_options else {}
|
||||||
# because DirectoryNamespace returns None for storage_options from
|
|
||||||
# describe_table/declare_table when no credential vendor is configured.
|
|
||||||
modified = copy.deepcopy(self.base_storage_options)
|
|
||||||
if storage_options:
|
|
||||||
modified.update(storage_options)
|
|
||||||
|
|
||||||
# Increment credentials to simulate rotation
|
# Increment credentials to simulate rotation
|
||||||
modified["aws_access_key_id"] = f"AKID_{count}"
|
modified["aws_access_key_id"] = f"AKID_{count}"
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ from lancedb.query import (
|
|||||||
PhraseQuery,
|
PhraseQuery,
|
||||||
Query,
|
Query,
|
||||||
FullTextSearchQuery,
|
FullTextSearchQuery,
|
||||||
ensure_vector_query,
|
|
||||||
)
|
)
|
||||||
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
|
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
|
||||||
from lancedb.table import AsyncTable, LanceTable
|
from lancedb.table import AsyncTable, LanceTable
|
||||||
@@ -1502,18 +1501,6 @@ def test_search_empty_table(mem_db):
|
|||||||
assert results == []
|
assert results == []
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_vector_query_empty_list():
|
|
||||||
"""Regression: ensure_vector_query used to return instead of raise ValueError."""
|
|
||||||
with pytest.raises(ValueError, match="non-empty"):
|
|
||||||
ensure_vector_query([])
|
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_vector_query_nested_empty_list():
|
|
||||||
"""Regression: ensure_vector_query used to return instead of raise ValueError."""
|
|
||||||
with pytest.raises(ValueError, match="non-empty"):
|
|
||||||
ensure_vector_query([[]])
|
|
||||||
|
|
||||||
|
|
||||||
def test_fast_search(tmp_path):
|
def test_fast_search(tmp_path):
|
||||||
db = lancedb.connect(tmp_path)
|
db = lancedb.connect(tmp_path)
|
||||||
|
|
||||||
|
|||||||
@@ -1201,18 +1201,6 @@ async def test_header_provider_overrides_static_headers():
|
|||||||
await db.table_names()
|
await db.table_names()
|
||||||
|
|
||||||
|
|
||||||
def test_close():
|
|
||||||
"""Test that close() works without AttributeError."""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
def handler(req):
|
|
||||||
req.send_response(200)
|
|
||||||
req.end_headers()
|
|
||||||
|
|
||||||
with mock_lancedb_connection(handler) as db:
|
|
||||||
asyncio.run(db.close())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("exception", [KeyboardInterrupt, SystemExit, GeneratorExit])
|
@pytest.mark.parametrize("exception", [KeyboardInterrupt, SystemExit, GeneratorExit])
|
||||||
def test_background_loop_cancellation(exception):
|
def test_background_loop_cancellation(exception):
|
||||||
"""Test that BackgroundEventLoop.run() cancels the future on interrupt."""
|
"""Test that BackgroundEventLoop.run() cancels the future on interrupt."""
|
||||||
|
|||||||
@@ -527,132 +527,6 @@ async def test_add_async(mem_db_async: AsyncConnection):
|
|||||||
assert await table.count_rows() == 3
|
assert await table.count_rows() == 3
|
||||||
|
|
||||||
|
|
||||||
def test_add_overwrite_infers_vector_schema(mem_db: DBConnection):
|
|
||||||
"""Overwrite should infer vector columns the same way create_table does.
|
|
||||||
|
|
||||||
Regression test for https://github.com/lancedb/lancedb/issues/3183
|
|
||||||
"""
|
|
||||||
table = mem_db.create_table(
|
|
||||||
"test_overwrite_vec",
|
|
||||||
data=[
|
|
||||||
{"vector": [1.0, 2.0, 3.0, 4.0], "item": "foo"},
|
|
||||||
{"vector": [5.0, 6.0, 7.0, 8.0], "item": "bar"},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
# create_table infers vector as fixed_size_list<float32, 4>
|
|
||||||
original_type = table.schema.field("vector").type
|
|
||||||
assert pa.types.is_fixed_size_list(original_type)
|
|
||||||
|
|
||||||
# overwrite with plain Python lists (PyArrow infers list<double>)
|
|
||||||
table.add(
|
|
||||||
[
|
|
||||||
{"vector": [10.0, 20.0, 30.0, 40.0], "item": "baz"},
|
|
||||||
],
|
|
||||||
mode="overwrite",
|
|
||||||
)
|
|
||||||
# overwrite should infer vector column the same way as create_table
|
|
||||||
new_type = table.schema.field("vector").type
|
|
||||||
assert pa.types.is_fixed_size_list(new_type), (
|
|
||||||
f"Expected fixed_size_list after overwrite, got {new_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_progress_callback(mem_db: DBConnection):
|
|
||||||
table = mem_db.create_table(
|
|
||||||
"test",
|
|
||||||
data=[{"id": 1}, {"id": 2}],
|
|
||||||
)
|
|
||||||
|
|
||||||
updates = []
|
|
||||||
table.add([{"id": 3}, {"id": 4}], progress=lambda p: updates.append(dict(p)))
|
|
||||||
|
|
||||||
assert len(table) == 4
|
|
||||||
# The done callback always fires, so we should always get at least one.
|
|
||||||
assert len(updates) >= 1, "expected at least one progress callback"
|
|
||||||
for p in updates:
|
|
||||||
assert "output_rows" in p
|
|
||||||
assert "output_bytes" in p
|
|
||||||
assert "total_rows" in p
|
|
||||||
assert "elapsed_seconds" in p
|
|
||||||
assert "active_tasks" in p
|
|
||||||
assert "total_tasks" in p
|
|
||||||
assert "done" in p
|
|
||||||
# The last callback should have done=True.
|
|
||||||
assert updates[-1]["done"] is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_progress_tqdm_like(mem_db: DBConnection):
|
|
||||||
"""Test that a tqdm-like object gets total set and update() called."""
|
|
||||||
|
|
||||||
class FakeBar:
|
|
||||||
def __init__(self):
|
|
||||||
self.total = None
|
|
||||||
self.n = 0
|
|
||||||
self.postfix = None
|
|
||||||
|
|
||||||
def update(self, n):
|
|
||||||
self.n += n
|
|
||||||
|
|
||||||
def set_postfix_str(self, s):
|
|
||||||
self.postfix = s
|
|
||||||
|
|
||||||
def refresh(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
table = mem_db.create_table(
|
|
||||||
"test",
|
|
||||||
data=[{"id": 1}, {"id": 2}],
|
|
||||||
)
|
|
||||||
|
|
||||||
bar = FakeBar()
|
|
||||||
table.add([{"id": 3}, {"id": 4}], progress=bar)
|
|
||||||
|
|
||||||
assert len(table) == 4
|
|
||||||
# Postfix should contain throughput and worker count
|
|
||||||
if bar.postfix is not None:
|
|
||||||
assert "MB/s" in bar.postfix
|
|
||||||
assert "workers" in bar.postfix
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_progress_bool(mem_db: DBConnection):
|
|
||||||
"""Test that progress=True creates and closes a tqdm bar automatically."""
|
|
||||||
table = mem_db.create_table(
|
|
||||||
"test",
|
|
||||||
data=[{"id": 1}, {"id": 2}],
|
|
||||||
)
|
|
||||||
|
|
||||||
table.add([{"id": 3}, {"id": 4}], progress=True)
|
|
||||||
assert len(table) == 4
|
|
||||||
|
|
||||||
# progress=False should be the same as None
|
|
||||||
table.add([{"id": 5}], progress=False)
|
|
||||||
assert len(table) == 5
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_add_progress_callback_async(mem_db_async: AsyncConnection):
|
|
||||||
"""Progress callbacks work through the async path too."""
|
|
||||||
table = await mem_db_async.create_table("test", data=[{"id": 1}, {"id": 2}])
|
|
||||||
|
|
||||||
updates = []
|
|
||||||
await table.add([{"id": 3}, {"id": 4}], progress=lambda p: updates.append(dict(p)))
|
|
||||||
|
|
||||||
assert await table.count_rows() == 4
|
|
||||||
assert len(updates) >= 1
|
|
||||||
assert updates[-1]["done"] is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_progress_callback_error(mem_db: DBConnection):
|
|
||||||
"""A failing callback must not prevent the write from succeeding."""
|
|
||||||
table = mem_db.create_table("test", data=[{"id": 1}, {"id": 2}])
|
|
||||||
|
|
||||||
def bad_callback(p):
|
|
||||||
raise RuntimeError("boom")
|
|
||||||
|
|
||||||
table.add([{"id": 3}, {"id": 4}], progress=bad_callback)
|
|
||||||
assert len(table) == 4
|
|
||||||
|
|
||||||
|
|
||||||
def test_polars(mem_db: DBConnection):
|
def test_polars(mem_db: DBConnection):
|
||||||
data = {
|
data = {
|
||||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||||
@@ -2173,33 +2047,3 @@ def test_table_uri(tmp_path):
|
|||||||
db = lancedb.connect(tmp_path)
|
db = lancedb.connect(tmp_path)
|
||||||
table = db.create_table("my_table", data=[{"x": 0}])
|
table = db.create_table("my_table", data=[{"x": 0}])
|
||||||
assert table.uri == str(tmp_path / "my_table.lance")
|
assert table.uri == str(tmp_path / "my_table.lance")
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_data_metadata_not_stripped():
|
|
||||||
"""Regression test: dict.update() returns None, so assigning its result
|
|
||||||
would silently replace metadata with None, causing with_metadata(None)
|
|
||||||
to strip all schema metadata from the target schema."""
|
|
||||||
from lancedb.table import _sanitize_data
|
|
||||||
|
|
||||||
schema = pa.schema(
|
|
||||||
[pa.field("x", pa.int64())],
|
|
||||||
metadata={b"existing_key": b"existing_value"},
|
|
||||||
)
|
|
||||||
batch = pa.record_batch([pa.array([1, 2, 3])], schema=schema)
|
|
||||||
|
|
||||||
# Use a different field type so the reader and target schemas differ,
|
|
||||||
# forcing _cast_to_target_schema to rebuild the schema with the
|
|
||||||
# target's metadata (instead of taking the fast-path).
|
|
||||||
target_schema = pa.schema(
|
|
||||||
[pa.field("x", pa.int32())],
|
|
||||||
metadata={b"existing_key": b"existing_value"},
|
|
||||||
)
|
|
||||||
|
|
||||||
reader = pa.RecordBatchReader.from_batches(schema, [batch])
|
|
||||||
metadata = {b"new_key": b"new_value"}
|
|
||||||
result = _sanitize_data(reader, target_schema=target_schema, metadata=metadata)
|
|
||||||
|
|
||||||
result_schema = result.schema
|
|
||||||
assert result_schema.metadata is not None
|
|
||||||
assert result_schema.metadata[b"existing_key"] == b"existing_value"
|
|
||||||
assert result_schema.metadata[b"new_key"] == b"new_value"
|
|
||||||
|
|||||||
@@ -1,175 +0,0 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
||||||
|
|
||||||
//! PyO3 bindings for the LanceDB expression builder API.
|
|
||||||
//!
|
|
||||||
//! This module exposes [`PyExpr`] and helper free functions so Python can
|
|
||||||
//! build type-safe filter / projection expressions that map directly to
|
|
||||||
//! DataFusion [`Expr`] nodes, bypassing SQL string parsing.
|
|
||||||
|
|
||||||
use arrow::{datatypes::DataType, pyarrow::PyArrowType};
|
|
||||||
use lancedb::expr::{DfExpr, col as ldb_col, contains, expr_cast, lit as df_lit, lower, upper};
|
|
||||||
use pyo3::{Bound, PyAny, PyResult, exceptions::PyValueError, prelude::*, pyfunction};
|
|
||||||
|
|
||||||
/// A type-safe DataFusion expression.
|
|
||||||
///
|
|
||||||
/// Instances are constructed via the free functions [`expr_col`] and
|
|
||||||
/// [`expr_lit`] and combined with the methods on this struct. On the Python
|
|
||||||
/// side a thin wrapper class (`lancedb.expr.Expr`) delegates to these methods
|
|
||||||
/// and adds Python operator overloads.
|
|
||||||
#[pyclass(name = "PyExpr")]
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct PyExpr(pub DfExpr);
|
|
||||||
|
|
||||||
#[pymethods]
|
|
||||||
impl PyExpr {
|
|
||||||
// ── comparisons ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
fn eq(&self, other: &Self) -> Self {
|
|
||||||
Self(self.0.clone().eq(other.0.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ne(&self, other: &Self) -> Self {
|
|
||||||
Self(self.0.clone().not_eq(other.0.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn lt(&self, other: &Self) -> Self {
|
|
||||||
Self(self.0.clone().lt(other.0.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn lte(&self, other: &Self) -> Self {
|
|
||||||
Self(self.0.clone().lt_eq(other.0.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn gt(&self, other: &Self) -> Self {
|
|
||||||
Self(self.0.clone().gt(other.0.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn gte(&self, other: &Self) -> Self {
|
|
||||||
Self(self.0.clone().gt_eq(other.0.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── logical ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
fn and_(&self, other: &Self) -> Self {
|
|
||||||
Self(self.0.clone().and(other.0.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn or_(&self, other: &Self) -> Self {
|
|
||||||
Self(self.0.clone().or(other.0.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn not_(&self) -> Self {
|
|
||||||
use std::ops::Not;
|
|
||||||
Self(self.0.clone().not())
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── arithmetic ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
fn add(&self, other: &Self) -> Self {
|
|
||||||
use std::ops::Add;
|
|
||||||
Self(self.0.clone().add(other.0.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sub(&self, other: &Self) -> Self {
|
|
||||||
use std::ops::Sub;
|
|
||||||
Self(self.0.clone().sub(other.0.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mul(&self, other: &Self) -> Self {
|
|
||||||
use std::ops::Mul;
|
|
||||||
Self(self.0.clone().mul(other.0.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn div(&self, other: &Self) -> Self {
|
|
||||||
use std::ops::Div;
|
|
||||||
Self(self.0.clone().div(other.0.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── string functions ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
/// Convert string column to lowercase.
|
|
||||||
fn lower(&self) -> Self {
|
|
||||||
Self(lower(self.0.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert string column to uppercase.
|
|
||||||
fn upper(&self) -> Self {
|
|
||||||
Self(upper(self.0.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Test whether the string contains `substr`.
|
|
||||||
fn contains(&self, substr: &Self) -> Self {
|
|
||||||
Self(contains(self.0.clone(), substr.0.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── type cast ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
/// Cast the expression to `data_type`.
|
|
||||||
///
|
|
||||||
/// `data_type` must be a PyArrow `DataType` (e.g. `pa.int32()`).
|
|
||||||
/// On the Python side, `lancedb.expr.Expr.cast` also accepts type name
|
|
||||||
/// strings via `pa.lib.ensure_type` before forwarding here.
|
|
||||||
fn cast(&self, data_type: PyArrowType<DataType>) -> Self {
|
|
||||||
Self(expr_cast(self.0.clone(), data_type.0))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── utilities ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
/// Render the expression as a SQL string (useful for debugging).
|
|
||||||
fn to_sql(&self) -> PyResult<String> {
|
|
||||||
lancedb::expr::expr_to_sql_string(&self.0).map_err(|e| PyValueError::new_err(e.to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn __repr__(&self) -> PyResult<String> {
|
|
||||||
let sql =
|
|
||||||
lancedb::expr::expr_to_sql_string(&self.0).unwrap_or_else(|_| "<expr>".to_string());
|
|
||||||
Ok(format!("PyExpr({})", sql))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── free functions ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
/// Create a column reference expression.
|
|
||||||
///
|
|
||||||
/// The column name is preserved exactly as given (case-sensitive), so
|
|
||||||
/// `col("firstName")` correctly references a field named `firstName`.
|
|
||||||
#[pyfunction]
|
|
||||||
pub fn expr_col(name: &str) -> PyExpr {
|
|
||||||
PyExpr(ldb_col(name))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a literal value expression.
|
|
||||||
///
|
|
||||||
/// Supported Python types: `bool`, `int`, `float`, `str`.
|
|
||||||
#[pyfunction]
|
|
||||||
pub fn expr_lit(value: Bound<'_, PyAny>) -> PyResult<PyExpr> {
|
|
||||||
// bool must be checked before int because bool is a subclass of int in Python
|
|
||||||
if let Ok(b) = value.extract::<bool>() {
|
|
||||||
return Ok(PyExpr(df_lit(b)));
|
|
||||||
}
|
|
||||||
if let Ok(i) = value.extract::<i64>() {
|
|
||||||
return Ok(PyExpr(df_lit(i)));
|
|
||||||
}
|
|
||||||
if let Ok(f) = value.extract::<f64>() {
|
|
||||||
return Ok(PyExpr(df_lit(f)));
|
|
||||||
}
|
|
||||||
if let Ok(s) = value.extract::<String>() {
|
|
||||||
return Ok(PyExpr(df_lit(s)));
|
|
||||||
}
|
|
||||||
Err(PyValueError::new_err(format!(
|
|
||||||
"unsupported literal type: {}. Supported: bool, int, float, str",
|
|
||||||
value.get_type().name()?
|
|
||||||
)))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Call an arbitrary registered SQL function by name.
|
|
||||||
///
|
|
||||||
/// See `lancedb::expr::func` for the list of supported function names.
|
|
||||||
#[pyfunction]
|
|
||||||
pub fn expr_func(name: &str, args: Vec<PyExpr>) -> PyResult<PyExpr> {
|
|
||||||
let df_args: Vec<DfExpr> = args.into_iter().map(|e| e.0).collect();
|
|
||||||
lancedb::expr::func(name, df_args)
|
|
||||||
.map(PyExpr)
|
|
||||||
.map_err(|e| PyValueError::new_err(e.to_string()))
|
|
||||||
}
|
|
||||||
@@ -4,7 +4,6 @@
|
|||||||
use arrow::RecordBatchStream;
|
use arrow::RecordBatchStream;
|
||||||
use connection::{Connection, connect};
|
use connection::{Connection, connect};
|
||||||
use env_logger::Env;
|
use env_logger::Env;
|
||||||
use expr::{PyExpr, expr_col, expr_func, expr_lit};
|
|
||||||
use index::IndexConfig;
|
use index::IndexConfig;
|
||||||
use permutation::{PyAsyncPermutationBuilder, PyPermutationReader};
|
use permutation::{PyAsyncPermutationBuilder, PyPermutationReader};
|
||||||
use pyo3::{
|
use pyo3::{
|
||||||
@@ -22,7 +21,6 @@ use table::{
|
|||||||
pub mod arrow;
|
pub mod arrow;
|
||||||
pub mod connection;
|
pub mod connection;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod expr;
|
|
||||||
pub mod header;
|
pub mod header;
|
||||||
pub mod index;
|
pub mod index;
|
||||||
pub mod namespace;
|
pub mod namespace;
|
||||||
@@ -57,14 +55,10 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|||||||
m.add_class::<UpdateResult>()?;
|
m.add_class::<UpdateResult>()?;
|
||||||
m.add_class::<PyAsyncPermutationBuilder>()?;
|
m.add_class::<PyAsyncPermutationBuilder>()?;
|
||||||
m.add_class::<PyPermutationReader>()?;
|
m.add_class::<PyPermutationReader>()?;
|
||||||
m.add_class::<PyExpr>()?;
|
|
||||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
|
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(query::fts_query_to_json, m)?)?;
|
m.add_function(wrap_pyfunction!(query::fts_query_to_json, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(expr_col, m)?)?;
|
|
||||||
m.add_function(wrap_pyfunction!(expr_lit, m)?)?;
|
|
||||||
m.add_function(wrap_pyfunction!(expr_func, m)?)?;
|
|
||||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,10 +35,12 @@ use pyo3::types::PyList;
|
|||||||
use pyo3::types::{PyDict, PyString};
|
use pyo3::types::{PyDict, PyString};
|
||||||
use pyo3::{FromPyObject, exceptions::PyRuntimeError};
|
use pyo3::{FromPyObject, exceptions::PyRuntimeError};
|
||||||
use pyo3::{PyErr, pyclass};
|
use pyo3::{PyErr, pyclass};
|
||||||
use pyo3::{exceptions::PyValueError, intern};
|
use pyo3::{
|
||||||
|
exceptions::{PyNotImplementedError, PyValueError},
|
||||||
|
intern,
|
||||||
|
};
|
||||||
use pyo3_async_runtimes::tokio::future_into_py;
|
use pyo3_async_runtimes::tokio::future_into_py;
|
||||||
|
|
||||||
use crate::expr::PyExpr;
|
|
||||||
use crate::util::parse_distance_type;
|
use crate::util::parse_distance_type;
|
||||||
use crate::{arrow::RecordBatchStream, util::PyLanceDB};
|
use crate::{arrow::RecordBatchStream, util::PyLanceDB};
|
||||||
use crate::{error::PythonErrorExt, index::class_name};
|
use crate::{error::PythonErrorExt, index::class_name};
|
||||||
@@ -342,13 +344,9 @@ impl<'py> IntoPyObject<'py> for PyQueryFilter {
|
|||||||
|
|
||||||
fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult<Self::Output> {
|
fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult<Self::Output> {
|
||||||
match self.0 {
|
match self.0 {
|
||||||
QueryFilter::Datafusion(expr) => {
|
QueryFilter::Datafusion(_) => Err(PyNotImplementedError::new_err(
|
||||||
// Serialize the DataFusion expression to a SQL string so that
|
"Datafusion filter has no conversion to Python",
|
||||||
// callers (e.g. remote tables) see the same format as Sql.
|
)),
|
||||||
let sql = lancedb::expr::expr_to_sql_string(&expr)
|
|
||||||
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
|
|
||||||
Ok(sql.into_pyobject(py)?.into_any())
|
|
||||||
}
|
|
||||||
QueryFilter::Sql(sql) => Ok(sql.into_pyobject(py)?.into_any()),
|
QueryFilter::Sql(sql) => Ok(sql.into_pyobject(py)?.into_any()),
|
||||||
QueryFilter::Substrait(substrait) => Ok(substrait.into_pyobject(py)?.into_any()),
|
QueryFilter::Substrait(substrait) => Ok(substrait.into_pyobject(py)?.into_any()),
|
||||||
}
|
}
|
||||||
@@ -372,20 +370,10 @@ impl Query {
|
|||||||
self.inner = self.inner.clone().only_if(predicate);
|
self.inner = self.inner.clone().only_if(predicate);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn where_expr(&mut self, expr: PyExpr) {
|
|
||||||
self.inner = self.inner.clone().only_if_expr(expr.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
|
|
||||||
let pairs: Vec<(String, lancedb::expr::DfExpr)> =
|
|
||||||
columns.into_iter().map(|(name, e)| (name, e.0)).collect();
|
|
||||||
self.inner = self.inner.clone().select(Select::Expr(pairs));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||||
self.inner = self.inner.clone().select(Select::columns(&columns));
|
self.inner = self.inner.clone().select(Select::columns(&columns));
|
||||||
}
|
}
|
||||||
@@ -619,20 +607,10 @@ impl FTSQuery {
|
|||||||
self.inner = self.inner.clone().only_if(predicate);
|
self.inner = self.inner.clone().only_if(predicate);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn where_expr(&mut self, expr: PyExpr) {
|
|
||||||
self.inner = self.inner.clone().only_if_expr(expr.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
|
|
||||||
let pairs: Vec<(String, lancedb::expr::DfExpr)> =
|
|
||||||
columns.into_iter().map(|(name, e)| (name, e.0)).collect();
|
|
||||||
self.inner = self.inner.clone().select(Select::Expr(pairs));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||||
self.inner = self.inner.clone().select(Select::columns(&columns));
|
self.inner = self.inner.clone().select(Select::columns(&columns));
|
||||||
}
|
}
|
||||||
@@ -747,10 +725,6 @@ impl VectorQuery {
|
|||||||
self.inner = self.inner.clone().only_if(predicate);
|
self.inner = self.inner.clone().only_if(predicate);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn where_expr(&mut self, expr: PyExpr) {
|
|
||||||
self.inner = self.inner.clone().only_if_expr(expr.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn add_query_vector(&mut self, vector: Bound<'_, PyAny>) -> PyResult<()> {
|
pub fn add_query_vector(&mut self, vector: Bound<'_, PyAny>) -> PyResult<()> {
|
||||||
let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?;
|
let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?;
|
||||||
let array = make_array(data);
|
let array = make_array(data);
|
||||||
@@ -762,12 +736,6 @@ impl VectorQuery {
|
|||||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
|
|
||||||
let pairs: Vec<(String, lancedb::expr::DfExpr)> =
|
|
||||||
columns.into_iter().map(|(name, e)| (name, e.0)).collect();
|
|
||||||
self.inner = self.inner.clone().select(Select::Expr(pairs));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||||
self.inner = self.inner.clone().select(Select::columns(&columns));
|
self.inner = self.inner.clone().select(Select::columns(&columns));
|
||||||
}
|
}
|
||||||
@@ -922,21 +890,11 @@ impl HybridQuery {
|
|||||||
self.inner_fts.r#where(predicate);
|
self.inner_fts.r#where(predicate);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn where_expr(&mut self, expr: PyExpr) {
|
|
||||||
self.inner_vec.where_expr(expr.clone());
|
|
||||||
self.inner_fts.where_expr(expr);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||||
self.inner_vec.select(columns.clone());
|
self.inner_vec.select(columns.clone());
|
||||||
self.inner_fts.select(columns);
|
self.inner_fts.select(columns);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
|
|
||||||
self.inner_vec.select_expr(columns.clone());
|
|
||||||
self.inner_fts.select_expr(columns);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||||
self.inner_vec.select_columns(columns.clone());
|
self.inner_vec.select_columns(columns.clone());
|
||||||
self.inner_fts.select_columns(columns);
|
self.inner_fts.select_columns(columns);
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ use lancedb::table::{
|
|||||||
Table as LanceDbTable,
|
Table as LanceDbTable,
|
||||||
};
|
};
|
||||||
use pyo3::{
|
use pyo3::{
|
||||||
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
|
||||||
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
|
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
|
||||||
pyclass, pymethods,
|
pyclass, pymethods,
|
||||||
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
|
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
|
||||||
@@ -299,12 +299,10 @@ impl Table {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyo3(signature = (data, mode, progress=None))]
|
|
||||||
pub fn add<'a>(
|
pub fn add<'a>(
|
||||||
self_: PyRef<'a, Self>,
|
self_: PyRef<'a, Self>,
|
||||||
data: PyScannable,
|
data: PyScannable,
|
||||||
mode: String,
|
mode: String,
|
||||||
progress: Option<Py<PyAny>>,
|
|
||||||
) -> PyResult<Bound<'a, PyAny>> {
|
) -> PyResult<Bound<'a, PyAny>> {
|
||||||
let mut op = self_.inner_ref()?.add(data);
|
let mut op = self_.inner_ref()?.add(data);
|
||||||
if mode == "append" {
|
if mode == "append" {
|
||||||
@@ -314,81 +312,6 @@ impl Table {
|
|||||||
} else {
|
} else {
|
||||||
return Err(PyValueError::new_err(format!("Invalid mode: {}", mode)));
|
return Err(PyValueError::new_err(format!("Invalid mode: {}", mode)));
|
||||||
}
|
}
|
||||||
if let Some(progress_obj) = progress {
|
|
||||||
let is_callable = Python::attach(|py| progress_obj.bind(py).is_callable());
|
|
||||||
if is_callable {
|
|
||||||
// Callback: call with a dict of progress info.
|
|
||||||
op = op.progress(move |p| {
|
|
||||||
Python::attach(|py| {
|
|
||||||
let dict = PyDict::new(py);
|
|
||||||
if let Err(e) = dict
|
|
||||||
.set_item("output_rows", p.output_rows())
|
|
||||||
.and_then(|_| dict.set_item("output_bytes", p.output_bytes()))
|
|
||||||
.and_then(|_| dict.set_item("total_rows", p.total_rows()))
|
|
||||||
.and_then(|_| {
|
|
||||||
dict.set_item("elapsed_seconds", p.elapsed().as_secs_f64())
|
|
||||||
})
|
|
||||||
.and_then(|_| dict.set_item("active_tasks", p.active_tasks()))
|
|
||||||
.and_then(|_| dict.set_item("total_tasks", p.total_tasks()))
|
|
||||||
.and_then(|_| dict.set_item("done", p.done()))
|
|
||||||
{
|
|
||||||
log::warn!("progress dict error: {e}");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if let Err(e) = progress_obj.call1(py, (dict,)) {
|
|
||||||
log::warn!("progress callback error: {e}");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
// tqdm-like: has update() method.
|
|
||||||
let mut last_rows: usize = 0;
|
|
||||||
let mut total_set = false;
|
|
||||||
op = op.progress(move |p| {
|
|
||||||
let current = p.output_rows();
|
|
||||||
let prev = last_rows;
|
|
||||||
last_rows = current;
|
|
||||||
Python::attach(|py| {
|
|
||||||
if let Some(total) = p.total_rows()
|
|
||||||
&& !total_set
|
|
||||||
{
|
|
||||||
if let Err(e) = progress_obj.setattr(py, "total", total) {
|
|
||||||
log::warn!("progress setattr error: {e}");
|
|
||||||
}
|
|
||||||
total_set = true;
|
|
||||||
}
|
|
||||||
let delta = current.saturating_sub(prev);
|
|
||||||
if delta > 0 {
|
|
||||||
if let Err(e) = progress_obj.call_method1(py, "update", (delta,)) {
|
|
||||||
log::warn!("progress update error: {e}");
|
|
||||||
}
|
|
||||||
// Show throughput and active workers in tqdm postfix.
|
|
||||||
let elapsed = p.elapsed().as_secs_f64();
|
|
||||||
if elapsed > 0.0 {
|
|
||||||
let mb_per_sec = p.output_bytes() as f64 / elapsed / 1_000_000.0;
|
|
||||||
let postfix = format!(
|
|
||||||
"{:.1} MB/s | {}/{} workers",
|
|
||||||
mb_per_sec,
|
|
||||||
p.active_tasks(),
|
|
||||||
p.total_tasks()
|
|
||||||
);
|
|
||||||
if let Err(e) =
|
|
||||||
progress_obj.call_method1(py, "set_postfix_str", (postfix,))
|
|
||||||
{
|
|
||||||
log::warn!("progress set_postfix_str error: {e}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if p.done() {
|
|
||||||
// Force a final refresh so the bar shows completion.
|
|
||||||
if let Err(e) = progress_obj.call_method0(py, "refresh") {
|
|
||||||
log::warn!("progress refresh error: {e}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
future_into_py(self_.py(), async move {
|
future_into_py(self_.py(), async move {
|
||||||
let result = op.execute().await.infer_error()?;
|
let result = op.execute().await.infer_error()?;
|
||||||
|
|||||||
@@ -1,387 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
||||||
|
|
||||||
"""Tests for the type-safe expression builder API."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import pyarrow as pa
|
|
||||||
import lancedb
|
|
||||||
from lancedb.expr import Expr, col, lit, func
|
|
||||||
|
|
||||||
|
|
||||||
# ── unit tests for Expr construction ─────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestExprConstruction:
|
|
||||||
def test_col_returns_expr(self):
|
|
||||||
e = col("age")
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
|
|
||||||
def test_lit_int(self):
|
|
||||||
e = lit(42)
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
|
|
||||||
def test_lit_float(self):
|
|
||||||
e = lit(3.14)
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
|
|
||||||
def test_lit_str(self):
|
|
||||||
e = lit("hello")
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
|
|
||||||
def test_lit_bool(self):
|
|
||||||
e = lit(True)
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
|
|
||||||
def test_lit_unsupported_type_raises(self):
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
lit([1, 2, 3])
|
|
||||||
|
|
||||||
def test_func(self):
|
|
||||||
e = func("lower", col("name"))
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "lower(name)"
|
|
||||||
|
|
||||||
def test_func_unknown_raises(self):
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
func("not_a_real_function", col("x"))
|
|
||||||
|
|
||||||
|
|
||||||
class TestExprOperators:
|
|
||||||
def test_eq_operator(self):
|
|
||||||
e = col("x") == lit(1)
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(x = 1)"
|
|
||||||
|
|
||||||
def test_ne_operator(self):
|
|
||||||
e = col("x") != lit(1)
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(x <> 1)"
|
|
||||||
|
|
||||||
def test_lt_operator(self):
|
|
||||||
e = col("age") < lit(18)
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(age < 18)"
|
|
||||||
|
|
||||||
def test_le_operator(self):
|
|
||||||
e = col("age") <= lit(18)
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(age <= 18)"
|
|
||||||
|
|
||||||
def test_gt_operator(self):
|
|
||||||
e = col("age") > lit(18)
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(age > 18)"
|
|
||||||
|
|
||||||
def test_ge_operator(self):
|
|
||||||
e = col("age") >= lit(18)
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(age >= 18)"
|
|
||||||
|
|
||||||
def test_and_operator(self):
|
|
||||||
e = (col("age") > lit(18)) & (col("status") == lit("active"))
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "((age > 18) AND (status = 'active'))"
|
|
||||||
|
|
||||||
def test_or_operator(self):
|
|
||||||
e = (col("a") == lit(1)) | (col("b") == lit(2))
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "((a = 1) OR (b = 2))"
|
|
||||||
|
|
||||||
def test_invert_operator(self):
|
|
||||||
e = ~(col("active") == lit(True))
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "NOT (active = true)"
|
|
||||||
|
|
||||||
def test_add_operator(self):
|
|
||||||
e = col("x") + lit(1)
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(x + 1)"
|
|
||||||
|
|
||||||
def test_sub_operator(self):
|
|
||||||
e = col("x") - lit(1)
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(x - 1)"
|
|
||||||
|
|
||||||
def test_mul_operator(self):
|
|
||||||
e = col("price") * lit(1.1)
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(price * 1.1)"
|
|
||||||
|
|
||||||
def test_div_operator(self):
|
|
||||||
e = col("total") / lit(2)
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(total / 2)"
|
|
||||||
|
|
||||||
def test_radd(self):
|
|
||||||
e = lit(1) + col("x")
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(1 + x)"
|
|
||||||
|
|
||||||
def test_rmul(self):
|
|
||||||
e = lit(2) * col("x")
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(2 * x)"
|
|
||||||
|
|
||||||
def test_coerce_plain_int(self):
|
|
||||||
# Operators should auto-wrap plain Python values via lit()
|
|
||||||
e = col("age") > 18
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(age > 18)"
|
|
||||||
|
|
||||||
def test_coerce_plain_str(self):
|
|
||||||
e = col("name") == "alice"
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(name = 'alice')"
|
|
||||||
|
|
||||||
|
|
||||||
class TestExprStringMethods:
|
|
||||||
def test_lower(self):
|
|
||||||
e = col("name").lower()
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "lower(name)"
|
|
||||||
|
|
||||||
def test_upper(self):
|
|
||||||
e = col("name").upper()
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "upper(name)"
|
|
||||||
|
|
||||||
def test_contains(self):
|
|
||||||
e = col("text").contains(lit("hello"))
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "contains(text, 'hello')"
|
|
||||||
|
|
||||||
def test_contains_with_str_coerce(self):
|
|
||||||
e = col("text").contains("hello")
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "contains(text, 'hello')"
|
|
||||||
|
|
||||||
def test_chained_lower_eq(self):
|
|
||||||
e = col("name").lower() == lit("alice")
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(lower(name) = 'alice')"
|
|
||||||
|
|
||||||
|
|
||||||
class TestExprCast:
|
|
||||||
def test_cast_string(self):
|
|
||||||
e = col("id").cast("string")
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "CAST(id AS VARCHAR)"
|
|
||||||
|
|
||||||
def test_cast_int32(self):
|
|
||||||
e = col("score").cast("int32")
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "CAST(score AS INTEGER)"
|
|
||||||
|
|
||||||
def test_cast_float64(self):
|
|
||||||
e = col("val").cast("float64")
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "CAST(val AS DOUBLE)"
|
|
||||||
|
|
||||||
def test_cast_pyarrow_type(self):
|
|
||||||
e = col("score").cast(pa.int32())
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "CAST(score AS INTEGER)"
|
|
||||||
|
|
||||||
def test_cast_pyarrow_float64(self):
|
|
||||||
e = col("val").cast(pa.float64())
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "CAST(val AS DOUBLE)"
|
|
||||||
|
|
||||||
def test_cast_pyarrow_string(self):
|
|
||||||
e = col("id").cast(pa.string())
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "CAST(id AS VARCHAR)"
|
|
||||||
|
|
||||||
def test_cast_pyarrow_and_string_equivalent(self):
|
|
||||||
# pa.int32() and "int32" should produce equivalent SQL
|
|
||||||
sql_str = col("x").cast("int32").to_sql()
|
|
||||||
sql_pa = col("x").cast(pa.int32()).to_sql()
|
|
||||||
assert sql_str == sql_pa
|
|
||||||
|
|
||||||
|
|
||||||
class TestExprNamedMethods:
|
|
||||||
def test_eq_method(self):
|
|
||||||
e = col("x").eq(lit(1))
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(x = 1)"
|
|
||||||
|
|
||||||
def test_gt_method(self):
|
|
||||||
e = col("x").gt(lit(0))
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "(x > 0)"
|
|
||||||
|
|
||||||
def test_and_method(self):
|
|
||||||
e = col("x").gt(lit(0)).and_(col("y").lt(lit(10)))
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "((x > 0) AND (y < 10))"
|
|
||||||
|
|
||||||
def test_or_method(self):
|
|
||||||
e = col("x").eq(lit(1)).or_(col("x").eq(lit(2)))
|
|
||||||
assert isinstance(e, Expr)
|
|
||||||
assert e.to_sql() == "((x = 1) OR (x = 2))"
|
|
||||||
|
|
||||||
|
|
||||||
class TestExprRepr:
|
|
||||||
def test_repr(self):
|
|
||||||
e = col("age") > lit(18)
|
|
||||||
assert repr(e) == "Expr((age > 18))"
|
|
||||||
|
|
||||||
def test_to_sql(self):
|
|
||||||
e = col("age") > 18
|
|
||||||
assert e.to_sql() == "(age > 18)"
|
|
||||||
|
|
||||||
def test_unhashable(self):
|
|
||||||
e = col("x")
|
|
||||||
with pytest.raises(TypeError):
|
|
||||||
{e: 1}
|
|
||||||
|
|
||||||
|
|
||||||
# ── integration tests: end-to-end query against a real table ─────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def simple_table(tmp_path):
|
|
||||||
db = lancedb.connect(str(tmp_path))
|
|
||||||
data = pa.table(
|
|
||||||
{
|
|
||||||
"id": [1, 2, 3, 4, 5],
|
|
||||||
"name": ["Alice", "Bob", "Charlie", "alice", "BOB"],
|
|
||||||
"age": [25, 17, 30, 22, 15],
|
|
||||||
"score": [1.5, 2.0, 3.5, 4.0, 0.5],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return db.create_table("test", data)
|
|
||||||
|
|
||||||
|
|
||||||
class TestExprFilter:
|
|
||||||
def test_simple_gt_filter(self, simple_table):
|
|
||||||
result = simple_table.search().where(col("age") > lit(20)).to_arrow()
|
|
||||||
assert result.num_rows == 3 # ages 25, 30, 22
|
|
||||||
|
|
||||||
def test_compound_and_filter(self, simple_table):
|
|
||||||
result = (
|
|
||||||
simple_table.search()
|
|
||||||
.where((col("age") > lit(18)) & (col("score") > lit(2.0)))
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert result.num_rows == 2 # (30, 3.5) and (22, 4.0)
|
|
||||||
|
|
||||||
def test_string_equality_filter(self, simple_table):
|
|
||||||
result = simple_table.search().where(col("name") == lit("Bob")).to_arrow()
|
|
||||||
assert result.num_rows == 1
|
|
||||||
|
|
||||||
def test_or_filter(self, simple_table):
|
|
||||||
result = (
|
|
||||||
simple_table.search()
|
|
||||||
.where((col("age") < lit(18)) | (col("age") > lit(28)))
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert result.num_rows == 3 # ages 17, 30, 15
|
|
||||||
|
|
||||||
def test_coercion_no_lit(self, simple_table):
|
|
||||||
# Python values should be auto-coerced
|
|
||||||
result = simple_table.search().where(col("age") > 20).to_arrow()
|
|
||||||
assert result.num_rows == 3
|
|
||||||
|
|
||||||
def test_string_sql_still_works(self, simple_table):
|
|
||||||
# Backwards compatibility: plain strings still accepted
|
|
||||||
result = simple_table.search().where("age > 20").to_arrow()
|
|
||||||
assert result.num_rows == 3
|
|
||||||
|
|
||||||
|
|
||||||
class TestExprProjection:
|
|
||||||
def test_select_with_expr(self, simple_table):
|
|
||||||
result = (
|
|
||||||
simple_table.search()
|
|
||||||
.select({"double_score": col("score") * lit(2)})
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert "double_score" in result.schema.names
|
|
||||||
|
|
||||||
def test_select_mixed_str_and_expr(self, simple_table):
|
|
||||||
result = (
|
|
||||||
simple_table.search()
|
|
||||||
.select({"id": "id", "double_score": col("score") * lit(2)})
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert "id" in result.schema.names
|
|
||||||
assert "double_score" in result.schema.names
|
|
||||||
|
|
||||||
def test_select_list_of_columns(self, simple_table):
|
|
||||||
# Plain list of str still works
|
|
||||||
result = simple_table.search().select(["id", "name"]).to_arrow()
|
|
||||||
assert result.schema.names == ["id", "name"]
|
|
||||||
|
|
||||||
|
|
||||||
# ── column name edge cases ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestColNaming:
|
|
||||||
"""Unit tests verifying that col() preserves identifiers exactly.
|
|
||||||
|
|
||||||
Identifiers that need quoting (camelCase, spaces, leading digits, unicode)
|
|
||||||
are wrapped in backticks to match the lance SQL parser's dialect.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_camel_case_preserved_in_sql(self):
|
|
||||||
# camelCase is quoted with backticks so the case round-trips correctly.
|
|
||||||
assert col("firstName").to_sql() == "`firstName`"
|
|
||||||
|
|
||||||
def test_camel_case_in_expression(self):
|
|
||||||
assert (col("firstName") > lit(18)).to_sql() == "(`firstName` > 18)"
|
|
||||||
|
|
||||||
def test_space_in_name_quoted(self):
|
|
||||||
assert col("first name").to_sql() == "`first name`"
|
|
||||||
|
|
||||||
def test_space_in_expression(self):
|
|
||||||
assert (col("first name") == lit("A")).to_sql() == "(`first name` = 'A')"
|
|
||||||
|
|
||||||
def test_leading_digit_quoted(self):
|
|
||||||
assert col("2fast").to_sql() == "`2fast`"
|
|
||||||
|
|
||||||
def test_unicode_quoted(self):
|
|
||||||
assert col("名前").to_sql() == "`名前`"
|
|
||||||
|
|
||||||
def test_snake_case_unquoted(self):
|
|
||||||
# Plain snake_case needs no quoting.
|
|
||||||
assert col("first_name").to_sql() == "first_name"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def special_col_table(tmp_path):
|
|
||||||
db = lancedb.connect(str(tmp_path))
|
|
||||||
data = pa.table(
|
|
||||||
{
|
|
||||||
"firstName": ["Alice", "Bob", "Charlie"],
|
|
||||||
"first name": ["A", "B", "C"],
|
|
||||||
"score": [10, 20, 30],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return db.create_table("special", data)
|
|
||||||
|
|
||||||
|
|
||||||
class TestColNamingIntegration:
|
|
||||||
def test_camel_case_filter(self, special_col_table):
|
|
||||||
result = (
|
|
||||||
special_col_table.search()
|
|
||||||
.where(col("firstName") == lit("Alice"))
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert result.num_rows == 1
|
|
||||||
assert result["firstName"][0].as_py() == "Alice"
|
|
||||||
|
|
||||||
def test_space_in_col_filter(self, special_col_table):
|
|
||||||
result = (
|
|
||||||
special_col_table.search().where(col("first name") == lit("B")).to_arrow()
|
|
||||||
)
|
|
||||||
assert result.num_rows == 1
|
|
||||||
|
|
||||||
def test_camel_case_projection(self, special_col_table):
|
|
||||||
result = (
|
|
||||||
special_col_table.search()
|
|
||||||
.select({"upper_name": col("firstName").upper()})
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert "upper_name" in result.schema.names
|
|
||||||
assert sorted(result["upper_name"].to_pylist()) == ["ALICE", "BOB", "CHARLIE"]
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.27.2-beta.1"
|
version = "0.27.0-beta.5"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|||||||
@@ -596,8 +596,11 @@ pub struct ConnectBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "remote")]
|
#[cfg(feature = "remote")]
|
||||||
const ENV_VARS_TO_STORAGE_OPTS: [(&str, &str); 1] =
|
const ENV_VARS_TO_STORAGE_OPTS: [(&str, &str); 3] = [
|
||||||
[("AZURE_STORAGE_ACCOUNT_NAME", "azure_storage_account_name")];
|
("AZURE_STORAGE_ACCOUNT_NAME", "azure_storage_account_name"),
|
||||||
|
("AZURE_CLIENT_ID", "azure_client_id"),
|
||||||
|
("AZURE_TENANT_ID", "azure_tenant_id"),
|
||||||
|
];
|
||||||
|
|
||||||
impl ConnectBuilder {
|
impl ConnectBuilder {
|
||||||
/// Create a new [`ConnectOptions`] with the given database URI.
|
/// Create a new [`ConnectOptions`] with the given database URI.
|
||||||
|
|||||||
@@ -240,7 +240,7 @@ impl Shuffler {
|
|||||||
.await?;
|
.await?;
|
||||||
// Need to read the entire file in a single batch for in-memory shuffling
|
// Need to read the entire file in a single batch for in-memory shuffling
|
||||||
let batch = reader.read_record_batch(0, reader.num_rows()).await?;
|
let batch = reader.read_record_batch(0, reader.num_rows()).await?;
|
||||||
let mut rng = rng.lock().unwrap_or_else(|e| e.into_inner());
|
let mut rng = rng.lock().unwrap();
|
||||||
Self::shuffle_batch(&batch, &mut rng, clump_size)
|
Self::shuffle_batch(&batch, &mut rng, clump_size)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -27,17 +27,7 @@ use arrow_schema::DataType;
|
|||||||
use datafusion_expr::{Expr, ScalarUDF, expr_fn::cast};
|
use datafusion_expr::{Expr, ScalarUDF, expr_fn::cast};
|
||||||
use datafusion_functions::string::expr_fn as string_expr_fn;
|
use datafusion_functions::string::expr_fn as string_expr_fn;
|
||||||
|
|
||||||
pub use datafusion_expr::lit;
|
pub use datafusion_expr::{col, lit};
|
||||||
|
|
||||||
/// Create a column reference expression, preserving the name exactly as given.
|
|
||||||
///
|
|
||||||
/// Unlike DataFusion's built-in [`col`][datafusion_expr::col], this function
|
|
||||||
/// does **not** normalise the identifier to lower-case, so
|
|
||||||
/// `col("firstName")` correctly references a field named `firstName`.
|
|
||||||
pub fn col(name: impl Into<String>) -> DfExpr {
|
|
||||||
use datafusion_common::Column;
|
|
||||||
DfExpr::Column(Column::new_unqualified(name))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub use datafusion_expr::Expr as DfExpr;
|
pub use datafusion_expr::Expr as DfExpr;
|
||||||
|
|
||||||
|
|||||||
@@ -2,37 +2,11 @@
|
|||||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
use datafusion_expr::Expr;
|
use datafusion_expr::Expr;
|
||||||
use datafusion_sql::unparser::{self, dialect::Dialect};
|
use datafusion_sql::unparser;
|
||||||
|
|
||||||
/// Unparser dialect that matches the quoting style expected by the Lance SQL
|
|
||||||
/// parser. Lance uses backtick (`` ` ``) as the only delimited-identifier
|
|
||||||
/// quote character, so we must produce `` `firstName` `` rather than
|
|
||||||
/// `"firstName"` for identifiers that require quoting.
|
|
||||||
///
|
|
||||||
/// We quote an identifier when it:
|
|
||||||
/// * is a SQL reserved word, OR
|
|
||||||
/// * contains characters outside `[a-zA-Z0-9_]`, OR
|
|
||||||
/// * starts with a digit, OR
|
|
||||||
/// * contains upper-case letters (unquoted identifiers are normalised to
|
|
||||||
/// lower-case by the SQL parser, which would break case-sensitive schemas).
|
|
||||||
struct LanceSqlDialect;
|
|
||||||
|
|
||||||
impl Dialect for LanceSqlDialect {
|
|
||||||
fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
|
|
||||||
let needs_quote = identifier.chars().any(|c| c.is_ascii_uppercase())
|
|
||||||
|| !identifier
|
|
||||||
.chars()
|
|
||||||
.enumerate()
|
|
||||||
.all(|(i, c)| c == '_' || c.is_ascii_alphabetic() || (i > 0 && c.is_ascii_digit()));
|
|
||||||
if needs_quote { Some('`') } else { None }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn expr_to_sql_string(expr: &Expr) -> crate::Result<String> {
|
pub fn expr_to_sql_string(expr: &Expr) -> crate::Result<String> {
|
||||||
let ast = unparser::Unparser::new(&LanceSqlDialect)
|
let ast = unparser::expr_to_sql(expr).map_err(|e| crate::Error::InvalidInput {
|
||||||
.expr_to_sql(expr)
|
message: format!("failed to serialize expression to SQL: {}", e),
|
||||||
.map_err(|e| crate::Error::InvalidInput {
|
})?;
|
||||||
message: format!("failed to serialize expression to SQL: {}", e),
|
|
||||||
})?;
|
|
||||||
Ok(ast.to_string())
|
Ok(ast.to_string())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -66,13 +66,13 @@ impl IoTrackingStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn record_read(&self, num_bytes: u64) {
|
fn record_read(&self, num_bytes: u64) {
|
||||||
let mut stats = self.stats.lock().unwrap_or_else(|e| e.into_inner());
|
let mut stats = self.stats.lock().unwrap();
|
||||||
stats.read_iops += 1;
|
stats.read_iops += 1;
|
||||||
stats.read_bytes += num_bytes;
|
stats.read_bytes += num_bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn record_write(&self, num_bytes: u64) {
|
fn record_write(&self, num_bytes: u64) {
|
||||||
let mut stats = self.stats.lock().unwrap_or_else(|e| e.into_inner());
|
let mut stats = self.stats.lock().unwrap();
|
||||||
stats.write_iops += 1;
|
stats.write_iops += 1;
|
||||||
stats.write_bytes += num_bytes;
|
stats.write_bytes += num_bytes;
|
||||||
}
|
}
|
||||||
@@ -229,63 +229,10 @@ impl MultipartUpload for IoTrackingMultipartUpload {
|
|||||||
|
|
||||||
fn put_part(&mut self, payload: PutPayload) -> UploadPart {
|
fn put_part(&mut self, payload: PutPayload) -> UploadPart {
|
||||||
{
|
{
|
||||||
let mut stats = self.stats.lock().unwrap_or_else(|e| e.into_inner());
|
let mut stats = self.stats.lock().unwrap();
|
||||||
stats.write_iops += 1;
|
stats.write_iops += 1;
|
||||||
stats.write_bytes += payload.content_length() as u64;
|
stats.write_bytes += payload.content_length() as u64;
|
||||||
}
|
}
|
||||||
self.target.put_part(payload)
|
self.target.put_part(payload)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
/// Helper: poison a Mutex<IoStats> by panicking while holding the lock.
|
|
||||||
fn poison_stats(stats: &Arc<Mutex<IoStats>>) {
|
|
||||||
let stats_clone = stats.clone();
|
|
||||||
let handle = std::thread::spawn(move || {
|
|
||||||
let _guard = stats_clone.lock().unwrap();
|
|
||||||
panic!("intentional panic to poison stats mutex");
|
|
||||||
});
|
|
||||||
let _ = handle.join();
|
|
||||||
assert!(stats.lock().is_err(), "mutex should be poisoned");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_record_read_recovers_from_poisoned_lock() {
|
|
||||||
let stats = Arc::new(Mutex::new(IoStats::default()));
|
|
||||||
let store = IoTrackingStore {
|
|
||||||
target: Arc::new(object_store::memory::InMemory::new()),
|
|
||||||
stats: stats.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
poison_stats(&stats);
|
|
||||||
|
|
||||||
// record_read should not panic
|
|
||||||
store.record_read(1024);
|
|
||||||
|
|
||||||
// Verify the stats were updated despite poisoning
|
|
||||||
let s = stats.lock().unwrap_or_else(|e| e.into_inner());
|
|
||||||
assert_eq!(s.read_iops, 1);
|
|
||||||
assert_eq!(s.read_bytes, 1024);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_record_write_recovers_from_poisoned_lock() {
|
|
||||||
let stats = Arc::new(Mutex::new(IoStats::default()));
|
|
||||||
let store = IoTrackingStore {
|
|
||||||
target: Arc::new(object_store::memory::InMemory::new()),
|
|
||||||
stats: stats.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
poison_stats(&stats);
|
|
||||||
|
|
||||||
// record_write should not panic
|
|
||||||
store.record_write(2048);
|
|
||||||
|
|
||||||
let s = stats.lock().unwrap_or_else(|e| e.into_inner());
|
|
||||||
assert_eq!(s.write_iops, 1);
|
|
||||||
assert_eq!(s.write_bytes, 2048);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ use std::sync::Arc;
|
|||||||
use std::{future::Future, time::Duration};
|
use std::{future::Future, time::Duration};
|
||||||
|
|
||||||
use arrow::compute::concat_batches;
|
use arrow::compute::concat_batches;
|
||||||
use arrow_array::{Array, Float16Array, Float32Array, Float64Array, RecordBatch, make_array};
|
use arrow_array::{Array, Float16Array, Float32Array, Float64Array, make_array};
|
||||||
use arrow_schema::{DataType, SchemaRef};
|
use arrow_schema::{DataType, SchemaRef};
|
||||||
use datafusion_expr::Expr;
|
use datafusion_expr::Expr;
|
||||||
use datafusion_physical_plan::ExecutionPlan;
|
use datafusion_physical_plan::ExecutionPlan;
|
||||||
@@ -17,17 +17,15 @@ use lance_datafusion::exec::execute_plan;
|
|||||||
use lance_index::scalar::FullTextSearchQuery;
|
use lance_index::scalar::FullTextSearchQuery;
|
||||||
use lance_index::scalar::inverted::SCORE_COL;
|
use lance_index::scalar::inverted::SCORE_COL;
|
||||||
use lance_index::vector::DIST_COL;
|
use lance_index::vector::DIST_COL;
|
||||||
|
use lance_io::stream::RecordBatchStreamAdapter;
|
||||||
|
|
||||||
use crate::DistanceType;
|
use crate::DistanceType;
|
||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
use crate::rerankers::rrf::RRFReranker;
|
use crate::rerankers::rrf::RRFReranker;
|
||||||
use crate::rerankers::{NormalizeMethod, Reranker, check_reranker_result};
|
use crate::rerankers::{NormalizeMethod, Reranker, check_reranker_result};
|
||||||
use crate::table::BaseTable;
|
use crate::table::BaseTable;
|
||||||
use crate::utils::{MaxBatchLengthStream, TimeoutStream};
|
use crate::utils::TimeoutStream;
|
||||||
use crate::{
|
use crate::{arrow::SendableRecordBatchStream, table::AnyQuery};
|
||||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
|
||||||
table::AnyQuery,
|
|
||||||
};
|
|
||||||
|
|
||||||
mod hybrid;
|
mod hybrid;
|
||||||
|
|
||||||
@@ -606,14 +604,6 @@ impl Default for QueryExecutionOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl QueryExecutionOptions {
|
|
||||||
fn without_output_batch_length_limit(&self) -> Self {
|
|
||||||
let mut options = self.clone();
|
|
||||||
options.max_batch_length = 0;
|
|
||||||
options
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A trait for a query object that can be executed to get results
|
/// A trait for a query object that can be executed to get results
|
||||||
///
|
///
|
||||||
/// There are various kinds of queries but they all return results
|
/// There are various kinds of queries but they all return results
|
||||||
@@ -1190,8 +1180,6 @@ impl VectorQuery {
|
|||||||
&self,
|
&self,
|
||||||
options: QueryExecutionOptions,
|
options: QueryExecutionOptions,
|
||||||
) -> Result<SendableRecordBatchStream> {
|
) -> Result<SendableRecordBatchStream> {
|
||||||
let max_batch_length = options.max_batch_length as usize;
|
|
||||||
let internal_options = options.without_output_batch_length_limit();
|
|
||||||
// clone query and specify we want to include row IDs, which can be needed for reranking
|
// clone query and specify we want to include row IDs, which can be needed for reranking
|
||||||
let mut fts_query = Query::new(self.parent.clone());
|
let mut fts_query = Query::new(self.parent.clone());
|
||||||
fts_query.request = self.request.base.clone();
|
fts_query.request = self.request.base.clone();
|
||||||
@@ -1201,8 +1189,8 @@ impl VectorQuery {
|
|||||||
|
|
||||||
vector_query.request.base.full_text_search = None;
|
vector_query.request.base.full_text_search = None;
|
||||||
let (fts_results, vec_results) = try_join!(
|
let (fts_results, vec_results) = try_join!(
|
||||||
fts_query.execute_with_options(internal_options.clone()),
|
fts_query.execute_with_options(options.clone()),
|
||||||
vector_query.inner_execute_with_options(internal_options)
|
vector_query.inner_execute_with_options(options)
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let (fts_results, vec_results) = try_join!(
|
let (fts_results, vec_results) = try_join!(
|
||||||
@@ -1257,7 +1245,9 @@ impl VectorQuery {
|
|||||||
results = results.drop_column(ROW_ID)?;
|
results = results.drop_column(ROW_ID)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(single_batch_stream(results, max_batch_length))
|
Ok(SendableRecordBatchStream::from(
|
||||||
|
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn inner_execute_with_options(
|
async fn inner_execute_with_options(
|
||||||
@@ -1266,7 +1256,6 @@ impl VectorQuery {
|
|||||||
) -> Result<SendableRecordBatchStream> {
|
) -> Result<SendableRecordBatchStream> {
|
||||||
let plan = self.create_plan(options.clone()).await?;
|
let plan = self.create_plan(options.clone()).await?;
|
||||||
let inner = execute_plan(plan, Default::default())?;
|
let inner = execute_plan(plan, Default::default())?;
|
||||||
let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize);
|
|
||||||
let inner = if let Some(timeout) = options.timeout {
|
let inner = if let Some(timeout) = options.timeout {
|
||||||
TimeoutStream::new_boxed(inner, timeout)
|
TimeoutStream::new_boxed(inner, timeout)
|
||||||
} else {
|
} else {
|
||||||
@@ -1276,25 +1265,6 @@ impl VectorQuery {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn single_batch_stream(batch: RecordBatch, max_batch_length: usize) -> SendableRecordBatchStream {
|
|
||||||
let schema = batch.schema();
|
|
||||||
if max_batch_length == 0 || batch.num_rows() <= max_batch_length {
|
|
||||||
return Box::pin(SimpleRecordBatchStream::new(
|
|
||||||
stream::iter([Ok(batch)]),
|
|
||||||
schema,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut batches = Vec::with_capacity(batch.num_rows().div_ceil(max_batch_length));
|
|
||||||
let mut offset = 0;
|
|
||||||
while offset < batch.num_rows() {
|
|
||||||
let length = (batch.num_rows() - offset).min(max_batch_length);
|
|
||||||
batches.push(Ok(batch.slice(offset, length)));
|
|
||||||
offset += length;
|
|
||||||
}
|
|
||||||
Box::pin(SimpleRecordBatchStream::new(stream::iter(batches), schema))
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ExecutableQuery for VectorQuery {
|
impl ExecutableQuery for VectorQuery {
|
||||||
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
|
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
|
||||||
let query = AnyQuery::VectorQuery(self.request.clone());
|
let query = AnyQuery::VectorQuery(self.request.clone());
|
||||||
@@ -1783,50 +1753,6 @@ mod tests {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn make_large_vector_table(tmp_dir: &tempfile::TempDir, rows: usize) -> Table {
|
|
||||||
let dataset_path = tmp_dir.path().join("large_test.lance");
|
|
||||||
let uri = dataset_path.to_str().unwrap();
|
|
||||||
|
|
||||||
let schema = Arc::new(ArrowSchema::new(vec![
|
|
||||||
ArrowField::new("id", DataType::Utf8, false),
|
|
||||||
ArrowField::new(
|
|
||||||
"vector",
|
|
||||||
DataType::FixedSizeList(
|
|
||||||
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
|
||||||
4,
|
|
||||||
),
|
|
||||||
false,
|
|
||||||
),
|
|
||||||
]));
|
|
||||||
|
|
||||||
let ids = StringArray::from_iter_values((0..rows).map(|i| format!("row-{i}")));
|
|
||||||
let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
|
||||||
(0..rows).map(|i| Some(vec![Some(i as f32), Some(1.0), Some(2.0), Some(3.0)])),
|
|
||||||
4,
|
|
||||||
);
|
|
||||||
let batch =
|
|
||||||
RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)]).unwrap();
|
|
||||||
|
|
||||||
let conn = connect(uri).execute().await.unwrap();
|
|
||||||
conn.create_table("my_table", vec![batch])
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn assert_stream_batches_at_most(
|
|
||||||
mut results: SendableRecordBatchStream,
|
|
||||||
max_batch_length: usize,
|
|
||||||
) {
|
|
||||||
let mut saw_batch = false;
|
|
||||||
while let Some(batch) = results.next().await {
|
|
||||||
let batch = batch.unwrap();
|
|
||||||
saw_batch = true;
|
|
||||||
assert!(batch.num_rows() <= max_batch_length);
|
|
||||||
}
|
|
||||||
assert!(saw_batch);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_execute_with_options() {
|
async fn test_execute_with_options() {
|
||||||
let tmp_dir = tempdir().unwrap();
|
let tmp_dir = tempdir().unwrap();
|
||||||
@@ -1846,83 +1772,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_vector_query_execute_with_options_respects_max_batch_length() {
|
|
||||||
let tmp_dir = tempdir().unwrap();
|
|
||||||
let table = make_large_vector_table(&tmp_dir, 10_000).await;
|
|
||||||
|
|
||||||
let results = table
|
|
||||||
.query()
|
|
||||||
.nearest_to(vec![0.0, 1.0, 2.0, 3.0])
|
|
||||||
.unwrap()
|
|
||||||
.limit(10_000)
|
|
||||||
.execute_with_options(QueryExecutionOptions {
|
|
||||||
max_batch_length: 100,
|
|
||||||
..Default::default()
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_stream_batches_at_most(results, 100).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_hybrid_query_execute_with_options_respects_max_batch_length() {
|
|
||||||
let tmp_dir = tempdir().unwrap();
|
|
||||||
let dataset_path = tmp_dir.path();
|
|
||||||
let conn = connect(dataset_path.to_str().unwrap())
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let dims = 2;
|
|
||||||
let rows = 512;
|
|
||||||
let schema = Arc::new(ArrowSchema::new(vec![
|
|
||||||
ArrowField::new("text", DataType::Utf8, false),
|
|
||||||
ArrowField::new(
|
|
||||||
"vector",
|
|
||||||
DataType::FixedSizeList(
|
|
||||||
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
|
||||||
dims,
|
|
||||||
),
|
|
||||||
false,
|
|
||||||
),
|
|
||||||
]));
|
|
||||||
|
|
||||||
let text = StringArray::from_iter_values((0..rows).map(|_| "match"));
|
|
||||||
let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
|
||||||
(0..rows).map(|i| Some(vec![Some(i as f32), Some(0.0)])),
|
|
||||||
dims,
|
|
||||||
);
|
|
||||||
let record_batch =
|
|
||||||
RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vectors)]).unwrap();
|
|
||||||
let table = conn
|
|
||||||
.create_table("my_table", record_batch)
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
table
|
|
||||||
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
|
|
||||||
.replace(true)
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let results = table
|
|
||||||
.query()
|
|
||||||
.full_text_search(FullTextSearchQuery::new("match".to_string()))
|
|
||||||
.limit(rows)
|
|
||||||
.nearest_to(&[0.0, 0.0])
|
|
||||||
.unwrap()
|
|
||||||
.execute_with_options(QueryExecutionOptions {
|
|
||||||
max_batch_length: 100,
|
|
||||||
..Default::default()
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_stream_batches_at_most(results, 100).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_analyze_plan() {
|
async fn test_analyze_plan() {
|
||||||
let tmp_dir = tempdir().unwrap();
|
let tmp_dir = tempdir().unwrap();
|
||||||
|
|||||||
@@ -443,13 +443,23 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
|||||||
})?,
|
})?,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if let Some(v) = options.0.get("azure_storage_account_name") {
|
// Map azure storage options to x-azure-* headers.
|
||||||
headers.insert(
|
// The option key uses underscores (e.g. "azure_client_id") while the
|
||||||
HeaderName::from_static("x-azure-storage-account-name"),
|
// header uses hyphens (e.g. "x-azure-client-id").
|
||||||
HeaderValue::from_str(v).map_err(|_| Error::InvalidInput {
|
let azure_opts: [(&str, &str); 3] = [
|
||||||
message: format!("non-ascii storage account name '{}' provided", db_name),
|
("azure_storage_account_name", "x-azure-storage-account-name"),
|
||||||
})?,
|
("azure_client_id", "x-azure-client-id"),
|
||||||
);
|
("azure_tenant_id", "x-azure-tenant-id"),
|
||||||
|
];
|
||||||
|
for (opt_key, header_name) in azure_opts {
|
||||||
|
if let Some(v) = options.0.get(opt_key) {
|
||||||
|
headers.insert(
|
||||||
|
HeaderName::from_static(header_name),
|
||||||
|
HeaderValue::from_str(v).map_err(|_| Error::InvalidInput {
|
||||||
|
message: format!("non-ascii value '{}' for option '{}'", v, opt_key),
|
||||||
|
})?,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (key, value) in &config.extra_headers {
|
for (key, value) in &config.extra_headers {
|
||||||
@@ -1072,4 +1082,34 @@ mod tests {
|
|||||||
_ => panic!("Expected Runtime error"),
|
_ => panic!("Expected Runtime error"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_headers_azure_opts() {
|
||||||
|
let mut opts = HashMap::new();
|
||||||
|
opts.insert(
|
||||||
|
"azure_storage_account_name".to_string(),
|
||||||
|
"myaccount".to_string(),
|
||||||
|
);
|
||||||
|
opts.insert("azure_client_id".to_string(), "my-client-id".to_string());
|
||||||
|
opts.insert("azure_tenant_id".to_string(), "my-tenant-id".to_string());
|
||||||
|
let remote_opts = RemoteOptions::new(opts);
|
||||||
|
|
||||||
|
let headers = RestfulLanceDbClient::<Sender>::default_headers(
|
||||||
|
"test-key",
|
||||||
|
"us-east-1",
|
||||||
|
"testdb",
|
||||||
|
false,
|
||||||
|
&remote_opts,
|
||||||
|
None,
|
||||||
|
&ClientConfig::default(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
headers.get("x-azure-storage-account-name").unwrap(),
|
||||||
|
"myaccount"
|
||||||
|
);
|
||||||
|
assert_eq!(headers.get("x-azure-client-id").unwrap(), "my-client-id");
|
||||||
|
assert_eq!(headers.get("x-azure-tenant-id").unwrap(), "my-tenant-id");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,10 +72,6 @@ impl ServerVersion {
|
|||||||
pub fn support_structural_fts(&self) -> bool {
|
pub fn support_structural_fts(&self) -> bool {
|
||||||
self.0 >= semver::Version::new(0, 3, 0)
|
self.0 >= semver::Version::new(0, 3, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn support_multipart_write(&self) -> bool {
|
|
||||||
self.0 >= semver::Version::new(0, 4, 0)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const OPT_REMOTE_PREFIX: &str = "remote_database_";
|
pub const OPT_REMOTE_PREFIX: &str = "remote_database_";
|
||||||
@@ -782,7 +778,12 @@ impl RemoteOptions {
|
|||||||
|
|
||||||
impl From<StorageOptions> for RemoteOptions {
|
impl From<StorageOptions> for RemoteOptions {
|
||||||
fn from(options: StorageOptions) -> Self {
|
fn from(options: StorageOptions) -> Self {
|
||||||
let supported_opts = vec!["account_name", "azure_storage_account_name"];
|
let supported_opts = vec![
|
||||||
|
"account_name",
|
||||||
|
"azure_storage_account_name",
|
||||||
|
"azure_client_id",
|
||||||
|
"azure_tenant_id",
|
||||||
|
];
|
||||||
let mut filtered = HashMap::new();
|
let mut filtered = HashMap::new();
|
||||||
for opt in supported_opts {
|
for opt in supported_opts {
|
||||||
if let Some(v) = options.0.get(opt) {
|
if let Some(v) = options.0.get(opt) {
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -11,14 +11,10 @@ use arrow_ipc::CompressionType;
|
|||||||
use datafusion_common::{DataFusionError, Result as DataFusionResult};
|
use datafusion_common::{DataFusionError, Result as DataFusionResult};
|
||||||
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
|
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
|
||||||
use datafusion_physical_expr::EquivalenceProperties;
|
use datafusion_physical_expr::EquivalenceProperties;
|
||||||
use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
|
|
||||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||||
use datafusion_physical_plan::{
|
use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
|
||||||
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
|
|
||||||
};
|
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use http::header::CONTENT_TYPE;
|
use http::header::CONTENT_TYPE;
|
||||||
use lance::io::exec::utils::InstrumentedRecordBatchStreamAdapter;
|
|
||||||
|
|
||||||
use crate::Error;
|
use crate::Error;
|
||||||
use crate::remote::ARROW_STREAM_CONTENT_TYPE;
|
use crate::remote::ARROW_STREAM_CONTENT_TYPE;
|
||||||
@@ -26,16 +22,13 @@ use crate::remote::client::{HttpSend, RestfulLanceDbClient, Sender};
|
|||||||
use crate::remote::table::RemoteTable;
|
use crate::remote::table::RemoteTable;
|
||||||
use crate::table::AddResult;
|
use crate::table::AddResult;
|
||||||
use crate::table::datafusion::insert::COUNT_SCHEMA;
|
use crate::table::datafusion::insert::COUNT_SCHEMA;
|
||||||
use crate::table::write_progress::WriteProgressTracker;
|
|
||||||
|
|
||||||
/// ExecutionPlan for inserting data into a remote LanceDB table.
|
/// ExecutionPlan for inserting data into a remote LanceDB table.
|
||||||
///
|
///
|
||||||
/// Streams data as Arrow IPC to `/v1/table/{id}/insert/` endpoint.
|
/// This plan:
|
||||||
///
|
/// 1. Requires single partition (no parallel remote inserts yet)
|
||||||
/// When `upload_id` is set, inserts are staged as part of a multipart write
|
/// 2. Streams data as Arrow IPC to `/v1/table/{id}/insert/` endpoint
|
||||||
/// session and the plan supports multiple partitions for parallel uploads.
|
/// 3. Stores AddResult for retrieval after execution
|
||||||
/// Without `upload_id`, the plan requires a single partition and commits
|
|
||||||
/// immediately.
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct RemoteInsertExec<S: HttpSend = Sender> {
|
pub struct RemoteInsertExec<S: HttpSend = Sender> {
|
||||||
table_name: String,
|
table_name: String,
|
||||||
@@ -45,69 +38,21 @@ pub struct RemoteInsertExec<S: HttpSend = Sender> {
|
|||||||
overwrite: bool,
|
overwrite: bool,
|
||||||
properties: PlanProperties,
|
properties: PlanProperties,
|
||||||
add_result: Arc<Mutex<Option<AddResult>>>,
|
add_result: Arc<Mutex<Option<AddResult>>>,
|
||||||
metrics: ExecutionPlanMetricsSet,
|
|
||||||
upload_id: Option<String>,
|
|
||||||
tracker: Option<Arc<WriteProgressTracker>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: HttpSend + 'static> RemoteInsertExec<S> {
|
impl<S: HttpSend + 'static> RemoteInsertExec<S> {
|
||||||
/// Create a new single-partition RemoteInsertExec.
|
/// Create a new RemoteInsertExec.
|
||||||
pub fn new(
|
pub fn new(
|
||||||
table_name: String,
|
table_name: String,
|
||||||
identifier: String,
|
identifier: String,
|
||||||
client: RestfulLanceDbClient<S>,
|
client: RestfulLanceDbClient<S>,
|
||||||
input: Arc<dyn ExecutionPlan>,
|
input: Arc<dyn ExecutionPlan>,
|
||||||
overwrite: bool,
|
overwrite: bool,
|
||||||
tracker: Option<Arc<WriteProgressTracker>>,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self::new_inner(
|
|
||||||
table_name, identifier, client, input, overwrite, None, tracker,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a multi-partition RemoteInsertExec for use with multipart writes.
|
|
||||||
///
|
|
||||||
/// Each partition's insert is staged under the given `upload_id` without
|
|
||||||
/// committing. The caller is responsible for calling the complete (or abort)
|
|
||||||
/// endpoint after all partitions finish.
|
|
||||||
pub fn new_multipart(
|
|
||||||
table_name: String,
|
|
||||||
identifier: String,
|
|
||||||
client: RestfulLanceDbClient<S>,
|
|
||||||
input: Arc<dyn ExecutionPlan>,
|
|
||||||
overwrite: bool,
|
|
||||||
upload_id: String,
|
|
||||||
tracker: Option<Arc<WriteProgressTracker>>,
|
|
||||||
) -> Self {
|
|
||||||
Self::new_inner(
|
|
||||||
table_name,
|
|
||||||
identifier,
|
|
||||||
client,
|
|
||||||
input,
|
|
||||||
overwrite,
|
|
||||||
Some(upload_id),
|
|
||||||
tracker,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_inner(
|
|
||||||
table_name: String,
|
|
||||||
identifier: String,
|
|
||||||
client: RestfulLanceDbClient<S>,
|
|
||||||
input: Arc<dyn ExecutionPlan>,
|
|
||||||
overwrite: bool,
|
|
||||||
upload_id: Option<String>,
|
|
||||||
tracker: Option<Arc<WriteProgressTracker>>,
|
|
||||||
) -> Self {
|
|
||||||
let num_partitions = if upload_id.is_some() {
|
|
||||||
input.output_partitioning().partition_count()
|
|
||||||
} else {
|
|
||||||
1
|
|
||||||
};
|
|
||||||
let schema = COUNT_SCHEMA.clone();
|
let schema = COUNT_SCHEMA.clone();
|
||||||
let properties = PlanProperties::new(
|
let properties = PlanProperties::new(
|
||||||
EquivalenceProperties::new(schema),
|
EquivalenceProperties::new(schema),
|
||||||
datafusion_physical_plan::Partitioning::UnknownPartitioning(num_partitions),
|
datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
|
||||||
datafusion_physical_plan::execution_plan::EmissionType::Final,
|
datafusion_physical_plan::execution_plan::EmissionType::Final,
|
||||||
datafusion_physical_plan::execution_plan::Boundedness::Bounded,
|
datafusion_physical_plan::execution_plan::Boundedness::Bounded,
|
||||||
);
|
);
|
||||||
@@ -120,9 +65,6 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
|
|||||||
overwrite,
|
overwrite,
|
||||||
properties,
|
properties,
|
||||||
add_result: Arc::new(Mutex::new(None)),
|
add_result: Arc::new(Mutex::new(None)),
|
||||||
metrics: ExecutionPlanMetricsSet::new(),
|
|
||||||
upload_id,
|
|
||||||
tracker,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,10 +72,7 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
|
|||||||
// TODO: this will be used when we wire this up to Table::add().
|
// TODO: this will be used when we wire this up to Table::add().
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub fn add_result(&self) -> Option<AddResult> {
|
pub fn add_result(&self) -> Option<AddResult> {
|
||||||
self.add_result
|
self.add_result.lock().unwrap().clone()
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(|e| e.into_inner())
|
|
||||||
.clone()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stream the input into an HTTP body as an Arrow IPC stream, capturing any
|
/// Stream the input into an HTTP body as an Arrow IPC stream, capturing any
|
||||||
@@ -144,7 +83,6 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
|
|||||||
fn stream_as_http_body(
|
fn stream_as_http_body(
|
||||||
data: SendableRecordBatchStream,
|
data: SendableRecordBatchStream,
|
||||||
error_tx: tokio::sync::oneshot::Sender<DataFusionError>,
|
error_tx: tokio::sync::oneshot::Sender<DataFusionError>,
|
||||||
tracker: Option<Arc<WriteProgressTracker>>,
|
|
||||||
) -> DataFusionResult<reqwest::Body> {
|
) -> DataFusionResult<reqwest::Body> {
|
||||||
let options = arrow_ipc::writer::IpcWriteOptions::default()
|
let options = arrow_ipc::writer::IpcWriteOptions::default()
|
||||||
.try_with_compression(Some(CompressionType::LZ4_FRAME))?;
|
.try_with_compression(Some(CompressionType::LZ4_FRAME))?;
|
||||||
@@ -156,46 +94,37 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
|
|||||||
|
|
||||||
let stream = futures::stream::try_unfold(
|
let stream = futures::stream::try_unfold(
|
||||||
(data, writer, Some(error_tx), false),
|
(data, writer, Some(error_tx), false),
|
||||||
move |(mut data, mut writer, error_tx, finished)| {
|
move |(mut data, mut writer, error_tx, finished)| async move {
|
||||||
let tracker = tracker.clone();
|
if finished {
|
||||||
async move {
|
return Ok(None);
|
||||||
if finished {
|
}
|
||||||
return Ok(None);
|
match data.next().await {
|
||||||
|
Some(Ok(batch)) => {
|
||||||
|
writer
|
||||||
|
.write(&batch)
|
||||||
|
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||||
|
let buffer = std::mem::take(writer.get_mut());
|
||||||
|
Ok(Some((buffer, (data, writer, error_tx, false))))
|
||||||
}
|
}
|
||||||
match data.next().await {
|
Some(Err(e)) => {
|
||||||
Some(Ok(batch)) => {
|
// Send the original error through the channel before
|
||||||
writer
|
// returning a generic error to reqwest.
|
||||||
.write(&batch)
|
if let Some(tx) = error_tx {
|
||||||
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
let _ = tx.send(e);
|
||||||
let buffer = std::mem::take(writer.get_mut());
|
|
||||||
if let Some(ref t) = tracker {
|
|
||||||
t.record_bytes(buffer.len());
|
|
||||||
}
|
|
||||||
Ok(Some((buffer, (data, writer, error_tx, false))))
|
|
||||||
}
|
}
|
||||||
Some(Err(e)) => {
|
Err(std::io::Error::other(
|
||||||
// Send the original error through the channel before
|
"input stream error (see error channel)",
|
||||||
// returning a generic error to reqwest.
|
))
|
||||||
if let Some(tx) = error_tx {
|
}
|
||||||
let _ = tx.send(e);
|
None => {
|
||||||
}
|
writer
|
||||||
Err(std::io::Error::other(
|
.finish()
|
||||||
"input stream error (see error channel)",
|
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
||||||
))
|
let buffer = std::mem::take(writer.get_mut());
|
||||||
}
|
if buffer.is_empty() {
|
||||||
None => {
|
Ok(None)
|
||||||
writer
|
} else {
|
||||||
.finish()
|
Ok(Some((buffer, (data, writer, None, true))))
|
||||||
.map_err(|e| std::io::Error::other(e.to_string()))?;
|
|
||||||
let buffer = std::mem::take(writer.get_mut());
|
|
||||||
if buffer.is_empty() {
|
|
||||||
Ok(None)
|
|
||||||
} else {
|
|
||||||
if let Some(ref t) = tracker {
|
|
||||||
t.record_bytes(buffer.len());
|
|
||||||
}
|
|
||||||
Ok(Some((buffer, (data, writer, None, true))))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -245,11 +174,8 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn required_input_distribution(&self) -> Vec<datafusion_physical_plan::Distribution> {
|
fn required_input_distribution(&self) -> Vec<datafusion_physical_plan::Distribution> {
|
||||||
if self.upload_id.is_some() {
|
// Until we have a separate commit endpoint, we need to do all inserts in a single partition
|
||||||
vec![datafusion_physical_plan::Distribution::UnspecifiedDistribution]
|
vec![datafusion_physical_plan::Distribution::SinglePartition]
|
||||||
} else {
|
|
||||||
vec![datafusion_physical_plan::Distribution::SinglePartition]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
|
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
|
||||||
@@ -265,14 +191,12 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
|
|||||||
"RemoteInsertExec requires exactly one child".to_string(),
|
"RemoteInsertExec requires exactly one child".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
Ok(Arc::new(Self::new_inner(
|
Ok(Arc::new(Self::new(
|
||||||
self.table_name.clone(),
|
self.table_name.clone(),
|
||||||
self.identifier.clone(),
|
self.identifier.clone(),
|
||||||
self.client.clone(),
|
self.client.clone(),
|
||||||
children[0].clone(),
|
children[0].clone(),
|
||||||
self.overwrite,
|
self.overwrite,
|
||||||
self.upload_id.clone(),
|
|
||||||
self.tracker.clone(),
|
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -281,29 +205,18 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
|
|||||||
partition: usize,
|
partition: usize,
|
||||||
context: Arc<TaskContext>,
|
context: Arc<TaskContext>,
|
||||||
) -> DataFusionResult<SendableRecordBatchStream> {
|
) -> DataFusionResult<SendableRecordBatchStream> {
|
||||||
if self.upload_id.is_none() && partition != 0 {
|
if partition != 0 {
|
||||||
return Err(DataFusionError::Internal(
|
return Err(DataFusionError::Internal(
|
||||||
"RemoteInsertExec only supports single partition execution without upload_id"
|
"RemoteInsertExec only supports single partition execution".to_string(),
|
||||||
.to_string(),
|
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let input_stream = self.input.execute(partition, context)?;
|
let input_stream = self.input.execute(0, context)?;
|
||||||
let input_schema = input_stream.schema();
|
|
||||||
let input_stream: SendableRecordBatchStream =
|
|
||||||
Box::pin(InstrumentedRecordBatchStreamAdapter::new(
|
|
||||||
input_schema,
|
|
||||||
input_stream,
|
|
||||||
partition,
|
|
||||||
&self.metrics,
|
|
||||||
));
|
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let identifier = self.identifier.clone();
|
let identifier = self.identifier.clone();
|
||||||
let overwrite = self.overwrite;
|
let overwrite = self.overwrite;
|
||||||
let add_result = self.add_result.clone();
|
let add_result = self.add_result.clone();
|
||||||
let table_name = self.table_name.clone();
|
let table_name = self.table_name.clone();
|
||||||
let upload_id = self.upload_id.clone();
|
|
||||||
let tracker = self.tracker.clone();
|
|
||||||
|
|
||||||
let stream = futures::stream::once(async move {
|
let stream = futures::stream::once(async move {
|
||||||
let mut request = client
|
let mut request = client
|
||||||
@@ -313,12 +226,9 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
|
|||||||
if overwrite {
|
if overwrite {
|
||||||
request = request.query(&[("mode", "overwrite")]);
|
request = request.query(&[("mode", "overwrite")]);
|
||||||
}
|
}
|
||||||
if let Some(ref uid) = upload_id {
|
|
||||||
request = request.query(&[("upload_id", uid.as_str())]);
|
|
||||||
}
|
|
||||||
|
|
||||||
let (error_tx, mut error_rx) = tokio::sync::oneshot::channel();
|
let (error_tx, mut error_rx) = tokio::sync::oneshot::channel();
|
||||||
let body = Self::stream_as_http_body(input_stream, error_tx, tracker)?;
|
let body = Self::stream_as_http_body(input_stream, error_tx)?;
|
||||||
let request = request.body(body);
|
let request = request.body(body);
|
||||||
|
|
||||||
let result: DataFusionResult<(String, _)> = async {
|
let result: DataFusionResult<(String, _)> = async {
|
||||||
@@ -352,43 +262,32 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
|
|||||||
|
|
||||||
let (request_id, response) = result?;
|
let (request_id, response) = result?;
|
||||||
|
|
||||||
// For multipart writes, the staging response is not the final
|
let body_text = response.text().await.map_err(|e| {
|
||||||
// version. Only parse AddResult for non-multipart inserts.
|
DataFusionError::External(Box::new(Error::Http {
|
||||||
if upload_id.is_none() {
|
source: Box::new(e),
|
||||||
let body_text = response.text().await.map_err(|e| {
|
request_id: request_id.clone(),
|
||||||
|
status_code: None,
|
||||||
|
}))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let parsed_result = if body_text.trim().is_empty() {
|
||||||
|
// Backward compatible with old servers
|
||||||
|
AddResult { version: 0 }
|
||||||
|
} else {
|
||||||
|
serde_json::from_str(&body_text).map_err(|e| {
|
||||||
DataFusionError::External(Box::new(Error::Http {
|
DataFusionError::External(Box::new(Error::Http {
|
||||||
source: Box::new(e),
|
source: format!("Failed to parse add response: {}", e).into(),
|
||||||
request_id: request_id.clone(),
|
request_id: request_id.clone(),
|
||||||
status_code: None,
|
status_code: None,
|
||||||
}))
|
}))
|
||||||
})?;
|
})?
|
||||||
|
};
|
||||||
let parsed_result = if body_text.trim().is_empty() {
|
|
||||||
// Backward compatible with old servers
|
|
||||||
AddResult { version: 0 }
|
|
||||||
} else {
|
|
||||||
serde_json::from_str(&body_text).map_err(|e| {
|
|
||||||
DataFusionError::External(Box::new(Error::Http {
|
|
||||||
source: format!("Failed to parse add response: {}", e).into(),
|
|
||||||
request_id: request_id.clone(),
|
|
||||||
status_code: None,
|
|
||||||
}))
|
|
||||||
})?
|
|
||||||
};
|
|
||||||
|
|
||||||
|
{
|
||||||
let mut res_lock = add_result.lock().map_err(|_| {
|
let mut res_lock = add_result.lock().map_err(|_| {
|
||||||
DataFusionError::Execution("Failed to acquire lock for add_result".to_string())
|
DataFusionError::Execution("Failed to acquire lock for add_result".to_string())
|
||||||
})?;
|
})?;
|
||||||
*res_lock = Some(parsed_result);
|
*res_lock = Some(parsed_result);
|
||||||
} else {
|
|
||||||
// We don't use the body in this case, but we should still consume it.
|
|
||||||
let _ = response.bytes().await.map_err(|e| {
|
|
||||||
DataFusionError::External(Box::new(Error::Http {
|
|
||||||
source: Box::new(e),
|
|
||||||
request_id: request_id.clone(),
|
|
||||||
status_code: None,
|
|
||||||
}))
|
|
||||||
})?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a single batch with count 0 (actual count is tracked in add_result)
|
// Return a single batch with count 0 (actual count is tracked in add_result)
|
||||||
@@ -402,10 +301,6 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
|
|||||||
stream,
|
stream,
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn metrics(&self) -> Option<MetricsSet> {
|
|
||||||
Some(self.metrics.clone_inner())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -74,10 +74,7 @@ pub mod optimize;
|
|||||||
pub mod query;
|
pub mod query;
|
||||||
pub mod schema_evolution;
|
pub mod schema_evolution;
|
||||||
pub mod update;
|
pub mod update;
|
||||||
pub mod write_progress;
|
|
||||||
use crate::index::waiter::wait_for_index;
|
use crate::index::waiter::wait_for_index;
|
||||||
#[cfg(feature = "remote")]
|
|
||||||
pub(crate) use add_data::PreprocessingOutput;
|
|
||||||
pub use add_data::{AddDataBuilder, AddDataMode, AddResult, NaNVectorBehavior};
|
pub use add_data::{AddDataBuilder, AddDataMode, AddResult, NaNVectorBehavior};
|
||||||
pub use chrono::Duration;
|
pub use chrono::Duration;
|
||||||
pub use delete::DeleteResult;
|
pub use delete::DeleteResult;
|
||||||
@@ -443,34 +440,6 @@ mod test_utils {
|
|||||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_with_handler_version_and_config<T>(
|
|
||||||
name: impl Into<String>,
|
|
||||||
version: semver::Version,
|
|
||||||
handler: impl Fn(reqwest::Request) -> http::Response<T> + Clone + Send + Sync + 'static,
|
|
||||||
config: crate::remote::ClientConfig,
|
|
||||||
) -> Self
|
|
||||||
where
|
|
||||||
T: Into<reqwest::Body>,
|
|
||||||
{
|
|
||||||
let inner = Arc::new(
|
|
||||||
crate::remote::table::RemoteTable::new_mock_with_version_and_config(
|
|
||||||
name.into(),
|
|
||||||
handler.clone(),
|
|
||||||
Some(version),
|
|
||||||
config.clone(),
|
|
||||||
),
|
|
||||||
);
|
|
||||||
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock_with_config(
|
|
||||||
handler, config,
|
|
||||||
));
|
|
||||||
Self {
|
|
||||||
inner,
|
|
||||||
database: Some(database),
|
|
||||||
// Registry is unused.
|
|
||||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2229,26 +2198,21 @@ impl BaseTable for NativeTable {
|
|||||||
|
|
||||||
let table_schema = Schema::from(&ds.schema().clone());
|
let table_schema = Schema::from(&ds.schema().clone());
|
||||||
|
|
||||||
let num_partitions = if let Some(parallelism) = add.write_parallelism {
|
// Peek at the first batch to estimate a good partition count for
|
||||||
parallelism
|
// write parallelism.
|
||||||
|
let mut peeked = PeekedScannable::new(add.data);
|
||||||
|
let num_partitions = if let Some(first_batch) = peeked.peek().await {
|
||||||
|
let max_partitions = lance_core::utils::tokio::get_num_compute_intensive_cpus();
|
||||||
|
estimate_write_partitions(
|
||||||
|
first_batch.get_array_memory_size(),
|
||||||
|
first_batch.num_rows(),
|
||||||
|
peeked.num_rows(),
|
||||||
|
max_partitions,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
// Peek at the first batch to estimate a good partition count for
|
1
|
||||||
// write parallelism.
|
|
||||||
let mut peeked = PeekedScannable::new(add.data);
|
|
||||||
let n = if let Some(first_batch) = peeked.peek().await {
|
|
||||||
let max_partitions = lance_core::utils::tokio::get_num_compute_intensive_cpus();
|
|
||||||
estimate_write_partitions(
|
|
||||||
first_batch.get_array_memory_size(),
|
|
||||||
first_batch.num_rows(),
|
|
||||||
peeked.num_rows(),
|
|
||||||
max_partitions,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
1
|
|
||||||
};
|
|
||||||
add.data = Box::new(peeked);
|
|
||||||
n
|
|
||||||
};
|
};
|
||||||
|
add.data = Box::new(peeked);
|
||||||
|
|
||||||
let output = add.into_plan(&table_schema, &table_def)?;
|
let output = add.into_plan(&table_schema, &table_def)?;
|
||||||
|
|
||||||
@@ -2277,21 +2241,13 @@ impl BaseTable for NativeTable {
|
|||||||
|
|
||||||
let insert_exec = Arc::new(InsertExec::new(ds_wrapper.clone(), ds, plan, lance_params));
|
let insert_exec = Arc::new(InsertExec::new(ds_wrapper.clone(), ds, plan, lance_params));
|
||||||
|
|
||||||
let tracker_for_tasks = output.tracker.clone();
|
|
||||||
if let Some(ref t) = tracker_for_tasks {
|
|
||||||
t.set_total_tasks(num_partitions);
|
|
||||||
}
|
|
||||||
let _finish = write_progress::FinishOnDrop(output.tracker);
|
|
||||||
|
|
||||||
// Execute all partitions in parallel.
|
// Execute all partitions in parallel.
|
||||||
let task_ctx = Arc::new(TaskContext::default());
|
let task_ctx = Arc::new(TaskContext::default());
|
||||||
let handles = FuturesUnordered::new();
|
let handles = FuturesUnordered::new();
|
||||||
for partition in 0..num_partitions {
|
for partition in 0..num_partitions {
|
||||||
let exec = insert_exec.clone();
|
let exec = insert_exec.clone();
|
||||||
let ctx = task_ctx.clone();
|
let ctx = task_ctx.clone();
|
||||||
let tracker = tracker_for_tasks.clone();
|
|
||||||
handles.push(tokio::spawn(async move {
|
handles.push(tokio::spawn(async move {
|
||||||
let _guard = tracker.as_ref().map(|t| t.track_task());
|
|
||||||
let mut stream = exec
|
let mut stream = exec
|
||||||
.execute(partition, ctx)
|
.execute(partition, ctx)
|
||||||
.map_err(|e| -> Error { e.into() })?;
|
.map_err(|e| -> Error { e.into() })?;
|
||||||
|
|||||||
@@ -13,9 +13,6 @@ use crate::embeddings::EmbeddingRegistry;
|
|||||||
use crate::table::datafusion::cast::cast_to_table_schema;
|
use crate::table::datafusion::cast::cast_to_table_schema;
|
||||||
use crate::table::datafusion::reject_nan::reject_nan_vectors;
|
use crate::table::datafusion::reject_nan::reject_nan_vectors;
|
||||||
use crate::table::datafusion::scannable_exec::ScannableExec;
|
use crate::table::datafusion::scannable_exec::ScannableExec;
|
||||||
use crate::table::write_progress::ProgressCallback;
|
|
||||||
use crate::table::write_progress::WriteProgress;
|
|
||||||
use crate::table::write_progress::WriteProgressTracker;
|
|
||||||
use crate::{Error, Result};
|
use crate::{Error, Result};
|
||||||
|
|
||||||
use super::{BaseTable, TableDefinition, WriteOptions};
|
use super::{BaseTable, TableDefinition, WriteOptions};
|
||||||
@@ -55,8 +52,6 @@ pub struct AddDataBuilder {
|
|||||||
pub(crate) write_options: WriteOptions,
|
pub(crate) write_options: WriteOptions,
|
||||||
pub(crate) on_nan_vectors: NaNVectorBehavior,
|
pub(crate) on_nan_vectors: NaNVectorBehavior,
|
||||||
pub(crate) embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
pub(crate) embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||||
pub(crate) progress_callback: Option<ProgressCallback>,
|
|
||||||
pub(crate) write_parallelism: Option<usize>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for AddDataBuilder {
|
impl std::fmt::Debug for AddDataBuilder {
|
||||||
@@ -82,8 +77,6 @@ impl AddDataBuilder {
|
|||||||
write_options: WriteOptions::default(),
|
write_options: WriteOptions::default(),
|
||||||
on_nan_vectors: NaNVectorBehavior::default(),
|
on_nan_vectors: NaNVectorBehavior::default(),
|
||||||
embedding_registry,
|
embedding_registry,
|
||||||
progress_callback: None,
|
|
||||||
write_parallelism: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,43 +101,7 @@ impl AddDataBuilder {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set a callback to receive progress updates during the add operation.
|
|
||||||
///
|
|
||||||
/// The callback is invoked once per batch written, and once more with
|
|
||||||
/// [`WriteProgress::done`] set to `true` when the write completes.
|
|
||||||
///
|
|
||||||
/// ```
|
|
||||||
/// # use lancedb::Table;
|
|
||||||
/// # async fn example(table: &Table) -> Result<(), Box<dyn std::error::Error>> {
|
|
||||||
/// let batch = arrow_array::record_batch!(("id", Int32, [1, 2, 3])).unwrap();
|
|
||||||
/// table.add(batch)
|
|
||||||
/// .progress(|p| println!("{}/{:?} rows", p.output_rows(), p.total_rows()))
|
|
||||||
/// .execute()
|
|
||||||
/// .await?;
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// ```
|
|
||||||
pub fn progress(mut self, callback: impl FnMut(&WriteProgress) + Send + 'static) -> Self {
|
|
||||||
self.progress_callback = Some(Arc::new(std::sync::Mutex::new(callback)));
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the number of parallel write streams.
|
|
||||||
///
|
|
||||||
/// By default, the number of streams is estimated from the data size.
|
|
||||||
/// Setting this to `1` disables parallel writes.
|
|
||||||
pub fn write_parallelism(mut self, parallelism: usize) -> Self {
|
|
||||||
self.write_parallelism = Some(parallelism);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn execute(self) -> Result<AddResult> {
|
pub async fn execute(self) -> Result<AddResult> {
|
||||||
if self.write_parallelism.map(|p| p == 0).unwrap_or(false) {
|
|
||||||
return Err(Error::InvalidInput {
|
|
||||||
message: "write_parallelism must be greater than 0".to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
self.parent.clone().add(self).await
|
self.parent.clone().add(self).await
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -173,11 +130,8 @@ impl AddDataBuilder {
|
|||||||
scannable_with_embeddings(self.data, table_def, self.embedding_registry.as_ref())?;
|
scannable_with_embeddings(self.data, table_def, self.embedding_registry.as_ref())?;
|
||||||
|
|
||||||
let rescannable = self.data.rescannable();
|
let rescannable = self.data.rescannable();
|
||||||
let tracker = self
|
|
||||||
.progress_callback
|
|
||||||
.map(|cb| Arc::new(WriteProgressTracker::new(cb, self.data.num_rows())));
|
|
||||||
let plan: Arc<dyn datafusion_physical_plan::ExecutionPlan> =
|
let plan: Arc<dyn datafusion_physical_plan::ExecutionPlan> =
|
||||||
Arc::new(ScannableExec::new(self.data, tracker.clone()));
|
Arc::new(ScannableExec::new(self.data));
|
||||||
// Skip casting when overwriting — the input schema replaces the table schema.
|
// Skip casting when overwriting — the input schema replaces the table schema.
|
||||||
let plan = if overwrite {
|
let plan = if overwrite {
|
||||||
plan
|
plan
|
||||||
@@ -195,7 +149,6 @@ impl AddDataBuilder {
|
|||||||
rescannable,
|
rescannable,
|
||||||
write_options: self.write_options,
|
write_options: self.write_options,
|
||||||
mode: self.mode,
|
mode: self.mode,
|
||||||
tracker,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -208,7 +161,6 @@ pub struct PreprocessingOutput {
|
|||||||
pub rescannable: bool,
|
pub rescannable: bool,
|
||||||
pub write_options: WriteOptions,
|
pub write_options: WriteOptions,
|
||||||
pub mode: AddDataMode,
|
pub mode: AddDataMode,
|
||||||
pub tracker: Option<Arc<WriteProgressTracker>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check that the input schema is valid for insert.
|
/// Check that the input schema is valid for insert.
|
||||||
|
|||||||
@@ -12,16 +12,13 @@ use datafusion_common::{DataFusionError, Result as DataFusionResult};
|
|||||||
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
|
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
|
||||||
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
|
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
|
||||||
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
|
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
|
||||||
use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
|
|
||||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||||
use datafusion_physical_plan::{
|
use datafusion_physical_plan::{
|
||||||
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
|
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
|
||||||
};
|
};
|
||||||
use futures::TryStreamExt;
|
|
||||||
use lance::Dataset;
|
use lance::Dataset;
|
||||||
use lance::dataset::transaction::{Operation, Transaction};
|
use lance::dataset::transaction::{Operation, Transaction};
|
||||||
use lance::dataset::{CommitBuilder, InsertBuilder, WriteParams};
|
use lance::dataset::{CommitBuilder, InsertBuilder, WriteParams};
|
||||||
use lance::io::exec::utils::InstrumentedRecordBatchStreamAdapter;
|
|
||||||
use lance_table::format::Fragment;
|
use lance_table::format::Fragment;
|
||||||
|
|
||||||
use crate::table::dataset::DatasetConsistencyWrapper;
|
use crate::table::dataset::DatasetConsistencyWrapper;
|
||||||
@@ -83,7 +80,6 @@ pub struct InsertExec {
|
|||||||
write_params: WriteParams,
|
write_params: WriteParams,
|
||||||
properties: PlanProperties,
|
properties: PlanProperties,
|
||||||
partial_transactions: Arc<Mutex<Vec<Transaction>>>,
|
partial_transactions: Arc<Mutex<Vec<Transaction>>>,
|
||||||
metrics: ExecutionPlanMetricsSet,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InsertExec {
|
impl InsertExec {
|
||||||
@@ -109,7 +105,6 @@ impl InsertExec {
|
|||||||
write_params,
|
write_params,
|
||||||
properties,
|
properties,
|
||||||
partial_transactions: Arc::new(Mutex::new(Vec::with_capacity(num_partitions))),
|
partial_transactions: Arc::new(Mutex::new(Vec::with_capacity(num_partitions))),
|
||||||
metrics: ExecutionPlanMetricsSet::new(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -181,19 +176,6 @@ impl ExecutionPlan for InsertExec {
|
|||||||
let total_partitions = self.input.output_partitioning().partition_count();
|
let total_partitions = self.input.output_partitioning().partition_count();
|
||||||
let ds_wrapper = self.ds_wrapper.clone();
|
let ds_wrapper = self.ds_wrapper.clone();
|
||||||
|
|
||||||
let output_bytes = MetricBuilder::new(&self.metrics).output_bytes(partition);
|
|
||||||
let input_schema = input_stream.schema();
|
|
||||||
let input_stream: SendableRecordBatchStream =
|
|
||||||
Box::pin(InstrumentedRecordBatchStreamAdapter::new(
|
|
||||||
input_schema,
|
|
||||||
input_stream.map_ok(move |batch| {
|
|
||||||
output_bytes.add(batch.get_array_memory_size());
|
|
||||||
batch
|
|
||||||
}),
|
|
||||||
partition,
|
|
||||||
&self.metrics,
|
|
||||||
));
|
|
||||||
|
|
||||||
let stream = futures::stream::once(async move {
|
let stream = futures::stream::once(async move {
|
||||||
let transaction = InsertBuilder::new(dataset.clone())
|
let transaction = InsertBuilder::new(dataset.clone())
|
||||||
.with_params(&write_params)
|
.with_params(&write_params)
|
||||||
@@ -204,9 +186,7 @@ impl ExecutionPlan for InsertExec {
|
|||||||
|
|
||||||
let to_commit = {
|
let to_commit = {
|
||||||
// Don't hold the lock over an await point.
|
// Don't hold the lock over an await point.
|
||||||
let mut txns = partial_transactions
|
let mut txns = partial_transactions.lock().unwrap();
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(|e| e.into_inner());
|
|
||||||
txns.push(transaction);
|
txns.push(transaction);
|
||||||
if txns.len() == total_partitions {
|
if txns.len() == total_partitions {
|
||||||
Some(std::mem::take(&mut *txns))
|
Some(std::mem::take(&mut *txns))
|
||||||
@@ -235,10 +215,6 @@ impl ExecutionPlan for InsertExec {
|
|||||||
stream,
|
stream,
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn metrics(&self) -> Option<MetricsSet> {
|
|
||||||
Some(self.metrics.clone_inner())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -7,21 +7,17 @@ use std::sync::{Arc, Mutex};
|
|||||||
use datafusion_common::{DataFusionError, Result as DFResult, Statistics, stats::Precision};
|
use datafusion_common::{DataFusionError, Result as DFResult, Statistics, stats::Precision};
|
||||||
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
|
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
|
||||||
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
|
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
|
||||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
|
||||||
use datafusion_physical_plan::{
|
use datafusion_physical_plan::{
|
||||||
DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, execution_plan::EmissionType,
|
DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, execution_plan::EmissionType,
|
||||||
};
|
};
|
||||||
use futures::TryStreamExt;
|
|
||||||
|
|
||||||
use crate::table::write_progress::WriteProgressTracker;
|
|
||||||
use crate::{arrow::SendableRecordBatchStreamExt, data::scannable::Scannable};
|
use crate::{arrow::SendableRecordBatchStreamExt, data::scannable::Scannable};
|
||||||
|
|
||||||
pub(crate) struct ScannableExec {
|
pub struct ScannableExec {
|
||||||
// We don't require Scannable to be Sync, so we wrap it in a Mutex to allow safe concurrent access.
|
// We don't require Scannable to by Sync, so we wrap it in a Mutex to allow safe concurrent access.
|
||||||
source: Mutex<Box<dyn Scannable>>,
|
source: Mutex<Box<dyn Scannable>>,
|
||||||
num_rows: Option<usize>,
|
num_rows: Option<usize>,
|
||||||
properties: PlanProperties,
|
properties: PlanProperties,
|
||||||
tracker: Option<Arc<WriteProgressTracker>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for ScannableExec {
|
impl std::fmt::Debug for ScannableExec {
|
||||||
@@ -34,7 +30,7 @@ impl std::fmt::Debug for ScannableExec {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ScannableExec {
|
impl ScannableExec {
|
||||||
pub fn new(source: Box<dyn Scannable>, tracker: Option<Arc<WriteProgressTracker>>) -> Self {
|
pub fn new(source: Box<dyn Scannable>) -> Self {
|
||||||
let schema = source.schema();
|
let schema = source.schema();
|
||||||
let eq_properties = EquivalenceProperties::new(schema);
|
let eq_properties = EquivalenceProperties::new(schema);
|
||||||
let properties = PlanProperties::new(
|
let properties = PlanProperties::new(
|
||||||
@@ -50,7 +46,6 @@ impl ScannableExec {
|
|||||||
source,
|
source,
|
||||||
num_rows,
|
num_rows,
|
||||||
properties,
|
properties,
|
||||||
tracker,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -107,18 +102,7 @@ impl ExecutionPlan for ScannableExec {
|
|||||||
Err(poison) => poison.into_inner().scan_as_stream(),
|
Err(poison) => poison.into_inner().scan_as_stream(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let tracker = self.tracker.clone();
|
Ok(stream.into_df_stream())
|
||||||
let stream = stream.into_df_stream().map_ok(move |batch| {
|
|
||||||
if let Some(ref t) = tracker {
|
|
||||||
t.record_batch(batch.num_rows(), batch.get_array_memory_size());
|
|
||||||
}
|
|
||||||
batch
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(Box::pin(RecordBatchStreamAdapter::new(
|
|
||||||
self.schema(),
|
|
||||||
stream,
|
|
||||||
)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn partition_statistics(&self, _partition: Option<usize>) -> DFResult<Statistics> {
|
fn partition_statistics(&self, _partition: Option<usize>) -> DFResult<Statistics> {
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ impl DatasetConsistencyWrapper {
|
|||||||
/// pinned dataset regardless of consistency mode.
|
/// pinned dataset regardless of consistency mode.
|
||||||
pub async fn get(&self) -> Result<Arc<Dataset>> {
|
pub async fn get(&self) -> Result<Arc<Dataset>> {
|
||||||
{
|
{
|
||||||
let state = self.state.lock()?;
|
let state = self.state.lock().unwrap();
|
||||||
if state.pinned_version.is_some() {
|
if state.pinned_version.is_some() {
|
||||||
return Ok(state.dataset.clone());
|
return Ok(state.dataset.clone());
|
||||||
}
|
}
|
||||||
@@ -101,7 +101,7 @@ impl DatasetConsistencyWrapper {
|
|||||||
}
|
}
|
||||||
ConsistencyMode::Strong => refresh_latest(self.state.clone()).await,
|
ConsistencyMode::Strong => refresh_latest(self.state.clone()).await,
|
||||||
ConsistencyMode::Lazy => {
|
ConsistencyMode::Lazy => {
|
||||||
let state = self.state.lock()?;
|
let state = self.state.lock().unwrap();
|
||||||
Ok(state.dataset.clone())
|
Ok(state.dataset.clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -116,7 +116,7 @@ impl DatasetConsistencyWrapper {
|
|||||||
/// concurrent [`as_time_travel`](Self::as_time_travel) call), the update
|
/// concurrent [`as_time_travel`](Self::as_time_travel) call), the update
|
||||||
/// is silently ignored — the write already committed to storage.
|
/// is silently ignored — the write already committed to storage.
|
||||||
pub fn update(&self, dataset: Dataset) {
|
pub fn update(&self, dataset: Dataset) {
|
||||||
let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
|
let mut state = self.state.lock().unwrap();
|
||||||
if state.pinned_version.is_some() {
|
if state.pinned_version.is_some() {
|
||||||
// A concurrent as_time_travel() beat us here. The write succeeded
|
// A concurrent as_time_travel() beat us here. The write succeeded
|
||||||
// in storage, but since we're now pinned we don't advance the
|
// in storage, but since we're now pinned we don't advance the
|
||||||
@@ -139,7 +139,7 @@ impl DatasetConsistencyWrapper {
|
|||||||
|
|
||||||
/// Check that the dataset is in a mutable mode (Latest).
|
/// Check that the dataset is in a mutable mode (Latest).
|
||||||
pub fn ensure_mutable(&self) -> Result<()> {
|
pub fn ensure_mutable(&self) -> Result<()> {
|
||||||
let state = self.state.lock()?;
|
let state = self.state.lock().unwrap();
|
||||||
if state.pinned_version.is_some() {
|
if state.pinned_version.is_some() {
|
||||||
Err(crate::Error::InvalidInput {
|
Err(crate::Error::InvalidInput {
|
||||||
message: "table cannot be modified when a specific version is checked out"
|
message: "table cannot be modified when a specific version is checked out"
|
||||||
@@ -152,16 +152,13 @@ impl DatasetConsistencyWrapper {
|
|||||||
|
|
||||||
/// Returns the version, if in time travel mode, or None otherwise.
|
/// Returns the version, if in time travel mode, or None otherwise.
|
||||||
pub fn time_travel_version(&self) -> Option<u64> {
|
pub fn time_travel_version(&self) -> Option<u64> {
|
||||||
self.state
|
self.state.lock().unwrap().pinned_version
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(|e| e.into_inner())
|
|
||||||
.pinned_version
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert into a wrapper in latest version mode.
|
/// Convert into a wrapper in latest version mode.
|
||||||
pub async fn as_latest(&self) -> Result<()> {
|
pub async fn as_latest(&self) -> Result<()> {
|
||||||
let dataset = {
|
let dataset = {
|
||||||
let state = self.state.lock()?;
|
let state = self.state.lock().unwrap();
|
||||||
if state.pinned_version.is_none() {
|
if state.pinned_version.is_none() {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
@@ -171,7 +168,7 @@ impl DatasetConsistencyWrapper {
|
|||||||
let latest_version = dataset.latest_version_id().await?;
|
let latest_version = dataset.latest_version_id().await?;
|
||||||
let new_dataset = dataset.checkout_version(latest_version).await?;
|
let new_dataset = dataset.checkout_version(latest_version).await?;
|
||||||
|
|
||||||
let mut state = self.state.lock()?;
|
let mut state = self.state.lock().unwrap();
|
||||||
if state.pinned_version.is_some() {
|
if state.pinned_version.is_some() {
|
||||||
state.dataset = Arc::new(new_dataset);
|
state.dataset = Arc::new(new_dataset);
|
||||||
state.pinned_version = None;
|
state.pinned_version = None;
|
||||||
@@ -187,7 +184,7 @@ impl DatasetConsistencyWrapper {
|
|||||||
let target_ref = target_version.into();
|
let target_ref = target_version.into();
|
||||||
|
|
||||||
let (should_checkout, dataset) = {
|
let (should_checkout, dataset) = {
|
||||||
let state = self.state.lock()?;
|
let state = self.state.lock().unwrap();
|
||||||
let should = match state.pinned_version {
|
let should = match state.pinned_version {
|
||||||
None => true,
|
None => true,
|
||||||
Some(version) => match &target_ref {
|
Some(version) => match &target_ref {
|
||||||
@@ -207,7 +204,7 @@ impl DatasetConsistencyWrapper {
|
|||||||
let new_dataset = dataset.checkout_version(target_ref).await?;
|
let new_dataset = dataset.checkout_version(target_ref).await?;
|
||||||
let version_value = new_dataset.version().version;
|
let version_value = new_dataset.version().version;
|
||||||
|
|
||||||
let mut state = self.state.lock()?;
|
let mut state = self.state.lock().unwrap();
|
||||||
state.dataset = Arc::new(new_dataset);
|
state.dataset = Arc::new(new_dataset);
|
||||||
state.pinned_version = Some(version_value);
|
state.pinned_version = Some(version_value);
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -215,7 +212,7 @@ impl DatasetConsistencyWrapper {
|
|||||||
|
|
||||||
pub async fn reload(&self) -> Result<()> {
|
pub async fn reload(&self) -> Result<()> {
|
||||||
let (dataset, pinned_version) = {
|
let (dataset, pinned_version) = {
|
||||||
let state = self.state.lock()?;
|
let state = self.state.lock().unwrap();
|
||||||
(state.dataset.clone(), state.pinned_version)
|
(state.dataset.clone(), state.pinned_version)
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -233,7 +230,7 @@ impl DatasetConsistencyWrapper {
|
|||||||
|
|
||||||
let new_dataset = dataset.checkout_version(version).await?;
|
let new_dataset = dataset.checkout_version(version).await?;
|
||||||
|
|
||||||
let mut state = self.state.lock()?;
|
let mut state = self.state.lock().unwrap();
|
||||||
if state.pinned_version == Some(version) {
|
if state.pinned_version == Some(version) {
|
||||||
state.dataset = Arc::new(new_dataset);
|
state.dataset = Arc::new(new_dataset);
|
||||||
}
|
}
|
||||||
@@ -245,14 +242,14 @@ impl DatasetConsistencyWrapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn refresh_latest(state: Arc<Mutex<DatasetState>>) -> Result<Arc<Dataset>> {
|
async fn refresh_latest(state: Arc<Mutex<DatasetState>>) -> Result<Arc<Dataset>> {
|
||||||
let dataset = { state.lock()?.dataset.clone() };
|
let dataset = { state.lock().unwrap().dataset.clone() };
|
||||||
|
|
||||||
let mut ds = (*dataset).clone();
|
let mut ds = (*dataset).clone();
|
||||||
ds.checkout_latest().await?;
|
ds.checkout_latest().await?;
|
||||||
let new_arc = Arc::new(ds);
|
let new_arc = Arc::new(ds);
|
||||||
|
|
||||||
{
|
{
|
||||||
let mut state = state.lock()?;
|
let mut state = state.lock().unwrap();
|
||||||
if state.pinned_version.is_none()
|
if state.pinned_version.is_none()
|
||||||
&& new_arc.manifest().version >= state.dataset.manifest().version
|
&& new_arc.manifest().version >= state.dataset.manifest().version
|
||||||
{
|
{
|
||||||
@@ -615,108 +612,4 @@ mod tests {
|
|||||||
let s = io_stats.incremental_stats();
|
let s = io_stats.incremental_stats();
|
||||||
assert_eq!(s.read_iops, 0, "step 5, elapsed={:?}", start.elapsed());
|
assert_eq!(s.read_iops, 0, "step 5, elapsed={:?}", start.elapsed());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper: poison the mutex inside a DatasetConsistencyWrapper.
|
|
||||||
fn poison_state(wrapper: &DatasetConsistencyWrapper) {
|
|
||||||
let state = wrapper.state.clone();
|
|
||||||
let handle = std::thread::spawn(move || {
|
|
||||||
let _guard = state.lock().unwrap();
|
|
||||||
panic!("intentional panic to poison mutex");
|
|
||||||
});
|
|
||||||
let _ = handle.join(); // join collects the panic
|
|
||||||
assert!(wrapper.state.lock().is_err(), "mutex should be poisoned");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_get_returns_error_on_poisoned_lock() {
|
|
||||||
let dir = tempfile::tempdir().unwrap();
|
|
||||||
let uri = dir.path().to_str().unwrap();
|
|
||||||
let ds = create_test_dataset(uri).await;
|
|
||||||
|
|
||||||
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
|
|
||||||
poison_state(&wrapper);
|
|
||||||
|
|
||||||
// get() should return Err, not panic
|
|
||||||
let result = wrapper.get().await;
|
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_ensure_mutable_returns_error_on_poisoned_lock() {
|
|
||||||
let dir = tempfile::tempdir().unwrap();
|
|
||||||
let uri = dir.path().to_str().unwrap();
|
|
||||||
let ds = create_test_dataset(uri).await;
|
|
||||||
|
|
||||||
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
|
|
||||||
poison_state(&wrapper);
|
|
||||||
|
|
||||||
let result = wrapper.ensure_mutable();
|
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_update_recovers_from_poisoned_lock() {
|
|
||||||
let dir = tempfile::tempdir().unwrap();
|
|
||||||
let uri = dir.path().to_str().unwrap();
|
|
||||||
let ds = create_test_dataset(uri).await;
|
|
||||||
let ds_v2 = append_to_dataset(uri).await;
|
|
||||||
|
|
||||||
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
|
|
||||||
poison_state(&wrapper);
|
|
||||||
|
|
||||||
// update() returns (), should not panic
|
|
||||||
wrapper.update(ds_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_time_travel_version_recovers_from_poisoned_lock() {
|
|
||||||
let dir = tempfile::tempdir().unwrap();
|
|
||||||
let uri = dir.path().to_str().unwrap();
|
|
||||||
let ds = create_test_dataset(uri).await;
|
|
||||||
|
|
||||||
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
|
|
||||||
poison_state(&wrapper);
|
|
||||||
|
|
||||||
// Should not panic, returns whatever was in the mutex
|
|
||||||
let _version = wrapper.time_travel_version();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_as_latest_returns_error_on_poisoned_lock() {
|
|
||||||
let dir = tempfile::tempdir().unwrap();
|
|
||||||
let uri = dir.path().to_str().unwrap();
|
|
||||||
let ds = create_test_dataset(uri).await;
|
|
||||||
|
|
||||||
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
|
|
||||||
poison_state(&wrapper);
|
|
||||||
|
|
||||||
let result = wrapper.as_latest().await;
|
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_as_time_travel_returns_error_on_poisoned_lock() {
|
|
||||||
let dir = tempfile::tempdir().unwrap();
|
|
||||||
let uri = dir.path().to_str().unwrap();
|
|
||||||
let ds = create_test_dataset(uri).await;
|
|
||||||
|
|
||||||
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
|
|
||||||
poison_state(&wrapper);
|
|
||||||
|
|
||||||
let result = wrapper.as_time_travel(1u64).await;
|
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_reload_returns_error_on_poisoned_lock() {
|
|
||||||
let dir = tempfile::tempdir().unwrap();
|
|
||||||
let uri = dir.path().to_str().unwrap();
|
|
||||||
let ds = create_test_dataset(uri).await;
|
|
||||||
|
|
||||||
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
|
|
||||||
poison_state(&wrapper);
|
|
||||||
|
|
||||||
let result = wrapper.reload().await;
|
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ use crate::expr::expr_to_sql_string;
|
|||||||
use crate::query::{
|
use crate::query::{
|
||||||
DEFAULT_TOP_K, QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest,
|
DEFAULT_TOP_K, QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest,
|
||||||
};
|
};
|
||||||
use crate::utils::{MaxBatchLengthStream, TimeoutStream, default_vector_column};
|
use crate::utils::{TimeoutStream, default_vector_column};
|
||||||
use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder};
|
use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder};
|
||||||
use arrow::datatypes::{Float32Type, UInt8Type};
|
use arrow::datatypes::{Float32Type, UInt8Type};
|
||||||
use arrow_array::Array;
|
use arrow_array::Array;
|
||||||
@@ -66,7 +66,6 @@ async fn execute_generic_query(
|
|||||||
) -> Result<DatasetRecordBatchStream> {
|
) -> Result<DatasetRecordBatchStream> {
|
||||||
let plan = create_plan(table, query, options.clone()).await?;
|
let plan = create_plan(table, query, options.clone()).await?;
|
||||||
let inner = execute_plan(plan, Default::default())?;
|
let inner = execute_plan(plan, Default::default())?;
|
||||||
let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize);
|
|
||||||
let inner = if let Some(timeout) = options.timeout {
|
let inner = if let Some(timeout) = options.timeout {
|
||||||
TimeoutStream::new_boxed(inner, timeout)
|
TimeoutStream::new_boxed(inner, timeout)
|
||||||
} else {
|
} else {
|
||||||
@@ -201,9 +200,7 @@ pub async fn create_plan(
|
|||||||
scanner.with_row_id();
|
scanner.with_row_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.max_batch_length > 0 {
|
scanner.batch_size(options.max_batch_length as usize);
|
||||||
scanner.batch_size(options.max_batch_length as usize);
|
|
||||||
}
|
|
||||||
|
|
||||||
if query.base.fast_search {
|
if query.base.fast_search {
|
||||||
scanner.fast_search();
|
scanner.fast_search();
|
||||||
|
|||||||
@@ -1,431 +0,0 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
||||||
|
|
||||||
//! Progress monitoring for write operations.
|
|
||||||
//!
|
|
||||||
//! You can add a callback to process progress in [`crate::table::AddDataBuilder::progress`].
|
|
||||||
//! [`WriteProgress`] is the struct passed to the callback.
|
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
use std::time::{Duration, Instant};
|
|
||||||
|
|
||||||
/// Progress snapshot for a write operation.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct WriteProgress {
|
|
||||||
// These are private and only accessible via getters, to make it easy to add
|
|
||||||
// new fields without breaking existing callbacks.
|
|
||||||
elapsed: Duration,
|
|
||||||
output_rows: usize,
|
|
||||||
output_bytes: usize,
|
|
||||||
total_rows: Option<usize>,
|
|
||||||
active_tasks: usize,
|
|
||||||
total_tasks: usize,
|
|
||||||
done: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl WriteProgress {
|
|
||||||
/// Wall-clock time since monitoring started.
|
|
||||||
pub fn elapsed(&self) -> Duration {
|
|
||||||
self.elapsed
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Number of rows written so far.
|
|
||||||
pub fn output_rows(&self) -> usize {
|
|
||||||
self.output_rows
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Number of bytes written so far.
|
|
||||||
pub fn output_bytes(&self) -> usize {
|
|
||||||
self.output_bytes
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Total rows expected.
|
|
||||||
///
|
|
||||||
/// Populated when the input source reports a row count (e.g. a
|
|
||||||
/// [`arrow_array::RecordBatch`]). Always `Some` when [`WriteProgress::done`]
|
|
||||||
/// is `true` — falling back to the actual number of rows written.
|
|
||||||
pub fn total_rows(&self) -> Option<usize> {
|
|
||||||
self.total_rows
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Number of parallel write tasks currently in flight.
|
|
||||||
pub fn active_tasks(&self) -> usize {
|
|
||||||
self.active_tasks
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Total number of parallel write tasks (i.e. the write parallelism).
|
|
||||||
pub fn total_tasks(&self) -> usize {
|
|
||||||
self.total_tasks
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Whether the write operation has completed.
|
|
||||||
///
|
|
||||||
/// The final callback always has `done = true`. Callers can use this to
|
|
||||||
/// finalize progress bars or perform cleanup.
|
|
||||||
pub fn done(&self) -> bool {
|
|
||||||
self.done
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Callback type for progress updates.
|
|
||||||
///
|
|
||||||
/// Callbacks are serialized by the tracker and are never invoked reentrantly,
|
|
||||||
/// so `FnMut` is safe to use here.
|
|
||||||
pub type ProgressCallback = Arc<Mutex<dyn FnMut(&WriteProgress) + Send>>;
|
|
||||||
|
|
||||||
/// Tracks progress of a write operation and invokes a [`ProgressCallback`].
|
|
||||||
///
|
|
||||||
/// Call [`WriteProgressTracker::record_batch`] for each batch written.
|
|
||||||
/// Call [`WriteProgressTracker::finish`] once after all data is written.
|
|
||||||
///
|
|
||||||
/// The callback is never invoked reentrantly: all state updates and callback
|
|
||||||
/// invocations are serialized behind a single lock.
|
|
||||||
impl std::fmt::Debug for WriteProgressTracker {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
f.debug_struct("WriteProgressTracker")
|
|
||||||
.field("total_rows", &self.total_rows)
|
|
||||||
.finish()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) struct WriteProgressTracker {
|
|
||||||
rows_and_bytes: std::sync::Mutex<(usize, usize)>,
|
|
||||||
/// Wire bytes tracked separately by the insert layer. When set (> 0),
|
|
||||||
/// this takes precedence over the in-memory bytes from `rows_and_bytes`.
|
|
||||||
wire_bytes: AtomicUsize,
|
|
||||||
active_tasks: Arc<AtomicUsize>,
|
|
||||||
total_tasks: AtomicUsize,
|
|
||||||
start: Instant,
|
|
||||||
/// Known total rows from the input source, if available.
|
|
||||||
total_rows: Option<usize>,
|
|
||||||
callback: ProgressCallback,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl WriteProgressTracker {
|
|
||||||
pub fn new(callback: ProgressCallback, total_rows: Option<usize>) -> Self {
|
|
||||||
Self {
|
|
||||||
rows_and_bytes: std::sync::Mutex::new((0, 0)),
|
|
||||||
wire_bytes: AtomicUsize::new(0),
|
|
||||||
active_tasks: Arc::new(AtomicUsize::new(0)),
|
|
||||||
total_tasks: AtomicUsize::new(1),
|
|
||||||
start: Instant::now(),
|
|
||||||
total_rows,
|
|
||||||
callback,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the total number of parallel write tasks (the write parallelism).
|
|
||||||
pub fn set_total_tasks(&self, n: usize) {
|
|
||||||
self.total_tasks.store(n, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Increment the active task count. Returns a guard that decrements on drop.
|
|
||||||
pub fn track_task(&self) -> ActiveTaskGuard {
|
|
||||||
self.active_tasks.fetch_add(1, Ordering::Relaxed);
|
|
||||||
ActiveTaskGuard(self.active_tasks.clone())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Record a batch of rows passing through the scan node.
|
|
||||||
pub fn record_batch(&self, rows: usize, bytes: usize) {
|
|
||||||
// Lock order: callback first, then rows_and_bytes. This is the only
|
|
||||||
// order used anywhere, so deadlocks cannot occur.
|
|
||||||
let mut cb = self.callback.lock().unwrap_or_else(|e| e.into_inner());
|
|
||||||
let mut guard = self
|
|
||||||
.rows_and_bytes
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(|e| e.into_inner());
|
|
||||||
guard.0 += rows;
|
|
||||||
guard.1 += bytes;
|
|
||||||
let progress = self.snapshot(guard.0, guard.1, false);
|
|
||||||
drop(guard);
|
|
||||||
cb(&progress);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Record wire bytes from the insert layer (e.g. IPC-encoded bytes for
|
|
||||||
/// remote writes). When wire bytes are recorded, they take precedence over
|
|
||||||
/// the in-memory Arrow bytes tracked by [`record_batch`].
|
|
||||||
pub fn record_bytes(&self, bytes: usize) {
|
|
||||||
self.wire_bytes.fetch_add(bytes, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Emit the final progress callback indicating the write is complete.
|
|
||||||
///
|
|
||||||
/// `total_rows` is always `Some` on the final callback: it uses the known
|
|
||||||
/// total if available, or falls back to the number of rows actually written.
|
|
||||||
pub fn finish(&self) {
|
|
||||||
let mut cb = self.callback.lock().unwrap_or_else(|e| e.into_inner());
|
|
||||||
let guard = self
|
|
||||||
.rows_and_bytes
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(|e| e.into_inner());
|
|
||||||
let mut snap = self.snapshot(guard.0, guard.1, true);
|
|
||||||
snap.total_rows = Some(self.total_rows.unwrap_or(guard.0));
|
|
||||||
drop(guard);
|
|
||||||
cb(&snap);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn snapshot(&self, rows: usize, in_memory_bytes: usize, done: bool) -> WriteProgress {
|
|
||||||
let wire = self.wire_bytes.load(Ordering::Relaxed);
|
|
||||||
// Prefer wire bytes (actual I/O size) when the insert layer is
|
|
||||||
// tracking them; fall back to in-memory Arrow size otherwise.
|
|
||||||
// TODO: for local writes, track actual bytes written by Lance
|
|
||||||
// instead of using in-memory Arrow size as a proxy.
|
|
||||||
let output_bytes = if wire > 0 { wire } else { in_memory_bytes };
|
|
||||||
WriteProgress {
|
|
||||||
elapsed: self.start.elapsed(),
|
|
||||||
output_rows: rows,
|
|
||||||
output_bytes,
|
|
||||||
total_rows: self.total_rows,
|
|
||||||
active_tasks: self.active_tasks.load(Ordering::Relaxed),
|
|
||||||
total_tasks: self.total_tasks.load(Ordering::Relaxed),
|
|
||||||
done,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// RAII guard that decrements the active task count when dropped.
|
|
||||||
pub(crate) struct ActiveTaskGuard(Arc<AtomicUsize>);
|
|
||||||
|
|
||||||
impl Drop for ActiveTaskGuard {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
self.0.fetch_sub(1, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// RAII guard that calls [`WriteProgressTracker::finish`] on drop.
|
|
||||||
///
|
|
||||||
/// This ensures the final `done=true` callback fires even if the write
|
|
||||||
/// errors or the future is cancelled.
|
|
||||||
pub(crate) struct FinishOnDrop(pub Option<Arc<WriteProgressTracker>>);
|
|
||||||
|
|
||||||
impl Drop for FinishOnDrop {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
if let Some(t) = self.0.take() {
|
|
||||||
t.finish();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
||||||
|
|
||||||
use arrow_array::record_batch;
|
|
||||||
|
|
||||||
use crate::connect;
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_progress_monitor_fires_callback() {
|
|
||||||
let db = connect("memory://").execute().await.unwrap();
|
|
||||||
|
|
||||||
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
|
|
||||||
let table = db
|
|
||||||
.create_table("progress_test", batch)
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let callback_count = Arc::new(AtomicUsize::new(0));
|
|
||||||
let last_rows = Arc::new(AtomicUsize::new(0));
|
|
||||||
let max_active = Arc::new(AtomicUsize::new(0));
|
|
||||||
let last_total_tasks = Arc::new(AtomicUsize::new(0));
|
|
||||||
let cb_count = callback_count.clone();
|
|
||||||
let cb_rows = last_rows.clone();
|
|
||||||
let cb_active = max_active.clone();
|
|
||||||
let cb_total_tasks = last_total_tasks.clone();
|
|
||||||
|
|
||||||
let new_data = record_batch!(("id", Int32, [4, 5, 6])).unwrap();
|
|
||||||
table
|
|
||||||
.add(new_data)
|
|
||||||
.progress(move |p| {
|
|
||||||
cb_count.fetch_add(1, Ordering::SeqCst);
|
|
||||||
cb_rows.store(p.output_rows(), Ordering::SeqCst);
|
|
||||||
cb_active.fetch_max(p.active_tasks(), Ordering::SeqCst);
|
|
||||||
cb_total_tasks.store(p.total_tasks(), Ordering::SeqCst);
|
|
||||||
})
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(table.count_rows(None).await.unwrap(), 6);
|
|
||||||
assert!(callback_count.load(Ordering::SeqCst) >= 1);
|
|
||||||
// Progress tracks the newly inserted rows, not the total table size.
|
|
||||||
assert_eq!(last_rows.load(Ordering::SeqCst), 3);
|
|
||||||
// At least one callback should have seen an active task.
|
|
||||||
assert!(max_active.load(Ordering::SeqCst) >= 1);
|
|
||||||
// total_tasks should reflect the write parallelism.
|
|
||||||
assert!(last_total_tasks.load(Ordering::SeqCst) >= 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_progress_done_fires_at_end() {
|
|
||||||
let db = connect("memory://").execute().await.unwrap();
|
|
||||||
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
|
|
||||||
let table = db
|
|
||||||
.create_table("progress_done", batch)
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let seen_done = Arc::new(std::sync::Mutex::new(Vec::<bool>::new()));
|
|
||||||
let seen = seen_done.clone();
|
|
||||||
|
|
||||||
let new_data = record_batch!(("id", Int32, [4, 5, 6])).unwrap();
|
|
||||||
table
|
|
||||||
.add(new_data)
|
|
||||||
.progress(move |p| {
|
|
||||||
seen.lock().unwrap().push(p.done());
|
|
||||||
})
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let done_flags = seen_done.lock().unwrap();
|
|
||||||
assert!(!done_flags.is_empty(), "at least one callback must fire");
|
|
||||||
// Only the last callback should have done=true.
|
|
||||||
let last = *done_flags.last().unwrap();
|
|
||||||
assert!(last, "last callback must have done=true");
|
|
||||||
// All earlier callbacks should have done=false.
|
|
||||||
for &d in done_flags.iter().rev().skip(1) {
|
|
||||||
assert!(!d, "non-final callbacks must have done=false");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_progress_total_rows_known() {
|
|
||||||
let db = connect("memory://").execute().await.unwrap();
|
|
||||||
|
|
||||||
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
|
|
||||||
let table = db
|
|
||||||
.create_table("total_known", batch)
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let seen_total = Arc::new(std::sync::Mutex::new(Vec::new()));
|
|
||||||
let seen = seen_total.clone();
|
|
||||||
|
|
||||||
// RecordBatch implements Scannable with num_rows() -> Some(3)
|
|
||||||
let new_data = record_batch!(("id", Int32, [4, 5, 6])).unwrap();
|
|
||||||
table
|
|
||||||
.add(new_data)
|
|
||||||
.progress(move |p| {
|
|
||||||
seen.lock().unwrap().push(p.total_rows());
|
|
||||||
})
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let totals = seen_total.lock().unwrap();
|
|
||||||
// All callbacks (including done) should have total_rows = Some(3)
|
|
||||||
assert!(
|
|
||||||
totals.contains(&Some(3)),
|
|
||||||
"expected total_rows=Some(3) in at least one callback, got: {:?}",
|
|
||||||
*totals
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_progress_total_rows_unknown() {
|
|
||||||
use arrow_array::RecordBatchIterator;
|
|
||||||
|
|
||||||
let db = connect("memory://").execute().await.unwrap();
|
|
||||||
|
|
||||||
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
|
|
||||||
let table = db
|
|
||||||
.create_table("total_unknown", batch)
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let seen_total = Arc::new(std::sync::Mutex::new(Vec::new()));
|
|
||||||
let seen = seen_total.clone();
|
|
||||||
|
|
||||||
// RecordBatchReader does not provide num_rows, so total_rows should be
|
|
||||||
// None in intermediate callbacks but always Some on the done callback.
|
|
||||||
let schema = arrow_schema::Schema::new(vec![arrow_schema::Field::new(
|
|
||||||
"id",
|
|
||||||
arrow_schema::DataType::Int32,
|
|
||||||
false,
|
|
||||||
)]);
|
|
||||||
let new_data: Box<dyn arrow_array::RecordBatchReader + Send> =
|
|
||||||
Box::new(RecordBatchIterator::new(
|
|
||||||
vec![Ok(record_batch!(("id", Int32, [4, 5, 6])).unwrap())],
|
|
||||||
Arc::new(schema),
|
|
||||||
));
|
|
||||||
table
|
|
||||||
.add(new_data)
|
|
||||||
.progress(move |p| {
|
|
||||||
seen.lock().unwrap().push((p.total_rows(), p.done()));
|
|
||||||
})
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let entries = seen_total.lock().unwrap();
|
|
||||||
assert!(!entries.is_empty(), "at least one callback must fire");
|
|
||||||
for (total, done) in entries.iter() {
|
|
||||||
if *done {
|
|
||||||
assert!(
|
|
||||||
total.is_some(),
|
|
||||||
"done callback must have total_rows set, got: {:?}",
|
|
||||||
total
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
assert_eq!(
|
|
||||||
*total, None,
|
|
||||||
"intermediate callback must have total_rows=None, got: {:?}",
|
|
||||||
total
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_record_batch_recovers_from_poisoned_callback_lock() {
|
|
||||||
use super::{ProgressCallback, WriteProgressTracker};
|
|
||||||
use std::sync::Mutex;
|
|
||||||
|
|
||||||
let callback: ProgressCallback = Arc::new(Mutex::new(|_: &super::WriteProgress| {}));
|
|
||||||
|
|
||||||
// Poison the callback mutex
|
|
||||||
let cb_clone = callback.clone();
|
|
||||||
let handle = std::thread::spawn(move || {
|
|
||||||
let _guard = cb_clone.lock().unwrap();
|
|
||||||
panic!("intentional panic to poison callback mutex");
|
|
||||||
});
|
|
||||||
let _ = handle.join();
|
|
||||||
assert!(
|
|
||||||
callback.lock().is_err(),
|
|
||||||
"callback mutex should be poisoned"
|
|
||||||
);
|
|
||||||
|
|
||||||
let tracker = WriteProgressTracker::new(callback, Some(100));
|
|
||||||
|
|
||||||
// record_batch should not panic
|
|
||||||
tracker.record_batch(10, 1024);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_finish_recovers_from_poisoned_callback_lock() {
|
|
||||||
use super::{ProgressCallback, WriteProgressTracker};
|
|
||||||
use std::sync::Mutex;
|
|
||||||
|
|
||||||
let callback: ProgressCallback = Arc::new(Mutex::new(|_: &super::WriteProgress| {}));
|
|
||||||
|
|
||||||
// Poison the callback mutex
|
|
||||||
let cb_clone = callback.clone();
|
|
||||||
let handle = std::thread::spawn(move || {
|
|
||||||
let _guard = cb_clone.lock().unwrap();
|
|
||||||
panic!("intentional panic to poison callback mutex");
|
|
||||||
});
|
|
||||||
let _ = handle.join();
|
|
||||||
|
|
||||||
let tracker = WriteProgressTracker::new(callback, Some(100));
|
|
||||||
|
|
||||||
// finish should not panic
|
|
||||||
tracker.finish();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -122,7 +122,7 @@ where
|
|||||||
/// This is a cheap synchronous check useful as a fast path before
|
/// This is a cheap synchronous check useful as a fast path before
|
||||||
/// constructing a fetch closure for [`get()`](Self::get).
|
/// constructing a fetch closure for [`get()`](Self::get).
|
||||||
pub fn try_get(&self) -> Option<V> {
|
pub fn try_get(&self) -> Option<V> {
|
||||||
let cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
|
let cache = self.inner.lock().unwrap();
|
||||||
cache.state.fresh_value(self.ttl, self.refresh_window)
|
cache.state.fresh_value(self.ttl, self.refresh_window)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,7 +138,7 @@ where
|
|||||||
{
|
{
|
||||||
// Fast path: check if cache is fresh
|
// Fast path: check if cache is fresh
|
||||||
{
|
{
|
||||||
let cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
|
let cache = self.inner.lock().unwrap();
|
||||||
if let Some(value) = cache.state.fresh_value(self.ttl, self.refresh_window) {
|
if let Some(value) = cache.state.fresh_value(self.ttl, self.refresh_window) {
|
||||||
return Ok(value);
|
return Ok(value);
|
||||||
}
|
}
|
||||||
@@ -147,7 +147,7 @@ where
|
|||||||
// Slow path
|
// Slow path
|
||||||
let mut fetch = Some(fetch);
|
let mut fetch = Some(fetch);
|
||||||
let action = {
|
let action = {
|
||||||
let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
|
let mut cache = self.inner.lock().unwrap();
|
||||||
self.determine_action(&mut cache, &mut fetch)
|
self.determine_action(&mut cache, &mut fetch)
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -161,7 +161,7 @@ where
|
|||||||
///
|
///
|
||||||
/// This avoids a blocking fetch on the first [`get()`](Self::get) call.
|
/// This avoids a blocking fetch on the first [`get()`](Self::get) call.
|
||||||
pub fn seed(&self, value: V) {
|
pub fn seed(&self, value: V) {
|
||||||
let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
|
let mut cache = self.inner.lock().unwrap();
|
||||||
cache.state = State::Current(value, clock::now());
|
cache.state = State::Current(value, clock::now());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -170,7 +170,7 @@ where
|
|||||||
/// Any in-flight background fetch from before this call will not update the
|
/// Any in-flight background fetch from before this call will not update the
|
||||||
/// cache (the generation counter prevents stale writes).
|
/// cache (the generation counter prevents stale writes).
|
||||||
pub fn invalidate(&self) {
|
pub fn invalidate(&self) {
|
||||||
let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
|
let mut cache = self.inner.lock().unwrap();
|
||||||
cache.state = State::Empty;
|
cache.state = State::Empty;
|
||||||
cache.generation += 1;
|
cache.generation += 1;
|
||||||
}
|
}
|
||||||
@@ -267,7 +267,7 @@ where
|
|||||||
let fut_for_spawn = shared.clone();
|
let fut_for_spawn = shared.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let result = fut_for_spawn.await;
|
let result = fut_for_spawn.await;
|
||||||
let mut cache = inner.lock().unwrap_or_else(|e| e.into_inner());
|
let mut cache = inner.lock().unwrap();
|
||||||
// Only update if no invalidation has happened since we started
|
// Only update if no invalidation has happened since we started
|
||||||
if cache.generation != generation {
|
if cache.generation != generation {
|
||||||
return;
|
return;
|
||||||
@@ -590,67 +590,4 @@ mod tests {
|
|||||||
let v = cache.get(ok_fetcher(count.clone(), "fresh")).await.unwrap();
|
let v = cache.get(ok_fetcher(count.clone(), "fresh")).await.unwrap();
|
||||||
assert_eq!(v, "fresh");
|
assert_eq!(v, "fresh");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper: poison the inner mutex of a BackgroundCache.
|
|
||||||
fn poison_cache(cache: &BackgroundCache<String, TestError>) {
|
|
||||||
let inner = cache.inner.clone();
|
|
||||||
let handle = std::thread::spawn(move || {
|
|
||||||
let _guard = inner.lock().unwrap();
|
|
||||||
panic!("intentional panic to poison mutex");
|
|
||||||
});
|
|
||||||
let _ = handle.join();
|
|
||||||
assert!(cache.inner.lock().is_err(), "mutex should be poisoned");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_try_get_recovers_from_poisoned_lock() {
|
|
||||||
let cache = new_cache();
|
|
||||||
let count = Arc::new(AtomicUsize::new(0));
|
|
||||||
|
|
||||||
// Seed a value first
|
|
||||||
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
|
|
||||||
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap(); // peek
|
|
||||||
|
|
||||||
poison_cache(&cache);
|
|
||||||
|
|
||||||
// try_get() should not panic — it recovers via unwrap_or_else
|
|
||||||
let result = cache.try_get();
|
|
||||||
// The value may or may not be fresh depending on timing, but it must not panic
|
|
||||||
let _ = result;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_get_recovers_from_poisoned_lock() {
|
|
||||||
let cache = new_cache();
|
|
||||||
let count = Arc::new(AtomicUsize::new(0));
|
|
||||||
|
|
||||||
poison_cache(&cache);
|
|
||||||
|
|
||||||
// get() should not panic — it recovers and can still fetch
|
|
||||||
let result = cache.get(ok_fetcher(count.clone(), "recovered")).await;
|
|
||||||
assert!(result.is_ok());
|
|
||||||
assert_eq!(result.unwrap(), "recovered");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_seed_recovers_from_poisoned_lock() {
|
|
||||||
let cache = new_cache();
|
|
||||||
poison_cache(&cache);
|
|
||||||
|
|
||||||
// seed() should not panic
|
|
||||||
cache.seed("seeded".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_invalidate_recovers_from_poisoned_lock() {
|
|
||||||
let cache = new_cache();
|
|
||||||
let count = Arc::new(AtomicUsize::new(0));
|
|
||||||
|
|
||||||
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
|
|
||||||
|
|
||||||
poison_cache(&cache);
|
|
||||||
|
|
||||||
// invalidate() should not panic
|
|
||||||
cache.invalidate();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -335,85 +335,6 @@ impl Stream for TimeoutStream {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A `Stream` wrapper that slices oversized batches to enforce a maximum batch length.
|
|
||||||
pub struct MaxBatchLengthStream {
|
|
||||||
inner: SendableRecordBatchStream,
|
|
||||||
max_batch_length: Option<usize>,
|
|
||||||
buffered_batch: Option<RecordBatch>,
|
|
||||||
buffered_offset: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MaxBatchLengthStream {
|
|
||||||
pub fn new(inner: SendableRecordBatchStream, max_batch_length: usize) -> Self {
|
|
||||||
Self {
|
|
||||||
inner,
|
|
||||||
max_batch_length: (max_batch_length > 0).then_some(max_batch_length),
|
|
||||||
buffered_batch: None,
|
|
||||||
buffered_offset: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new_boxed(
|
|
||||||
inner: SendableRecordBatchStream,
|
|
||||||
max_batch_length: usize,
|
|
||||||
) -> SendableRecordBatchStream {
|
|
||||||
if max_batch_length == 0 {
|
|
||||||
inner
|
|
||||||
} else {
|
|
||||||
Box::pin(Self::new(inner, max_batch_length))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RecordBatchStream for MaxBatchLengthStream {
|
|
||||||
fn schema(&self) -> SchemaRef {
|
|
||||||
self.inner.schema()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Stream for MaxBatchLengthStream {
|
|
||||||
type Item = DataFusionResult<RecordBatch>;
|
|
||||||
|
|
||||||
fn poll_next(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> std::task::Poll<Option<Self::Item>> {
|
|
||||||
loop {
|
|
||||||
let Some(max_batch_length) = self.max_batch_length else {
|
|
||||||
return Pin::new(&mut self.inner).poll_next(cx);
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(batch) = self.buffered_batch.clone() {
|
|
||||||
if self.buffered_offset < batch.num_rows() {
|
|
||||||
let remaining = batch.num_rows() - self.buffered_offset;
|
|
||||||
let length = remaining.min(max_batch_length);
|
|
||||||
let sliced = batch.slice(self.buffered_offset, length);
|
|
||||||
self.buffered_offset += length;
|
|
||||||
if self.buffered_offset >= batch.num_rows() {
|
|
||||||
self.buffered_batch = None;
|
|
||||||
self.buffered_offset = 0;
|
|
||||||
}
|
|
||||||
return std::task::Poll::Ready(Some(Ok(sliced)));
|
|
||||||
}
|
|
||||||
|
|
||||||
self.buffered_batch = None;
|
|
||||||
self.buffered_offset = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
match Pin::new(&mut self.inner).poll_next(cx) {
|
|
||||||
std::task::Poll::Ready(Some(Ok(batch))) => {
|
|
||||||
if batch.num_rows() <= max_batch_length {
|
|
||||||
return std::task::Poll::Ready(Some(Ok(batch)));
|
|
||||||
}
|
|
||||||
self.buffered_batch = Some(batch);
|
|
||||||
self.buffered_offset = 0;
|
|
||||||
}
|
|
||||||
other => return other,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use arrow_array::Int32Array;
|
use arrow_array::Int32Array;
|
||||||
@@ -549,7 +470,7 @@ mod tests {
|
|||||||
assert_eq!(string_to_datatype(string), Some(expected));
|
assert_eq!(string_to_datatype(string), Some(expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample_batch(num_rows: i32) -> RecordBatch {
|
fn sample_batch() -> RecordBatch {
|
||||||
let schema = Arc::new(Schema::new(vec![Field::new(
|
let schema = Arc::new(Schema::new(vec![Field::new(
|
||||||
"col1",
|
"col1",
|
||||||
DataType::Int32,
|
DataType::Int32,
|
||||||
@@ -557,14 +478,14 @@ mod tests {
|
|||||||
)]));
|
)]));
|
||||||
RecordBatch::try_new(
|
RecordBatch::try_new(
|
||||||
schema.clone(),
|
schema.clone(),
|
||||||
vec![Arc::new(Int32Array::from_iter_values(0..num_rows))],
|
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_timeout_stream() {
|
async fn test_timeout_stream() {
|
||||||
let batch = sample_batch(3);
|
let batch = sample_batch();
|
||||||
let schema = batch.schema();
|
let schema = batch.schema();
|
||||||
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
||||||
|
|
||||||
@@ -594,7 +515,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_timeout_stream_zero_duration() {
|
async fn test_timeout_stream_zero_duration() {
|
||||||
let batch = sample_batch(3);
|
let batch = sample_batch();
|
||||||
let schema = batch.schema();
|
let schema = batch.schema();
|
||||||
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
||||||
|
|
||||||
@@ -613,7 +534,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_timeout_stream_completes_normally() {
|
async fn test_timeout_stream_completes_normally() {
|
||||||
let batch = sample_batch(3);
|
let batch = sample_batch();
|
||||||
let schema = batch.schema();
|
let schema = batch.schema();
|
||||||
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
||||||
|
|
||||||
@@ -631,35 +552,4 @@ mod tests {
|
|||||||
// Stream should be empty now
|
// Stream should be empty now
|
||||||
assert!(timeout_stream.next().await.is_none());
|
assert!(timeout_stream.next().await.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn collect_batch_sizes(
|
|
||||||
stream: SendableRecordBatchStream,
|
|
||||||
max_batch_length: usize,
|
|
||||||
) -> Vec<usize> {
|
|
||||||
let mut sliced_stream = MaxBatchLengthStream::new(stream, max_batch_length);
|
|
||||||
sliced_stream
|
|
||||||
.by_ref()
|
|
||||||
.map(|batch| batch.unwrap().num_rows())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_max_batch_length_stream_behaviors() {
|
|
||||||
let schema = sample_batch(7).schema();
|
|
||||||
let mock_stream = stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]);
|
|
||||||
|
|
||||||
let sendable_stream: SendableRecordBatchStream =
|
|
||||||
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
|
|
||||||
assert_eq!(
|
|
||||||
collect_batch_sizes(sendable_stream, 3).await,
|
|
||||||
vec![2, 3, 3, 1]
|
|
||||||
);
|
|
||||||
|
|
||||||
let sendable_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new(
|
|
||||||
schema,
|
|
||||||
stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]),
|
|
||||||
));
|
|
||||||
assert_eq!(collect_batch_sizes(sendable_stream, 0).await, vec![2, 7]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user