Compare commits

..

5 Commits

Author SHA1 Message Date
Lance Release
b67f13f642 Bump version: 0.21.2-beta.2 → 0.21.2 2025-03-26 16:27:05 +00:00
Lance Release
2f12d67469 Bump version: 0.21.2-beta.1 → 0.21.2-beta.2 2025-03-26 16:27:05 +00:00
Lance Release
8d7cc29abb Bump version: 0.18.2-beta.0 → 0.18.2-beta.1 2025-03-26 16:24:17 +00:00
Lance Release
a4404e9e18 Bump version: 0.21.2-beta.0 → 0.21.2-beta.1 2025-03-26 16:23:37 +00:00
Will Jones
077e5bb586 upgrade to 0.25.0 2025-03-26 09:19:48 -07:00
107 changed files with 1373 additions and 5990 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.19.0-beta.11" current_version = "0.18.2-beta.1"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -18,24 +18,17 @@ concurrency:
group: "pages" group: "pages"
cancel-in-progress: true cancel-in-progress: true
env:
# This reduces the disk space needed for the build
RUSTFLAGS: "-C debuginfo=0"
# according to: https://matklad.github.io/2021/09/04/fast-rust-builds.html
# CI builds are faster with incremental disabled.
CARGO_INCREMENTAL: "0"
jobs: jobs:
# Single deploy job since we're just deploying # Single deploy job since we're just deploying
build: build:
environment: environment:
name: github-pages name: github-pages
url: ${{ steps.deployment.outputs.page_url }} url: ${{ steps.deployment.outputs.page_url }}
runs-on: ubuntu-24.04 runs-on: buildjet-8vcpu-ubuntu-2204
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install dependencies needed for ubuntu - name: Install dependecies needed for ubuntu
run: | run: |
sudo apt install -y protobuf-compiler libssl-dev sudo apt install -y protobuf-compiler libssl-dev
rustup update && rustup default rustup update && rustup default
@@ -45,7 +38,6 @@ jobs:
python-version: "3.10" python-version: "3.10"
cache: "pip" cache: "pip"
cache-dependency-path: "docs/requirements.txt" cache-dependency-path: "docs/requirements.txt"
- uses: Swatinem/rust-cache@v2
- name: Build Python - name: Build Python
working-directory: python working-directory: python
run: | run: |
@@ -57,6 +49,7 @@ jobs:
node-version: 20 node-version: 20
cache: 'npm' cache: 'npm'
cache-dependency-path: node/package-lock.json cache-dependency-path: node/package-lock.json
- uses: Swatinem/rust-cache@v2
- name: Install node dependencies - name: Install node dependencies
working-directory: node working-directory: node
run: | run: |

View File

@@ -43,7 +43,7 @@ jobs:
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
- uses: actions-rust-lang/setup-rust-toolchain@v1 - uses: actions-rust-lang/setup-rust-toolchain@v1
with: with:
toolchain: "1.81.0" toolchain: "1.79.0"
cache-workspaces: "./java/core/lancedb-jni" cache-workspaces: "./java/core/lancedb-jni"
# Disable full debug symbol generation to speed up CI build and keep memory down # Disable full debug symbol generation to speed up CI build and keep memory down
# "1" means line tables only, which is useful for panic tracebacks. # "1" means line tables only, which is useful for panic tracebacks.
@@ -97,7 +97,7 @@ jobs:
- name: Dry run - name: Dry run
if: github.event_name == 'pull_request' if: github.event_name == 'pull_request'
run: | run: |
mvn --batch-mode -DskipTests -Drust.release.build=true package mvn --batch-mode -DskipTests package
- name: Set github - name: Set github
run: | run: |
git config --global user.email "LanceDB Github Runner" git config --global user.email "LanceDB Github Runner"
@@ -108,7 +108,7 @@ jobs:
echo "use-agent" >> ~/.gnupg/gpg.conf echo "use-agent" >> ~/.gnupg/gpg.conf
echo "pinentry-mode loopback" >> ~/.gnupg/gpg.conf echo "pinentry-mode loopback" >> ~/.gnupg/gpg.conf
export GPG_TTY=$(tty) export GPG_TTY=$(tty)
mvn --batch-mode -DskipTests -Drust.release.build=true -DpushChanges=false -Dgpg.passphrase=${{ secrets.GPG_PASSPHRASE }} deploy -P deploy-to-ossrh mvn --batch-mode -DskipTests -DpushChanges=false -Dgpg.passphrase=${{ secrets.GPG_PASSPHRASE }} deploy -P deploy-to-ossrh
env: env:
SONATYPE_USER: ${{ secrets.SONATYPE_USER }} SONATYPE_USER: ${{ secrets.SONATYPE_USER }}
SONATYPE_TOKEN: ${{ secrets.SONATYPE_TOKEN }} SONATYPE_TOKEN: ${{ secrets.SONATYPE_TOKEN }}

View File

@@ -18,7 +18,6 @@ on:
# This should trigger a dry run (we skip the final publish step) # This should trigger a dry run (we skip the final publish step)
paths: paths:
- .github/workflows/npm-publish.yml - .github/workflows/npm-publish.yml
- Cargo.toml # Change in dependency frequently breaks builds
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref }}
@@ -131,24 +130,29 @@ jobs:
set -e && set -e &&
apt-get update && apt-get update &&
apt-get install -y protobuf-compiler pkg-config apt-get install -y protobuf-compiler pkg-config
- target: x86_64-unknown-linux-musl
# This one seems to need some extra memory # TODO: re-enable x64 musl builds. I could not figure out why, but it
host: ubuntu-2404-8x-x64 # consistently made GHA runners non-responsive at the end of build. Example:
# https://github.com/napi-rs/napi-rs/blob/main/alpine.Dockerfile # https://github.com/lancedb/lancedb/actions/runs/13980431071/job/39144319470?pr=2250
docker: ghcr.io/napi-rs/napi-rs/nodejs-rust:lts-alpine
features: fp16kernels # - target: x86_64-unknown-linux-musl
pre_build: |- # # This one seems to need some extra memory
set -e && # host: ubuntu-2404-8x-x64
apk add protobuf-dev curl && # # https://github.com/napi-rs/napi-rs/blob/main/alpine.Dockerfile
ln -s /usr/lib/gcc/x86_64-alpine-linux-musl/14.2.0/crtbeginS.o /usr/lib/crtbeginS.o && # docker: ghcr.io/napi-rs/napi-rs/nodejs-rust:lts-alpine
ln -s /usr/lib/libgcc_s.so /usr/lib/libgcc.so && # features: ","
CC=gcc && # pre_build: |-
CXX=g++ # set -e &&
# apk add protobuf-dev curl &&
# ln -s /usr/lib/gcc/x86_64-alpine-linux-musl/14.2.0/crtbeginS.o /usr/lib/crtbeginS.o &&
# ln -s /usr/lib/libgcc_s.so /usr/lib/libgcc.so
- target: aarch64-unknown-linux-gnu - target: aarch64-unknown-linux-gnu
host: ubuntu-2404-8x-x64 host: ubuntu-2404-8x-x64
# https://github.com/napi-rs/napi-rs/blob/main/debian-aarch64.Dockerfile # https://github.com/napi-rs/napi-rs/blob/main/debian-aarch64.Dockerfile
docker: ghcr.io/napi-rs/napi-rs/nodejs-rust:lts-debian-aarch64 docker: ghcr.io/napi-rs/napi-rs/nodejs-rust:lts-debian-aarch64
features: "fp16kernels" # TODO: enable fp16kernels after https://github.com/lancedb/lance/pull/3559
features: ","
pre_build: |- pre_build: |-
set -e && set -e &&
apt-get update && apt-get update &&
@@ -166,8 +170,8 @@ jobs:
set -e && set -e &&
apk add protobuf-dev && apk add protobuf-dev &&
rustup target add aarch64-unknown-linux-musl && rustup target add aarch64-unknown-linux-musl &&
export CC_aarch64_unknown_linux_musl=aarch64-linux-musl-gcc && export CC="/aarch64-linux-musl-cross/bin/aarch64-linux-musl-gcc" &&
export CXX_aarch64_unknown_linux_musl=aarch64-linux-musl-g++ export CXX="/aarch64-linux-musl-cross/bin/aarch64-linux-musl-g++"
name: build - ${{ matrix.settings.target }} name: build - ${{ matrix.settings.target }}
runs-on: ${{ matrix.settings.host }} runs-on: ${{ matrix.settings.host }}
defaults: defaults:
@@ -531,12 +535,6 @@ jobs:
for filename in *.tgz; do for filename in *.tgz; do
npm publish $PUBLISH_ARGS $filename npm publish $PUBLISH_ARGS $filename
done done
- name: Deprecate
env:
NODE_AUTH_TOKEN: ${{ secrets.LANCEDB_NPM_REGISTRY_TOKEN }}
# We need to deprecate the old package to avoid confusion.
# Each time we publish a new version, it gets undeprecated.
run: npm deprecate vectordb "Use @lancedb/lancedb instead."
- name: Notify Slack Action - name: Notify Slack Action
uses: ravsamhq/notify-slack-action@2.3.0 uses: ravsamhq/notify-slack-action@2.3.0
if: ${{ always() }} if: ${{ always() }}

View File

@@ -8,7 +8,6 @@ on:
# This should trigger a dry run (we skip the final publish step) # This should trigger a dry run (we skip the final publish step)
paths: paths:
- .github/workflows/pypi-publish.yml - .github/workflows/pypi-publish.yml
- Cargo.toml # Change in dependency frequently breaks builds
jobs: jobs:
linux: linux:

View File

@@ -136,9 +136,9 @@ jobs:
- uses: ./.github/workflows/run_tests - uses: ./.github/workflows/run_tests
with: with:
integration: true integration: true
- name: Test without pylance or pandas - name: Test without pylance
run: | run: |
pip uninstall -y pylance pandas pip uninstall -y pylance
pytest -vv python/tests/test_table.py pytest -vv python/tests/test_table.py
# Make sure wheels are not included in the Rust cache # Make sure wheels are not included in the Rust cache
- name: Delete wheels - name: Delete wheels

916
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -21,14 +21,16 @@ categories = ["database-implementations"]
rust-version = "1.78.0" rust-version = "1.78.0"
[workspace.dependencies] [workspace.dependencies]
lance = { "version" = "=0.26.0", "features" = ["dynamodb"] } lance = { "version" = "=0.25.0", "features" = [
lance-io = "=0.26.0" "dynamodb",
lance-index = "=0.26.0" ] }
lance-linalg = "=0.26.0" lance-io = { version = "=0.25.0" }
lance-table = "=0.26.0" lance-index = { version = "=0.25.0" }
lance-testing = "=0.26.0" lance-linalg = { version = "=0.25.0" }
lance-datafusion = "=0.26.0" lance-table = { version = "=0.25.0" }
lance-encoding = "=0.26.0" lance-testing = { version = "=0.25.0" }
lance-datafusion = { version = "=0.25.0" }
lance-encoding = { version = "=0.25.0" }
# Note that this one does not include pyarrow # Note that this one does not include pyarrow
arrow = { version = "54.1", optional = false } arrow = { version = "54.1", optional = false }
arrow-array = "54.1" arrow-array = "54.1"
@@ -39,12 +41,12 @@ arrow-schema = "54.1"
arrow-arith = "54.1" arrow-arith = "54.1"
arrow-cast = "54.1" arrow-cast = "54.1"
async-trait = "0" async-trait = "0"
datafusion = { version = "46.0", default-features = false } datafusion = { version = "45.0", default-features = false }
datafusion-catalog = "46.0" datafusion-catalog = "45.0"
datafusion-common = { version = "46.0", default-features = false } datafusion-common = { version = "45.0", default-features = false }
datafusion-execution = "46.0" datafusion-execution = "45.0"
datafusion-expr = "46.0" datafusion-expr = "45.0"
datafusion-physical-plan = "46.0" datafusion-physical-plan = "45.0"
env_logger = "0.11" env_logger = "0.11"
half = { "version" = "=2.4.1", default-features = false, features = [ half = { "version" = "=2.4.1", default-features = false, features = [
"num-traits", "num-traits",

View File

@@ -2,7 +2,7 @@
LanceDB docs are deployed to https://lancedb.github.io/lancedb/. LanceDB docs are deployed to https://lancedb.github.io/lancedb/.
Docs is built and deployed automatically by [Github Actions](../.github/workflows/docs.yml) Docs is built and deployed automatically by [Github Actions](.github/workflows/docs.yml)
whenever a commit is pushed to the `main` branch. So it is possible for the docs to show whenever a commit is pushed to the `main` branch. So it is possible for the docs to show
unreleased features. unreleased features.

View File

@@ -342,7 +342,7 @@ For **read and write access**, LanceDB will need a policy such as:
"Action": [ "Action": [
"s3:PutObject", "s3:PutObject",
"s3:GetObject", "s3:GetObject",
"s3:DeleteObject" "s3:DeleteObject",
], ],
"Resource": "arn:aws:s3:::<bucket>/<prefix>/*" "Resource": "arn:aws:s3:::<bucket>/<prefix>/*"
}, },
@@ -374,7 +374,7 @@ For **read-only access**, LanceDB will need a policy such as:
{ {
"Effect": "Allow", "Effect": "Allow",
"Action": [ "Action": [
"s3:GetObject" "s3:GetObject",
], ],
"Resource": "arn:aws:s3:::<bucket>/<prefix>/*" "Resource": "arn:aws:s3:::<bucket>/<prefix>/*"
}, },

View File

@@ -765,10 +765,7 @@ This can be used to update zero to all rows depending on how many rows match the
]; ];
const tbl = await db.createTable("my_table", data) const tbl = await db.createTable("my_table", data)
await tbl.update({ await tbl.update({vector: [10, 10]}, { where: "x = 2"})
values: { vector: [10, 10] },
where: "x = 2"
});
``` ```
=== "vectordb (deprecated)" === "vectordb (deprecated)"
@@ -787,10 +784,7 @@ This can be used to update zero to all rows depending on how many rows match the
]; ];
const tbl = await db.createTable("my_table", data) const tbl = await db.createTable("my_table", data)
await tbl.update({ await tbl.update({ where: "x = 2", values: {vector: [10, 10]} })
where: "x = 2",
values: { vector: [10, 10] }
});
``` ```
#### Updating using a sql query #### Updating using a sql query

View File

@@ -1,67 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / BoostQuery
# Class: BoostQuery
Represents a full-text query interface.
This interface defines the structure and behavior for full-text queries,
including methods to retrieve the query type and convert the query to a dictionary format.
## Implements
- [`FullTextQuery`](../interfaces/FullTextQuery.md)
## Constructors
### new BoostQuery()
```ts
new BoostQuery(
positive,
negative,
options?): BoostQuery
```
Creates an instance of BoostQuery.
The boost returns documents that match the positive query,
but penalizes those that match the negative query.
the penalty is controlled by the `negativeBoost` parameter.
#### Parameters
* **positive**: [`FullTextQuery`](../interfaces/FullTextQuery.md)
The positive query that boosts the relevance score.
* **negative**: [`FullTextQuery`](../interfaces/FullTextQuery.md)
The negative query that reduces the relevance score.
* **options?**
Optional parameters for the boost query.
- `negativeBoost`: The boost factor for the negative query (default is 0.0).
* **options.negativeBoost?**: `number`
#### Returns
[`BoostQuery`](BoostQuery.md)
## Methods
### queryType()
```ts
queryType(): FullTextQueryType
```
The type of the full-text query.
#### Returns
[`FullTextQueryType`](../enumerations/FullTextQueryType.md)
#### Implementation of
[`FullTextQuery`](../interfaces/FullTextQuery.md).[`queryType`](../interfaces/FullTextQuery.md#querytype)

View File

@@ -1,70 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / MatchQuery
# Class: MatchQuery
Represents a full-text query interface.
This interface defines the structure and behavior for full-text queries,
including methods to retrieve the query type and convert the query to a dictionary format.
## Implements
- [`FullTextQuery`](../interfaces/FullTextQuery.md)
## Constructors
### new MatchQuery()
```ts
new MatchQuery(
query,
column,
options?): MatchQuery
```
Creates an instance of MatchQuery.
#### Parameters
* **query**: `string`
The text query to search for.
* **column**: `string`
The name of the column to search within.
* **options?**
Optional parameters for the match query.
- `boost`: The boost factor for the query (default is 1.0).
- `fuzziness`: The fuzziness level for the query (default is 0).
- `maxExpansions`: The maximum number of terms to consider for fuzzy matching (default is 50).
* **options.boost?**: `number`
* **options.fuzziness?**: `number`
* **options.maxExpansions?**: `number`
#### Returns
[`MatchQuery`](MatchQuery.md)
## Methods
### queryType()
```ts
queryType(): FullTextQueryType
```
The type of the full-text query.
#### Returns
[`FullTextQueryType`](../enumerations/FullTextQueryType.md)
#### Implementation of
[`FullTextQuery`](../interfaces/FullTextQuery.md).[`queryType`](../interfaces/FullTextQuery.md#querytype)

View File

@@ -1,64 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / MultiMatchQuery
# Class: MultiMatchQuery
Represents a full-text query interface.
This interface defines the structure and behavior for full-text queries,
including methods to retrieve the query type and convert the query to a dictionary format.
## Implements
- [`FullTextQuery`](../interfaces/FullTextQuery.md)
## Constructors
### new MultiMatchQuery()
```ts
new MultiMatchQuery(
query,
columns,
options?): MultiMatchQuery
```
Creates an instance of MultiMatchQuery.
#### Parameters
* **query**: `string`
The text query to search for across multiple columns.
* **columns**: `string`[]
An array of column names to search within.
* **options?**
Optional parameters for the multi-match query.
- `boosts`: An array of boost factors for each column (default is 1.0 for all).
* **options.boosts?**: `number`[]
#### Returns
[`MultiMatchQuery`](MultiMatchQuery.md)
## Methods
### queryType()
```ts
queryType(): FullTextQueryType
```
The type of the full-text query.
#### Returns
[`FullTextQueryType`](../enumerations/FullTextQueryType.md)
#### Implementation of
[`FullTextQuery`](../interfaces/FullTextQuery.md).[`queryType`](../interfaces/FullTextQuery.md#querytype)

View File

@@ -1,55 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / PhraseQuery
# Class: PhraseQuery
Represents a full-text query interface.
This interface defines the structure and behavior for full-text queries,
including methods to retrieve the query type and convert the query to a dictionary format.
## Implements
- [`FullTextQuery`](../interfaces/FullTextQuery.md)
## Constructors
### new PhraseQuery()
```ts
new PhraseQuery(query, column): PhraseQuery
```
Creates an instance of `PhraseQuery`.
#### Parameters
* **query**: `string`
The phrase to search for in the specified column.
* **column**: `string`
The name of the column to search within.
#### Returns
[`PhraseQuery`](PhraseQuery.md)
## Methods
### queryType()
```ts
queryType(): FullTextQueryType
```
The type of the full-text query.
#### Returns
[`FullTextQueryType`](../enumerations/FullTextQueryType.md)
#### Implementation of
[`FullTextQuery`](../interfaces/FullTextQuery.md).[`queryType`](../interfaces/FullTextQuery.md#querytype)

View File

@@ -30,53 +30,6 @@ protected inner: Query | Promise<Query>;
## Methods ## Methods
### analyzePlan()
```ts
analyzePlan(): Promise<string>
```
Executes the query and returns the physical query plan annotated with runtime metrics.
This is useful for debugging and performance analysis, as it shows how the query was executed
and includes metrics such as elapsed time, rows processed, and I/O statistics.
#### Returns
`Promise`&lt;`string`&gt;
A query execution plan with runtime metrics for each step.
#### Example
```ts
import * as lancedb from "@lancedb/lancedb"
const db = await lancedb.connect("./.lancedb");
const table = await db.createTable("my_table", [
{ vector: [1.1, 0.9], id: "1" },
]);
const plan = await table.query().nearestTo([0.5, 0.2]).analyzePlan();
Example output (with runtime metrics inlined):
AnalyzeExec verbose=true, metrics=[]
ProjectionExec: expr=[id@3 as id, vector@0 as vector, _distance@2 as _distance], metrics=[output_rows=1, elapsed_compute=3.292µs]
Take: columns="vector, _rowid, _distance, (id)", metrics=[output_rows=1, elapsed_compute=66.001µs, batches_processed=1, bytes_read=8, iops=1, requests=1]
CoalesceBatchesExec: target_batch_size=1024, metrics=[output_rows=1, elapsed_compute=3.333µs]
GlobalLimitExec: skip=0, fetch=10, metrics=[output_rows=1, elapsed_compute=167ns]
FilterExec: _distance@2 IS NOT NULL, metrics=[output_rows=1, elapsed_compute=8.542µs]
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], metrics=[output_rows=1, elapsed_compute=63.25µs, row_replacements=1]
KNNVectorDistance: metric=l2, metrics=[output_rows=1, elapsed_compute=114.333µs, output_batches=1]
LanceScan: uri=/path/to/data, projection=[vector], row_id=true, row_addr=false, ordered=false, metrics=[output_rows=1, elapsed_compute=103.626µs, bytes_read=549, iops=2, requests=2]
```
#### Inherited from
[`QueryBase`](QueryBase.md).[`analyzePlan`](QueryBase.md#analyzeplan)
***
### execute() ### execute()
```ts ```ts
@@ -206,7 +159,7 @@ fullTextSearch(query, options?): this
#### Parameters #### Parameters
* **query**: `string` \| [`FullTextQuery`](../interfaces/FullTextQuery.md) * **query**: `string`
* **options?**: `Partial`&lt;[`FullTextSearchOptions`](../interfaces/FullTextSearchOptions.md)&gt; * **options?**: `Partial`&lt;[`FullTextSearchOptions`](../interfaces/FullTextSearchOptions.md)&gt;
@@ -309,7 +262,7 @@ nearestToText(query, columns?): Query
#### Parameters #### Parameters
* **query**: `string` \| [`FullTextQuery`](../interfaces/FullTextQuery.md) * **query**: `string`
* **columns?**: `string`[] * **columns?**: `string`[]

View File

@@ -36,49 +36,6 @@ protected inner: NativeQueryType | Promise<NativeQueryType>;
## Methods ## Methods
### analyzePlan()
```ts
analyzePlan(): Promise<string>
```
Executes the query and returns the physical query plan annotated with runtime metrics.
This is useful for debugging and performance analysis, as it shows how the query was executed
and includes metrics such as elapsed time, rows processed, and I/O statistics.
#### Returns
`Promise`&lt;`string`&gt;
A query execution plan with runtime metrics for each step.
#### Example
```ts
import * as lancedb from "@lancedb/lancedb"
const db = await lancedb.connect("./.lancedb");
const table = await db.createTable("my_table", [
{ vector: [1.1, 0.9], id: "1" },
]);
const plan = await table.query().nearestTo([0.5, 0.2]).analyzePlan();
Example output (with runtime metrics inlined):
AnalyzeExec verbose=true, metrics=[]
ProjectionExec: expr=[id@3 as id, vector@0 as vector, _distance@2 as _distance], metrics=[output_rows=1, elapsed_compute=3.292µs]
Take: columns="vector, _rowid, _distance, (id)", metrics=[output_rows=1, elapsed_compute=66.001µs, batches_processed=1, bytes_read=8, iops=1, requests=1]
CoalesceBatchesExec: target_batch_size=1024, metrics=[output_rows=1, elapsed_compute=3.333µs]
GlobalLimitExec: skip=0, fetch=10, metrics=[output_rows=1, elapsed_compute=167ns]
FilterExec: _distance@2 IS NOT NULL, metrics=[output_rows=1, elapsed_compute=8.542µs]
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], metrics=[output_rows=1, elapsed_compute=63.25µs, row_replacements=1]
KNNVectorDistance: metric=l2, metrics=[output_rows=1, elapsed_compute=114.333µs, output_batches=1]
LanceScan: uri=/path/to/data, projection=[vector], row_id=true, row_addr=false, ordered=false, metrics=[output_rows=1, elapsed_compute=103.626µs, bytes_read=549, iops=2, requests=2]
```
***
### execute() ### execute()
```ts ```ts
@@ -192,7 +149,7 @@ fullTextSearch(query, options?): this
#### Parameters #### Parameters
* **query**: `string` \| [`FullTextQuery`](../interfaces/FullTextQuery.md) * **query**: `string`
* **options?**: `Partial`&lt;[`FullTextSearchOptions`](../interfaces/FullTextSearchOptions.md)&gt; * **options?**: `Partial`&lt;[`FullTextSearchOptions`](../interfaces/FullTextSearchOptions.md)&gt;

View File

@@ -454,28 +454,6 @@ Modeled after ``VACUUM`` in PostgreSQL.
*** ***
### prewarmIndex()
```ts
abstract prewarmIndex(name): Promise<void>
```
Prewarm an index in the table.
#### Parameters
* **name**: `string`
The name of the index.
This will load the index into memory. This may reduce the cold-start time for
future queries. If the index does not fit in the cache then this call may be
wasteful.
#### Returns
`Promise`&lt;`void`&gt;
***
### query() ### query()
```ts ```ts
@@ -597,7 +575,7 @@ of the given query
#### Parameters #### Parameters
* **query**: `string` \| [`IntoVector`](../type-aliases/IntoVector.md) \| [`FullTextQuery`](../interfaces/FullTextQuery.md) * **query**: `string` \| [`IntoVector`](../type-aliases/IntoVector.md)
the query, a vector or string the query, a vector or string
* **queryType?**: `string` * **queryType?**: `string`
@@ -753,26 +731,3 @@ Retrieve the version of the table
#### Returns #### Returns
`Promise`&lt;`number`&gt; `Promise`&lt;`number`&gt;
***
### waitForIndex()
```ts
abstract waitForIndex(indexNames, timeoutSeconds): Promise<void>
```
Waits for asynchronous indexing to complete on the table.
#### Parameters
* **indexNames**: `string`[]
The name of the indices to wait for
* **timeoutSeconds**: `number`
The number of seconds to wait before timing out
This will raise an error if the indices are not created and fully indexed within the timeout.
#### Returns
`Promise`&lt;`void`&gt;

View File

@@ -48,53 +48,6 @@ addQueryVector(vector): VectorQuery
*** ***
### analyzePlan()
```ts
analyzePlan(): Promise<string>
```
Executes the query and returns the physical query plan annotated with runtime metrics.
This is useful for debugging and performance analysis, as it shows how the query was executed
and includes metrics such as elapsed time, rows processed, and I/O statistics.
#### Returns
`Promise`&lt;`string`&gt;
A query execution plan with runtime metrics for each step.
#### Example
```ts
import * as lancedb from "@lancedb/lancedb"
const db = await lancedb.connect("./.lancedb");
const table = await db.createTable("my_table", [
{ vector: [1.1, 0.9], id: "1" },
]);
const plan = await table.query().nearestTo([0.5, 0.2]).analyzePlan();
Example output (with runtime metrics inlined):
AnalyzeExec verbose=true, metrics=[]
ProjectionExec: expr=[id@3 as id, vector@0 as vector, _distance@2 as _distance], metrics=[output_rows=1, elapsed_compute=3.292µs]
Take: columns="vector, _rowid, _distance, (id)", metrics=[output_rows=1, elapsed_compute=66.001µs, batches_processed=1, bytes_read=8, iops=1, requests=1]
CoalesceBatchesExec: target_batch_size=1024, metrics=[output_rows=1, elapsed_compute=3.333µs]
GlobalLimitExec: skip=0, fetch=10, metrics=[output_rows=1, elapsed_compute=167ns]
FilterExec: _distance@2 IS NOT NULL, metrics=[output_rows=1, elapsed_compute=8.542µs]
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], metrics=[output_rows=1, elapsed_compute=63.25µs, row_replacements=1]
KNNVectorDistance: metric=l2, metrics=[output_rows=1, elapsed_compute=114.333µs, output_batches=1]
LanceScan: uri=/path/to/data, projection=[vector], row_id=true, row_addr=false, ordered=false, metrics=[output_rows=1, elapsed_compute=103.626µs, bytes_read=549, iops=2, requests=2]
```
#### Inherited from
[`QueryBase`](QueryBase.md).[`analyzePlan`](QueryBase.md#analyzeplan)
***
### bypassVectorIndex() ### bypassVectorIndex()
```ts ```ts
@@ -347,7 +300,7 @@ fullTextSearch(query, options?): this
#### Parameters #### Parameters
* **query**: `string` \| [`FullTextQuery`](../interfaces/FullTextQuery.md) * **query**: `string`
* **options?**: `Partial`&lt;[`FullTextSearchOptions`](../interfaces/FullTextSearchOptions.md)&gt; * **options?**: `Partial`&lt;[`FullTextSearchOptions`](../interfaces/FullTextSearchOptions.md)&gt;

View File

@@ -1,46 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / FullTextQueryType
# Enumeration: FullTextQueryType
Enum representing the types of full-text queries supported.
- `Match`: Performs a full-text search for terms in the query string.
- `MatchPhrase`: Searches for an exact phrase match in the text.
- `Boost`: Boosts the relevance score of specific terms in the query.
- `MultiMatch`: Searches across multiple fields for the query terms.
## Enumeration Members
### Boost
```ts
Boost: "boost";
```
***
### Match
```ts
Match: "match";
```
***
### MatchPhrase
```ts
MatchPhrase: "match_phrase";
```
***
### MultiMatch
```ts
MultiMatch: "multi_match";
```

View File

@@ -9,20 +9,12 @@
- [embedding](namespaces/embedding/README.md) - [embedding](namespaces/embedding/README.md)
- [rerankers](namespaces/rerankers/README.md) - [rerankers](namespaces/rerankers/README.md)
## Enumerations
- [FullTextQueryType](enumerations/FullTextQueryType.md)
## Classes ## Classes
- [BoostQuery](classes/BoostQuery.md)
- [Connection](classes/Connection.md) - [Connection](classes/Connection.md)
- [Index](classes/Index.md) - [Index](classes/Index.md)
- [MakeArrowTableOptions](classes/MakeArrowTableOptions.md) - [MakeArrowTableOptions](classes/MakeArrowTableOptions.md)
- [MatchQuery](classes/MatchQuery.md)
- [MergeInsertBuilder](classes/MergeInsertBuilder.md) - [MergeInsertBuilder](classes/MergeInsertBuilder.md)
- [MultiMatchQuery](classes/MultiMatchQuery.md)
- [PhraseQuery](classes/PhraseQuery.md)
- [Query](classes/Query.md) - [Query](classes/Query.md)
- [QueryBase](classes/QueryBase.md) - [QueryBase](classes/QueryBase.md)
- [RecordBatchIterator](classes/RecordBatchIterator.md) - [RecordBatchIterator](classes/RecordBatchIterator.md)
@@ -41,7 +33,6 @@
- [CreateTableOptions](interfaces/CreateTableOptions.md) - [CreateTableOptions](interfaces/CreateTableOptions.md)
- [ExecutableQuery](interfaces/ExecutableQuery.md) - [ExecutableQuery](interfaces/ExecutableQuery.md)
- [FtsOptions](interfaces/FtsOptions.md) - [FtsOptions](interfaces/FtsOptions.md)
- [FullTextQuery](interfaces/FullTextQuery.md)
- [FullTextSearchOptions](interfaces/FullTextSearchOptions.md) - [FullTextSearchOptions](interfaces/FullTextSearchOptions.md)
- [HnswPqOptions](interfaces/HnswPqOptions.md) - [HnswPqOptions](interfaces/HnswPqOptions.md)
- [HnswSqOptions](interfaces/HnswSqOptions.md) - [HnswSqOptions](interfaces/HnswSqOptions.md)

View File

@@ -1,25 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / FullTextQuery
# Interface: FullTextQuery
Represents a full-text query interface.
This interface defines the structure and behavior for full-text queries,
including methods to retrieve the query type and convert the query to a dictionary format.
## Methods
### queryType()
```ts
queryType(): FullTextQueryType
```
The type of the full-text query.
#### Returns
[`FullTextQueryType`](../enumerations/FullTextQueryType.md)

View File

@@ -39,11 +39,3 @@ and the same name, then an error will be returned. This is true even if
that index is out of date. that index is out of date.
The default is true The default is true
***
### waitTimeoutSeconds?
```ts
optional waitTimeoutSeconds: number;
```

View File

@@ -20,13 +20,3 @@ The maximum number of rows to return in a single batch
Batches may have fewer rows if the underlying data is stored Batches may have fewer rows if the underlying data is stored
in smaller chunks. in smaller chunks.
***
### timeoutMs?
```ts
optional timeoutMs: number;
```
Timeout for query execution in milliseconds

View File

@@ -35,9 +35,3 @@ print the resolved query plan. You can use the `explain_plan` method to do this:
* Python Sync: [LanceQueryBuilder.explain_plan][lancedb.query.LanceQueryBuilder.explain_plan] * Python Sync: [LanceQueryBuilder.explain_plan][lancedb.query.LanceQueryBuilder.explain_plan]
* Python Async: [AsyncQueryBase.explain_plan][lancedb.query.AsyncQueryBase.explain_plan] * Python Async: [AsyncQueryBase.explain_plan][lancedb.query.AsyncQueryBase.explain_plan]
* Node @lancedb/lancedb: [LanceQueryBuilder.explainPlan](/lancedb/js/classes/QueryBase/#explainplan) * Node @lancedb/lancedb: [LanceQueryBuilder.explainPlan](/lancedb/js/classes/QueryBase/#explainplan)
To understand how a query was actually executed—including metrics like execution time, number of rows processed, I/O stats, and more—use the analyze_plan method. This executes the query and returns a physical execution plan annotated with runtime metrics, making it especially helpful for performance tuning and debugging.
* Python Sync: [LanceQueryBuilder.analyze_plan][lancedb.query.LanceQueryBuilder.analyze_plan]
* Python Async: [AsyncQueryBase.analyze_plan][lancedb.query.AsyncQueryBase.analyze_plan]
* Node @lancedb/lancedb: [LanceQueryBuilder.analyzePlan](/lancedb/js/classes/QueryBase/#analyzePlan)

View File

@@ -8,16 +8,13 @@
<parent> <parent>
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId> <artifactId>lancedb-parent</artifactId>
<version>0.19.0-beta.11</version> <version>0.18.2-beta.1</version>
<relativePath>../pom.xml</relativePath> <relativePath>../pom.xml</relativePath>
</parent> </parent>
<artifactId>lancedb-core</artifactId> <artifactId>lancedb-core</artifactId>
<name>LanceDB Core</name> <name>LanceDB Core</name>
<packaging>jar</packaging> <packaging>jar</packaging>
<properties>
<rust.release.build>false</rust.release.build>
</properties>
<dependencies> <dependencies>
<dependency> <dependency>
@@ -71,7 +68,7 @@
</goals> </goals>
<configuration> <configuration>
<path>lancedb-jni</path> <path>lancedb-jni</path>
<release>${rust.release.build}</release> <release>true</release>
<!-- Copy native libraries to target/classes for runtime access --> <!-- Copy native libraries to target/classes for runtime access -->
<copyTo>${project.build.directory}/classes/nativelib</copyTo> <copyTo>${project.build.directory}/classes/nativelib</copyTo>
<copyWithPlatformDir>true</copyWithPlatformDir> <copyWithPlatformDir>true</copyWithPlatformDir>

View File

@@ -1,25 +1,16 @@
/* // SPDX-License-Identifier: Apache-2.0
* Licensed under the Apache License, Version 2.0 (the "License"); // SPDX-FileCopyrightText: Copyright The LanceDB Authors
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.lancedb.lancedb; package com.lancedb.lancedb;
import io.questdb.jar.jni.JarJniLoader; import io.questdb.jar.jni.JarJniLoader;
import java.io.Closeable; import java.io.Closeable;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
/** Represents LanceDB database. */ /**
* Represents LanceDB database.
*/
public class Connection implements Closeable { public class Connection implements Closeable {
static { static {
JarJniLoader.loadLib(Connection.class, "/nativelib", "lancedb_jni"); JarJniLoader.loadLib(Connection.class, "/nativelib", "lancedb_jni");
@@ -27,11 +18,14 @@ public class Connection implements Closeable {
private long nativeConnectionHandle; private long nativeConnectionHandle;
/** Connect to a LanceDB instance. */ /**
* Connect to a LanceDB instance.
*/
public static native Connection connect(String uri); public static native Connection connect(String uri);
/** /**
* Get the names of all tables in the database. The names are sorted in ascending order. * Get the names of all tables in the database. The names are sorted in
* ascending order.
* *
* @return the table names * @return the table names
*/ */
@@ -40,7 +34,8 @@ public class Connection implements Closeable {
} }
/** /**
* Get the names of filtered tables in the database. The names are sorted in ascending order. * Get the names of filtered tables in the database. The names are sorted in
* ascending order.
* *
* @param limit The number of results to return. * @param limit The number of results to return.
* @return the table names * @return the table names
@@ -50,11 +45,12 @@ public class Connection implements Closeable {
} }
/** /**
* Get the names of filtered tables in the database. The names are sorted in ascending order. * Get the names of filtered tables in the database. The names are sorted in
* ascending order.
* *
* @param startAfter If present, only return names that come lexicographically after the supplied * @param startAfter If present, only return names that come lexicographically after the supplied
* value. This can be combined with limit to implement pagination by setting this to the last * value. This can be combined with limit to implement pagination
* table name from the previous page. * by setting this to the last table name from the previous page.
* @return the table names * @return the table names
*/ */
public List<String> tableNames(String startAfter) { public List<String> tableNames(String startAfter) {
@@ -62,11 +58,12 @@ public class Connection implements Closeable {
} }
/** /**
* Get the names of filtered tables in the database. The names are sorted in ascending order. * Get the names of filtered tables in the database. The names are sorted in
* ascending order.
* *
* @param startAfter If present, only return names that come lexicographically after the supplied * @param startAfter If present, only return names that come lexicographically after the supplied
* value. This can be combined with limit to implement pagination by setting this to the last * value. This can be combined with limit to implement pagination
* table name from the previous page. * by setting this to the last table name from the previous page.
* @param limit The number of results to return. * @param limit The number of results to return.
* @return the table names * @return the table names
*/ */
@@ -75,19 +72,22 @@ public class Connection implements Closeable {
} }
/** /**
* Get the names of filtered tables in the database. The names are sorted in ascending order. * Get the names of filtered tables in the database. The names are sorted in
* ascending order.
* *
* @param startAfter If present, only return names that come lexicographically after the supplied * @param startAfter If present, only return names that come lexicographically after the supplied
* value. This can be combined with limit to implement pagination by setting this to the last * value. This can be combined with limit to implement pagination
* table name from the previous page. * by setting this to the last table name from the previous page.
* @param limit The number of results to return. * @param limit The number of results to return.
* @return the table names * @return the table names
*/ */
public native List<String> tableNames(Optional<String> startAfter, Optional<Integer> limit); public native List<String> tableNames(
Optional<String> startAfter, Optional<Integer> limit);
/** /**
* Closes this connection and releases any system resources associated with it. If the connection * Closes this connection and releases any system resources associated with it. If
* is already closed, then invoking this method has no effect. * the connection is
* already closed, then invoking this method has no effect.
*/ */
@Override @Override
public void close() { public void close() {
@@ -98,7 +98,8 @@ public class Connection implements Closeable {
} }
/** /**
* Native method to release the Lance connection resources associated with the given handle. * Native method to release the Lance connection resources associated with the
* given handle.
* *
* @param handle The native handle to the connection resource. * @param handle The native handle to the connection resource.
*/ */

View File

@@ -1,35 +1,27 @@
/* // SPDX-License-Identifier: Apache-2.0
* Licensed under the Apache License, Version 2.0 (the "License"); // SPDX-FileCopyrightText: Copyright The LanceDB Authors
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.lancedb.lancedb; package com.lancedb.lancedb;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import java.net.URL;
import java.nio.file.Path;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import java.nio.file.Path;
import java.util.List;
import java.net.URL;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
public class ConnectionTest { public class ConnectionTest {
private static final String[] TABLE_NAMES = { private static final String[] TABLE_NAMES = {
"dataset_version", "new_empty_dataset", "test", "write_stream" "dataset_version",
"new_empty_dataset",
"test",
"write_stream"
}; };
@TempDir static Path tempDir; // Temporary directory for the tests @TempDir
static Path tempDir; // Temporary directory for the tests
private static URL lanceDbURL; private static URL lanceDbURL;
@BeforeAll @BeforeAll
@@ -61,21 +53,18 @@ public class ConnectionTest {
@Test @Test
void tableNamesStartAfter() { void tableNamesStartAfter() {
try (Connection conn = Connection.connect(lanceDbURL.toString())) { try (Connection conn = Connection.connect(lanceDbURL.toString())) {
assertTableNamesStartAfter( assertTableNamesStartAfter(conn, TABLE_NAMES[0], 3, TABLE_NAMES[1], TABLE_NAMES[2], TABLE_NAMES[3]);
conn, TABLE_NAMES[0], 3, TABLE_NAMES[1], TABLE_NAMES[2], TABLE_NAMES[3]);
assertTableNamesStartAfter(conn, TABLE_NAMES[1], 2, TABLE_NAMES[2], TABLE_NAMES[3]); assertTableNamesStartAfter(conn, TABLE_NAMES[1], 2, TABLE_NAMES[2], TABLE_NAMES[3]);
assertTableNamesStartAfter(conn, TABLE_NAMES[2], 1, TABLE_NAMES[3]); assertTableNamesStartAfter(conn, TABLE_NAMES[2], 1, TABLE_NAMES[3]);
assertTableNamesStartAfter(conn, TABLE_NAMES[3], 0); assertTableNamesStartAfter(conn, TABLE_NAMES[3], 0);
assertTableNamesStartAfter( assertTableNamesStartAfter(conn, "a_dataset", 4, TABLE_NAMES[0], TABLE_NAMES[1], TABLE_NAMES[2], TABLE_NAMES[3]);
conn, "a_dataset", 4, TABLE_NAMES[0], TABLE_NAMES[1], TABLE_NAMES[2], TABLE_NAMES[3]);
assertTableNamesStartAfter(conn, "o_dataset", 2, TABLE_NAMES[2], TABLE_NAMES[3]); assertTableNamesStartAfter(conn, "o_dataset", 2, TABLE_NAMES[2], TABLE_NAMES[3]);
assertTableNamesStartAfter(conn, "v_dataset", 1, TABLE_NAMES[3]); assertTableNamesStartAfter(conn, "v_dataset", 1, TABLE_NAMES[3]);
assertTableNamesStartAfter(conn, "z_dataset", 0); assertTableNamesStartAfter(conn, "z_dataset", 0);
} }
} }
private void assertTableNamesStartAfter( private void assertTableNamesStartAfter(Connection conn, String startAfter, int expectedSize, String... expectedNames) {
Connection conn, String startAfter, int expectedSize, String... expectedNames) {
List<String> tableNames = conn.tableNames(startAfter); List<String> tableNames = conn.tableNames(startAfter);
assertEquals(expectedSize, tableNames.size()); assertEquals(expectedSize, tableNames.size());
for (int i = 0; i < expectedNames.length; i++) { for (int i = 0; i < expectedNames.length; i++) {

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId> <artifactId>lancedb-parent</artifactId>
<version>0.19.0-beta.11</version> <version>0.18.2-beta.1</version>
<packaging>pom</packaging> <packaging>pom</packaging>
<name>LanceDB Parent</name> <name>LanceDB Parent</name>
@@ -29,25 +29,6 @@
<properties> <properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version> <arrow.version>15.0.0</arrow.version>
<spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
<spotless.delimiter>package</spotless.delimiter>
<spotless.license.header>
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
</spotless.license.header>
</properties> </properties>
<modules> <modules>
@@ -146,8 +127,7 @@
<configuration> <configuration>
<configLocation>google_checks.xml</configLocation> <configLocation>google_checks.xml</configLocation>
<consoleOutput>true</consoleOutput> <consoleOutput>true</consoleOutput>
<failsOnError>false</failsOnError> <failsOnError>true</failsOnError>
<failOnViolation>false</failOnViolation>
<violationSeverity>warning</violationSeverity> <violationSeverity>warning</violationSeverity>
<linkXRef>false</linkXRef> <linkXRef>false</linkXRef>
</configuration> </configuration>
@@ -161,10 +141,6 @@
</execution> </execution>
</executions> </executions>
</plugin> </plugin>
<plugin>
<groupId>com.diffplug.spotless</groupId>
<artifactId>spotless-maven-plugin</artifactId>
</plugin>
</plugins> </plugins>
<pluginManagement> <pluginManagement>
<plugins> <plugins>
@@ -203,54 +179,6 @@
<artifactId>maven-install-plugin</artifactId> <artifactId>maven-install-plugin</artifactId>
<version>2.5.2</version> <version>2.5.2</version>
</plugin> </plugin>
<plugin>
<groupId>com.diffplug.spotless</groupId>
<artifactId>spotless-maven-plugin</artifactId>
<version>${spotless.version}</version>
<configuration>
<skip>${spotless.skip}</skip>
<upToDateChecking>
<enabled>true</enabled>
</upToDateChecking>
<java>
<includes>
<include>src/main/java/**/*.java</include>
<include>src/test/java/**/*.java</include>
</includes>
<googleJavaFormat>
<version>${spotless.java.googlejavaformat.version}</version>
<style>GOOGLE</style>
</googleJavaFormat>
<importOrder>
<order>com.lancedb.lance,,javax,java,\#</order>
</importOrder>
<removeUnusedImports />
</java>
<scala>
<includes>
<include>src/main/scala/**/*.scala</include>
<include>src/main/scala-*/**/*.scala</include>
<include>src/test/scala/**/*.scala</include>
<include>src/test/scala-*/**/*.scala</include>
</includes>
</scala>
<licenseHeader>
<content>${spotless.license.header}</content>
<delimiter>${spotless.delimiter}</delimiter>
</licenseHeader>
</configuration>
<executions>
<execution>
<id>spotless-check</id>
<phase>validate</phase>
<goals>
<goal>apply</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins> </plugins>
</pluginManagement> </pluginManagement>
</build> </build>

51
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.19.0-beta.11", "version": "0.18.2-beta.0",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "vectordb", "name": "vectordb",
"version": "0.19.0-beta.11", "version": "0.18.2-beta.0",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"
@@ -52,11 +52,11 @@
"uuid": "^9.0.0" "uuid": "^9.0.0"
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.19.0-beta.11", "@lancedb/vectordb-darwin-arm64": "0.18.2-beta.0",
"@lancedb/vectordb-darwin-x64": "0.19.0-beta.11", "@lancedb/vectordb-darwin-x64": "0.18.2-beta.0",
"@lancedb/vectordb-linux-arm64-gnu": "0.19.0-beta.11", "@lancedb/vectordb-linux-arm64-gnu": "0.18.2-beta.0",
"@lancedb/vectordb-linux-x64-gnu": "0.19.0-beta.11", "@lancedb/vectordb-linux-x64-gnu": "0.18.2-beta.0",
"@lancedb/vectordb-win32-x64-msvc": "0.19.0-beta.11" "@lancedb/vectordb-win32-x64-msvc": "0.18.2-beta.0"
}, },
"peerDependencies": { "peerDependencies": {
"@apache-arrow/ts": "^14.0.2", "@apache-arrow/ts": "^14.0.2",
@@ -327,9 +327,9 @@
} }
}, },
"node_modules/@lancedb/vectordb-darwin-arm64": { "node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.19.0-beta.11", "version": "0.18.2-beta.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.19.0-beta.11.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.18.2-beta.0.tgz",
"integrity": "sha512-fLefGJYdlIRIjrJj8MU1r5Zix5LpKktpCYilA7tZrfvBWkubGceJ+U6RPsWz7VGBfWcETo3g5CBooUPhbtSMlQ==", "integrity": "sha512-FzIcElkS6R5I5kU1S5m7yLVTB1Duv1XcmZQtVmYl/JjNlfxS1WTtMzdzMqSBFohDcgU2Tkc5+1FpK1B94dUUbg==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -340,9 +340,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-darwin-x64": { "node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.19.0-beta.11", "version": "0.18.2-beta.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.19.0-beta.11.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.18.2-beta.0.tgz",
"integrity": "sha512-FkCa1TbPLDXAGhlRI4tafcltzApCsyvgi+I+kX07u5DKPNQVALpQ3R6X6GLlbiFsAFBdyv9t2fqQ9DlgjJIZpA==", "integrity": "sha512-jv+XludfLNBDm1DjdqyghwDMtd4E+ygwycQpkpK72wyZSh6Qytrgq+4dNi/zCZ3UChFLbKbIxrVxv9yENQn2Pg==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -353,9 +353,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-linux-arm64-gnu": { "node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.19.0-beta.11", "version": "0.18.2-beta.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.19.0-beta.11.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.18.2-beta.0.tgz",
"integrity": "sha512-iZkL/01HNUNQ8pGK0+hoNyrM7P1YtShsyIQVzJMfo41SAofCBf9qvi9YaYyd49sDb+dQXeRn1+cfaJ9siz1OHw==", "integrity": "sha512-8/fBpbNYhhpetf/pZv0DyPnQkeAbsiICMyCoRiNu5auvQK4AsGF1XvLWrDi68u9F0GysBKvuatYuGqa/yh+Anw==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -366,9 +366,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-linux-x64-gnu": { "node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.19.0-beta.11", "version": "0.18.2-beta.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.19.0-beta.11.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.18.2-beta.0.tgz",
"integrity": "sha512-MdKRHxe2tRQqmExNLv3f6Wvx1mEi98eFtD0ysm4tNrQdaS1MJbTp+DUehrRKkfDWsooalHkIi9d02BVw5qseUQ==", "integrity": "sha512-7a1Kc/2V2ff4HlLzXyXVdK0Z0VIFUt50v2SBRdlcycJ0NLW9ZqV+9UjB/NAOwMXVgYd7d3rKjACGkQzkpvcyeg==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -379,9 +379,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-win32-x64-msvc": { "node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.19.0-beta.11", "version": "0.18.2-beta.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.19.0-beta.11.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.18.2-beta.0.tgz",
"integrity": "sha512-KWy+t9jr0feJAW9NkmM/w9kfdpp78+7mkeh9lb0g3xI3OdYU1yizNqFjbIQqJf7/L4sou4wmOjAC+FcP8qCtzg==", "integrity": "sha512-EeCiSf2RtJMESnkIca28GI6rAStYj2q9sVIyNCXpmIZSkJVpfQ3iswHGAbHrEfaPl0J1Re9cnRHLLuqkumwiIQ==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -1184,10 +1184,9 @@
} }
}, },
"node_modules/axios": { "node_modules/axios": {
"version": "1.8.4", "version": "1.7.7",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.8.4.tgz", "resolved": "https://registry.npmjs.org/axios/-/axios-1.7.7.tgz",
"integrity": "sha512-eBSYY4Y68NNlHbHBMdeDmKNtDgXWhQsJcGqzO3iLUM0GraQFSS9cVgPX5I9b3lbdFKyYoAEGAZF1DwhTaljNAw==", "integrity": "sha512-S4kL7XrjgBmvdGut0sN3yJxqYzrDOnivkBiN0OFs6hLiUam3UPvswUo0kqGyhqUZGEOytHyumEdXsAkgCOUf3Q==",
"license": "MIT",
"dependencies": { "dependencies": {
"follow-redirects": "^1.15.6", "follow-redirects": "^1.15.6",
"form-data": "^4.0.0", "form-data": "^4.0.0",

View File

@@ -1,6 +1,6 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.19.0-beta.11", "version": "0.18.2-beta.1",
"description": " Serverless, low-latency vector database for AI applications", "description": " Serverless, low-latency vector database for AI applications",
"private": false, "private": false,
"main": "dist/index.js", "main": "dist/index.js",
@@ -89,10 +89,10 @@
} }
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-x64": "0.19.0-beta.11", "@lancedb/vectordb-darwin-x64": "0.18.2-beta.1",
"@lancedb/vectordb-darwin-arm64": "0.19.0-beta.11", "@lancedb/vectordb-darwin-arm64": "0.18.2-beta.1",
"@lancedb/vectordb-linux-x64-gnu": "0.19.0-beta.11", "@lancedb/vectordb-linux-x64-gnu": "0.18.2-beta.1",
"@lancedb/vectordb-linux-arm64-gnu": "0.19.0-beta.11", "@lancedb/vectordb-linux-arm64-gnu": "0.18.2-beta.1",
"@lancedb/vectordb-win32-x64-msvc": "0.19.0-beta.11" "@lancedb/vectordb-win32-x64-msvc": "0.18.2-beta.1"
} }
} }

View File

@@ -1,7 +1,7 @@
[package] [package]
name = "lancedb-nodejs" name = "lancedb-nodejs"
edition.workspace = true edition.workspace = true
version = "0.19.0-beta.11" version = "0.18.2-beta.1"
license.workspace = true license.workspace = true
description.workspace = true description.workspace = true
repository.workspace = true repository.workspace = true
@@ -28,9 +28,6 @@ napi-derive = "2.16.4"
lzma-sys = { version = "*", features = ["static"] } lzma-sys = { version = "*", features = ["static"] }
log.workspace = true log.workspace = true
# Workaround for build failure until we can fix it.
aws-lc-sys = "=0.28.0"
[build-dependencies] [build-dependencies]
napi-build = "2.1" napi-build = "2.1"

View File

@@ -10,7 +10,7 @@ import * as arrow16 from "apache-arrow-16";
import * as arrow17 from "apache-arrow-17"; import * as arrow17 from "apache-arrow-17";
import * as arrow18 from "apache-arrow-18"; import * as arrow18 from "apache-arrow-18";
import { MatchQuery, PhraseQuery, Table, connect } from "../lancedb"; import { Table, connect } from "../lancedb";
import { import {
Table as ArrowTable, Table as ArrowTable,
Field, Field,
@@ -33,7 +33,6 @@ import {
register, register,
} from "../lancedb/embedding"; } from "../lancedb/embedding";
import { Index } from "../lancedb/indices"; import { Index } from "../lancedb/indices";
import { instanceOfFullTextQuery } from "../lancedb/query";
describe.each([arrow15, arrow16, arrow17, arrow18])( describe.each([arrow15, arrow16, arrow17, arrow18])(
"Given a table", "Given a table",
@@ -507,15 +506,6 @@ describe("When creating an index", () => {
expect(indices2.length).toBe(0); expect(indices2.length).toBe(0);
}); });
it("should wait for index readiness", async () => {
// Create an index and then wait for it to be ready
await tbl.createIndex("vec");
const indices = await tbl.listIndices();
expect(indices.length).toBeGreaterThan(0);
const idxName = indices[0].name;
await expect(tbl.waitForIndex([idxName], 5)).resolves.toBeUndefined();
});
it("should search with distance range", async () => { it("should search with distance range", async () => {
await tbl.createIndex("vec"); await tbl.createIndex("vec");
@@ -643,23 +633,6 @@ describe("When creating an index", () => {
expect(plan2).not.toMatch("LanceScan"); expect(plan2).not.toMatch("LanceScan");
}); });
it("should be able to run analyze plan", async () => {
await tbl.createIndex("vec");
await tbl.add([
{
id: 300,
vec: Array(32)
.fill(1)
.map(() => Math.random()),
tags: [],
},
]);
const plan = await tbl.query().nearestTo(queryVec).analyzePlan();
expect(plan).toMatch("AnalyzeExec");
expect(plan).toMatch("metrics=");
});
it("should be able to query with row id", async () => { it("should be able to query with row id", async () => {
const results = await tbl const results = await tbl
.query() .query()
@@ -833,7 +806,6 @@ describe("When creating an index", () => {
// Only build index over v1 // Only build index over v1
await tbl.createIndex("vec", { await tbl.createIndex("vec", {
config: Index.ivfPq({ numPartitions: 2, numSubVectors: 2 }), config: Index.ivfPq({ numPartitions: 2, numSubVectors: 2 }),
waitTimeoutSeconds: 30,
}); });
const rst = await tbl const rst = await tbl
@@ -878,44 +850,6 @@ describe("When creating an index", () => {
}); });
}); });
describe("When querying a table", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => tmpDir.removeCallback());
it("should throw an error when timeout is reached", async () => {
const db = await connect(tmpDir.name);
const data = makeArrowTable([
{ text: "a", vector: [0.1, 0.2] },
{ text: "b", vector: [0.3, 0.4] },
]);
const table = await db.createTable("test", data);
await table.createIndex("text", { config: Index.fts() });
await expect(
table.query().where("text != 'a'").toArray({ timeoutMs: 0 }),
).rejects.toThrow("Query timeout");
await expect(
table.query().nearestTo([0.0, 0.0]).toArrow({ timeoutMs: 0 }),
).rejects.toThrow("Query timeout");
await expect(
table.search("a", "fts").toArray({ timeoutMs: 0 }),
).rejects.toThrow("Query timeout");
await expect(
table
.query()
.nearestToText("a")
.nearestTo([0.0, 0.0])
.toArrow({ timeoutMs: 0 }),
).rejects.toThrow("Query timeout");
});
});
describe("Read consistency interval", () => { describe("Read consistency interval", () => {
let tmpDir: tmp.DirResult; let tmpDir: tmp.DirResult;
beforeEach(() => { beforeEach(() => {
@@ -1313,56 +1247,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
const results = await table.search("hello").toArray(); const results = await table.search("hello").toArray();
expect(results[0].text).toBe(data[0].text); expect(results[0].text).toBe(data[0].text);
const query = new MatchQuery("goodbye", "text");
expect(instanceOfFullTextQuery(query)).toBe(true);
const results2 = await table
.search(new MatchQuery("goodbye", "text"))
.toArray();
expect(results2[0].text).toBe(data[1].text);
});
test("prewarm full text search index", async () => {
const db = await connect(tmpDir.name);
const data = [
{ text: ["lance database", "the", "search"], vector: [0.1, 0.2, 0.3] },
{ text: ["lance database"], vector: [0.4, 0.5, 0.6] },
{ text: ["lance", "search"], vector: [0.7, 0.8, 0.9] },
{ text: ["database", "search"], vector: [1.0, 1.1, 1.2] },
{ text: ["unrelated", "doc"], vector: [1.3, 1.4, 1.5] },
];
const table = await db.createTable("test", data);
await table.createIndex("text", {
config: Index.fts(),
});
// For the moment, we just confirm we can call prewarmIndex without error
// and still search it afterwards
await table.prewarmIndex("text_idx");
const results = await table.search("lance").toArray();
expect(results.length).toBe(3);
});
test("full text index on list", async () => {
const db = await connect(tmpDir.name);
const data = [
{ text: ["lance database", "the", "search"], vector: [0.1, 0.2, 0.3] },
{ text: ["lance database"], vector: [0.4, 0.5, 0.6] },
{ text: ["lance", "search"], vector: [0.7, 0.8, 0.9] },
{ text: ["database", "search"], vector: [1.0, 1.1, 1.2] },
{ text: ["unrelated", "doc"], vector: [1.3, 1.4, 1.5] },
];
const table = await db.createTable("test", data);
await table.createIndex("text", {
config: Index.fts(),
});
const results = await table.search("lance").toArray();
expect(results.length).toBe(3);
const results2 = await table.search('"lance database"').toArray();
expect(results2.length).toBe(2);
}); });
test("full text search without positions", async () => { test("full text search without positions", async () => {
@@ -1415,43 +1299,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
expect(results.length).toBe(2); expect(results.length).toBe(2);
const phraseResults = await table.search('"hello world"').toArray(); const phraseResults = await table.search('"hello world"').toArray();
expect(phraseResults.length).toBe(1); expect(phraseResults.length).toBe(1);
const phraseResults2 = await table
.search(new PhraseQuery("hello world", "text"))
.toArray();
expect(phraseResults2.length).toBe(1);
});
test("full text search fuzzy query", async () => {
const db = await connect(tmpDir.name);
const data = [
{ text: "fa", vector: [0.1, 0.2, 0.3] },
{ text: "fo", vector: [0.4, 0.5, 0.6] },
{ text: "fob", vector: [0.4, 0.5, 0.6] },
{ text: "focus", vector: [0.4, 0.5, 0.6] },
{ text: "foo", vector: [0.4, 0.5, 0.6] },
{ text: "food", vector: [0.4, 0.5, 0.6] },
{ text: "foul", vector: [0.4, 0.5, 0.6] },
];
const table = await db.createTable("test", data);
await table.createIndex("text", {
config: Index.fts(),
});
const results = await table
.search(new MatchQuery("foo", "text"))
.toArray();
expect(results.length).toBe(1);
expect(results[0].text).toBe("foo");
const fuzzyResults = await table
.search(new MatchQuery("foo", "text", { fuzziness: 1 }))
.toArray();
expect(fuzzyResults.length).toBe(4);
const resultSet = new Set(fuzzyResults.map((r) => r.text));
expect(resultSet.has("foo")).toBe(true);
expect(resultSet.has("fob")).toBe(true);
expect(resultSet.has("fo")).toBe(true);
expect(resultSet.has("food")).toBe(true);
}); });
test.each([ test.each([
@@ -1499,30 +1346,6 @@ describe("when calling explainPlan", () => {
}); });
}); });
describe("when calling analyzePlan", () => {
let tmpDir: tmp.DirResult;
let table: Table;
let queryVec: number[];
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
const con = await connect(tmpDir.name);
table = await con.createTable("vectors", [{ id: 1, vector: [1.1, 0.9] }]);
});
afterEach(() => {
tmpDir.removeCallback();
});
it("retrieves runtime metrics", async () => {
queryVec = Array(2)
.fill(1)
.map(() => Math.random());
const plan = await table.query().nearestTo(queryVec).analyzePlan();
console.log("Query Plan:\n", plan); // <--- Print the plan
expect(plan).toMatch("AnalyzeExec");
});
});
describe("column name options", () => { describe("column name options", () => {
let tmpDir: tmp.DirResult; let tmpDir: tmp.DirResult;
let table: Table; let table: Table;

View File

@@ -47,12 +47,6 @@ export {
QueryExecutionOptions, QueryExecutionOptions,
FullTextSearchOptions, FullTextSearchOptions,
RecordBatchIterator, RecordBatchIterator,
FullTextQuery,
MatchQuery,
PhraseQuery,
BoostQuery,
MultiMatchQuery,
FullTextQueryType,
} from "./query"; } from "./query";
export { export {

View File

@@ -681,6 +681,4 @@ export interface IndexOptions {
* The default is true * The default is true
*/ */
replace?: boolean; replace?: boolean;
waitTimeoutSeconds?: number;
} }

View File

@@ -11,14 +11,12 @@ import {
} from "./arrow"; } from "./arrow";
import { type IvfPqOptions } from "./indices"; import { type IvfPqOptions } from "./indices";
import { import {
JsFullTextQuery,
RecordBatchIterator as NativeBatchIterator, RecordBatchIterator as NativeBatchIterator,
Query as NativeQuery, Query as NativeQuery,
Table as NativeTable, Table as NativeTable,
VectorQuery as NativeVectorQuery, VectorQuery as NativeVectorQuery,
} from "./native"; } from "./native";
import { Reranker } from "./rerankers"; import { Reranker } from "./rerankers";
export class RecordBatchIterator implements AsyncIterator<RecordBatch> { export class RecordBatchIterator implements AsyncIterator<RecordBatch> {
private promisedInner?: Promise<NativeBatchIterator>; private promisedInner?: Promise<NativeBatchIterator>;
private inner?: NativeBatchIterator; private inner?: NativeBatchIterator;
@@ -64,7 +62,7 @@ class RecordBatchIterable<
// biome-ignore lint/suspicious/noExplicitAny: skip // biome-ignore lint/suspicious/noExplicitAny: skip
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>, any, undefined> { [Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>, any, undefined> {
return new RecordBatchIterator( return new RecordBatchIterator(
this.inner.execute(this.options?.maxBatchLength, this.options?.timeoutMs), this.inner.execute(this.options?.maxBatchLength),
); );
} }
} }
@@ -80,11 +78,6 @@ export interface QueryExecutionOptions {
* in smaller chunks. * in smaller chunks.
*/ */
maxBatchLength?: number; maxBatchLength?: number;
/**
* Timeout for query execution in milliseconds
*/
timeoutMs?: number;
} }
/** /**
@@ -159,7 +152,7 @@ export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
} }
fullTextSearch( fullTextSearch(
query: string | FullTextQuery, query: string,
options?: Partial<FullTextSearchOptions>, options?: Partial<FullTextSearchOptions>,
): this { ): this {
let columns: string[] | null = null; let columns: string[] | null = null;
@@ -171,16 +164,9 @@ export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
} }
} }
this.doCall((inner: NativeQueryType) => { this.doCall((inner: NativeQueryType) =>
if (typeof query === "string") { inner.fullTextSearch(query, columns),
inner.fullTextSearch({ );
query: query,
columns: columns,
});
} else {
inner.fullTextSearch({ query: query.inner });
}
});
return this; return this;
} }
@@ -287,11 +273,9 @@ export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
options?: Partial<QueryExecutionOptions>, options?: Partial<QueryExecutionOptions>,
): Promise<NativeBatchIterator> { ): Promise<NativeBatchIterator> {
if (this.inner instanceof Promise) { if (this.inner instanceof Promise) {
return this.inner.then((inner) => return this.inner.then((inner) => inner.execute(options?.maxBatchLength));
inner.execute(options?.maxBatchLength, options?.timeoutMs),
);
} else { } else {
return this.inner.execute(options?.maxBatchLength, options?.timeoutMs); return this.inner.execute(options?.maxBatchLength);
} }
} }
@@ -364,43 +348,6 @@ export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
return this.inner.explainPlan(verbose); return this.inner.explainPlan(verbose);
} }
} }
/**
* Executes the query and returns the physical query plan annotated with runtime metrics.
*
* This is useful for debugging and performance analysis, as it shows how the query was executed
* and includes metrics such as elapsed time, rows processed, and I/O statistics.
*
* @example
* import * as lancedb from "@lancedb/lancedb"
*
* const db = await lancedb.connect("./.lancedb");
* const table = await db.createTable("my_table", [
* { vector: [1.1, 0.9], id: "1" },
* ]);
*
* const plan = await table.query().nearestTo([0.5, 0.2]).analyzePlan();
*
* Example output (with runtime metrics inlined):
* AnalyzeExec verbose=true, metrics=[]
* ProjectionExec: expr=[id@3 as id, vector@0 as vector, _distance@2 as _distance], metrics=[output_rows=1, elapsed_compute=3.292µs]
* Take: columns="vector, _rowid, _distance, (id)", metrics=[output_rows=1, elapsed_compute=66.001µs, batches_processed=1, bytes_read=8, iops=1, requests=1]
* CoalesceBatchesExec: target_batch_size=1024, metrics=[output_rows=1, elapsed_compute=3.333µs]
* GlobalLimitExec: skip=0, fetch=10, metrics=[output_rows=1, elapsed_compute=167ns]
* FilterExec: _distance@2 IS NOT NULL, metrics=[output_rows=1, elapsed_compute=8.542µs]
* SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], metrics=[output_rows=1, elapsed_compute=63.25µs, row_replacements=1]
* KNNVectorDistance: metric=l2, metrics=[output_rows=1, elapsed_compute=114.333µs, output_batches=1]
* LanceScan: uri=/path/to/data, projection=[vector], row_id=true, row_addr=false, ordered=false, metrics=[output_rows=1, elapsed_compute=103.626µs, bytes_read=549, iops=2, requests=2]
*
* @returns A query execution plan with runtime metrics for each step.
*/
async analyzePlan(): Promise<string> {
if (this.inner instanceof Promise) {
return this.inner.then((inner) => inner.analyzePlan());
} else {
return this.inner.analyzePlan();
}
}
} }
/** /**
@@ -734,177 +681,8 @@ export class Query extends QueryBase<NativeQuery> {
} }
} }
nearestToText(query: string | FullTextQuery, columns?: string[]): Query { nearestToText(query: string, columns?: string[]): Query {
this.doCall((inner) => { this.doCall((inner) => inner.fullTextSearch(query, columns));
if (typeof query === "string") {
inner.fullTextSearch({
query: query,
columns: columns,
});
} else {
inner.fullTextSearch({ query: query.inner });
}
});
return this; return this;
} }
} }
/**
* Enum representing the types of full-text queries supported.
*
* - `Match`: Performs a full-text search for terms in the query string.
* - `MatchPhrase`: Searches for an exact phrase match in the text.
* - `Boost`: Boosts the relevance score of specific terms in the query.
* - `MultiMatch`: Searches across multiple fields for the query terms.
*/
export enum FullTextQueryType {
Match = "match",
MatchPhrase = "match_phrase",
Boost = "boost",
MultiMatch = "multi_match",
}
/**
* Represents a full-text query interface.
* This interface defines the structure and behavior for full-text queries,
* including methods to retrieve the query type and convert the query to a dictionary format.
*/
export interface FullTextQuery {
/**
* Returns the inner query object.
* This is the underlying query object used by the database engine.
* @ignore
*/
inner: JsFullTextQuery;
/**
* The type of the full-text query.
*/
queryType(): FullTextQueryType;
}
// biome-ignore lint/suspicious/noExplicitAny: we want any here
export function instanceOfFullTextQuery(obj: any): obj is FullTextQuery {
return obj != null && obj.inner instanceof JsFullTextQuery;
}
export class MatchQuery implements FullTextQuery {
/** @ignore */
public readonly inner: JsFullTextQuery;
/**
* Creates an instance of MatchQuery.
*
* @param query - The text query to search for.
* @param column - The name of the column to search within.
* @param options - Optional parameters for the match query.
* - `boost`: The boost factor for the query (default is 1.0).
* - `fuzziness`: The fuzziness level for the query (default is 0).
* - `maxExpansions`: The maximum number of terms to consider for fuzzy matching (default is 50).
*/
constructor(
query: string,
column: string,
options?: {
boost?: number;
fuzziness?: number;
maxExpansions?: number;
},
) {
let fuzziness = options?.fuzziness;
if (fuzziness === undefined) {
fuzziness = 0;
}
this.inner = JsFullTextQuery.matchQuery(
query,
column,
options?.boost ?? 1.0,
fuzziness,
options?.maxExpansions ?? 50,
);
}
queryType(): FullTextQueryType {
return FullTextQueryType.Match;
}
}
export class PhraseQuery implements FullTextQuery {
/** @ignore */
public readonly inner: JsFullTextQuery;
/**
* Creates an instance of `PhraseQuery`.
*
* @param query - The phrase to search for in the specified column.
* @param column - The name of the column to search within.
*/
constructor(query: string, column: string) {
this.inner = JsFullTextQuery.phraseQuery(query, column);
}
queryType(): FullTextQueryType {
return FullTextQueryType.MatchPhrase;
}
}
export class BoostQuery implements FullTextQuery {
/** @ignore */
public readonly inner: JsFullTextQuery;
/**
* Creates an instance of BoostQuery.
* The boost returns documents that match the positive query,
* but penalizes those that match the negative query.
* the penalty is controlled by the `negativeBoost` parameter.
*
* @param positive - The positive query that boosts the relevance score.
* @param negative - The negative query that reduces the relevance score.
* @param options - Optional parameters for the boost query.
* - `negativeBoost`: The boost factor for the negative query (default is 0.0).
*/
constructor(
positive: FullTextQuery,
negative: FullTextQuery,
options?: {
negativeBoost?: number;
},
) {
this.inner = JsFullTextQuery.boostQuery(
positive.inner,
negative.inner,
options?.negativeBoost,
);
}
queryType(): FullTextQueryType {
return FullTextQueryType.Boost;
}
}
export class MultiMatchQuery implements FullTextQuery {
/** @ignore */
public readonly inner: JsFullTextQuery;
/**
* Creates an instance of MultiMatchQuery.
*
* @param query - The text query to search for across multiple columns.
* @param columns - An array of column names to search within.
* @param options - Optional parameters for the multi-match query.
* - `boosts`: An array of boost factors for each column (default is 1.0 for all).
*/
constructor(
query: string,
columns: string[],
options?: {
boosts?: number[];
},
) {
this.inner = JsFullTextQuery.multiMatchQuery(
query,
columns,
options?.boosts,
);
}
queryType(): FullTextQueryType {
return FullTextQueryType.MultiMatch;
}
}

View File

@@ -22,12 +22,7 @@ import {
OptimizeStats, OptimizeStats,
Table as _NativeTable, Table as _NativeTable,
} from "./native"; } from "./native";
import { import { Query, VectorQuery } from "./query";
FullTextQuery,
Query,
VectorQuery,
instanceOfFullTextQuery,
} from "./query";
import { sanitizeType } from "./sanitize"; import { sanitizeType } from "./sanitize";
import { IntoSql, toSQL } from "./util"; import { IntoSql, toSQL } from "./util";
export { IndexConfig } from "./native"; export { IndexConfig } from "./native";
@@ -235,30 +230,6 @@ export abstract class Table {
*/ */
abstract dropIndex(name: string): Promise<void>; abstract dropIndex(name: string): Promise<void>;
/**
* Prewarm an index in the table.
*
* @param name The name of the index.
*
* This will load the index into memory. This may reduce the cold-start time for
* future queries. If the index does not fit in the cache then this call may be
* wasteful.
*/
abstract prewarmIndex(name: string): Promise<void>;
/**
* Waits for asynchronous indexing to complete on the table.
*
* @param indexNames The name of the indices to wait for
* @param timeoutSeconds The number of seconds to wait before timing out
*
* This will raise an error if the indices are not created and fully indexed within the timeout.
*/
abstract waitForIndex(
indexNames: string[],
timeoutSeconds: number,
): Promise<void>;
/** /**
* Create a {@link Query} Builder. * Create a {@link Query} Builder.
* *
@@ -323,7 +294,7 @@ export abstract class Table {
* if the query is a string and no embedding function is defined, it will be treated as a full text search query * if the query is a string and no embedding function is defined, it will be treated as a full text search query
*/ */
abstract search( abstract search(
query: string | IntoVector | FullTextQuery, query: string | IntoVector,
queryType?: string, queryType?: string,
ftsColumns?: string | string[], ftsColumns?: string | string[],
): VectorQuery | Query; ): VectorQuery | Query;
@@ -582,39 +553,23 @@ export class LocalTable extends Table {
// Bit of a hack to get around the fact that TS has no package-scope. // Bit of a hack to get around the fact that TS has no package-scope.
// biome-ignore lint/suspicious/noExplicitAny: skip // biome-ignore lint/suspicious/noExplicitAny: skip
const nativeIndex = (options?.config as any)?.inner; const nativeIndex = (options?.config as any)?.inner;
await this.inner.createIndex( await this.inner.createIndex(nativeIndex, column, options?.replace);
nativeIndex,
column,
options?.replace,
options?.waitTimeoutSeconds,
);
} }
async dropIndex(name: string): Promise<void> { async dropIndex(name: string): Promise<void> {
await this.inner.dropIndex(name); await this.inner.dropIndex(name);
} }
async prewarmIndex(name: string): Promise<void> {
await this.inner.prewarmIndex(name);
}
async waitForIndex(
indexNames: string[],
timeoutSeconds: number,
): Promise<void> {
await this.inner.waitForIndex(indexNames, timeoutSeconds);
}
query(): Query { query(): Query {
return new Query(this.inner); return new Query(this.inner);
} }
search( search(
query: string | IntoVector | FullTextQuery, query: string | IntoVector,
queryType: string = "auto", queryType: string = "auto",
ftsColumns?: string | string[], ftsColumns?: string | string[],
): VectorQuery | Query { ): VectorQuery | Query {
if (typeof query !== "string" && !instanceOfFullTextQuery(query)) { if (typeof query !== "string") {
if (queryType === "fts") { if (queryType === "fts") {
throw new Error("Cannot perform full text search on a vector query"); throw new Error("Cannot perform full text search on a vector query");
} }
@@ -630,10 +585,7 @@ export class LocalTable extends Table {
// The query type is auto or vector // The query type is auto or vector
// fall back to full text search if no embedding functions are defined and the query is a string // fall back to full text search if no embedding functions are defined and the query is a string
if ( if (queryType === "auto" && getRegistry().length() === 0) {
queryType === "auto" &&
(getRegistry().length() === 0 || instanceOfFullTextQuery(query))
) {
return this.query().fullTextSearch(query, { return this.query().fullTextSearch(query, {
columns: ftsColumns, columns: ftsColumns,
}); });

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-arm64", "name": "@lancedb/lancedb-darwin-arm64",
"version": "0.19.0-beta.11", "version": "0.18.2-beta.1",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node", "main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-x64", "name": "@lancedb/lancedb-darwin-x64",
"version": "0.19.0-beta.11", "version": "0.18.2-beta.1",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.darwin-x64.node", "main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-arm64-gnu", "name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.19.0-beta.11", "version": "0.18.2-beta.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node", "main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-arm64-musl", "name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.19.0-beta.11", "version": "0.18.2-beta.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node", "main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-x64-gnu", "name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.19.0-beta.11", "version": "0.18.2-beta.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node", "main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-x64-musl", "name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.19.0-beta.11", "version": "0.18.2-beta.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node", "main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-win32-arm64-msvc", "name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.19.0-beta.11", "version": "0.18.2-beta.1",
"os": [ "os": [
"win32" "win32"
], ],

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-win32-x64-msvc", "name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.19.0-beta.11", "version": "0.18.2-beta.1",
"os": ["win32"], "os": ["win32"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node", "main": "lancedb.win32-x64-msvc.node",

250
nodejs/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{ {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.19.0-beta.11", "version": "0.18.2-beta.0",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.19.0-beta.11", "version": "0.18.2-beta.0",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"
@@ -2304,20 +2304,89 @@
} }
}, },
"node_modules/@babel/code-frame": { "node_modules/@babel/code-frame": {
"version": "7.26.2", "version": "7.23.5",
"resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.26.2.tgz", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.23.5.tgz",
"integrity": "sha512-RJlIHRueQgwWitWgF8OdFYGZX328Ax5BCemNGlqHfplnRT9ESi8JkFlvaVYbS+UubVY6dpv87Fs2u5M29iNFVQ==", "integrity": "sha512-CgH3s1a96LipHCmSUmYFPwY7MNx8C3avkq7i4Wl3cfa662ldtUe4VM1TPXX70pfmrlWTb6jLqTYrZyT2ZTJBgA==",
"dev": true, "dev": true,
"license": "MIT",
"dependencies": { "dependencies": {
"@babel/helper-validator-identifier": "^7.25.9", "@babel/highlight": "^7.23.4",
"js-tokens": "^4.0.0", "chalk": "^2.4.2"
"picocolors": "^1.0.0"
}, },
"engines": { "engines": {
"node": ">=6.9.0" "node": ">=6.9.0"
} }
}, },
"node_modules/@babel/code-frame/node_modules/ansi-styles": {
"version": "3.2.1",
"resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-3.2.1.tgz",
"integrity": "sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA==",
"dev": true,
"dependencies": {
"color-convert": "^1.9.0"
},
"engines": {
"node": ">=4"
}
},
"node_modules/@babel/code-frame/node_modules/chalk": {
"version": "2.4.2",
"resolved": "https://registry.npmjs.org/chalk/-/chalk-2.4.2.tgz",
"integrity": "sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==",
"dev": true,
"dependencies": {
"ansi-styles": "^3.2.1",
"escape-string-regexp": "^1.0.5",
"supports-color": "^5.3.0"
},
"engines": {
"node": ">=4"
}
},
"node_modules/@babel/code-frame/node_modules/color-convert": {
"version": "1.9.3",
"resolved": "https://registry.npmjs.org/color-convert/-/color-convert-1.9.3.tgz",
"integrity": "sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==",
"dev": true,
"dependencies": {
"color-name": "1.1.3"
}
},
"node_modules/@babel/code-frame/node_modules/color-name": {
"version": "1.1.3",
"resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.3.tgz",
"integrity": "sha512-72fSenhMw2HZMTVHeCA9KCmpEIbzWiQsjN+BHcBbS9vr1mtt+vJjPdksIBNUmKAW8TFUDPJK5SUU3QhE9NEXDw==",
"dev": true
},
"node_modules/@babel/code-frame/node_modules/escape-string-regexp": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz",
"integrity": "sha512-vbRorB5FUQWvla16U8R/qgaFIya2qGzwDrNmCZuYKrbdSUMG6I1ZCGQRefkRVhuOkIGVne7BQ35DSfo1qvJqFg==",
"dev": true,
"engines": {
"node": ">=0.8.0"
}
},
"node_modules/@babel/code-frame/node_modules/has-flag": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/has-flag/-/has-flag-3.0.0.tgz",
"integrity": "sha512-sKJf1+ceQBr4SMkvQnBDNDtf4TXpVhVGateu0t918bl30FnbE2m4vNLX+VWe/dpjlb+HugGYzW7uQXH98HPEYw==",
"dev": true,
"engines": {
"node": ">=4"
}
},
"node_modules/@babel/code-frame/node_modules/supports-color": {
"version": "5.5.0",
"resolved": "https://registry.npmjs.org/supports-color/-/supports-color-5.5.0.tgz",
"integrity": "sha512-QjVjwdXIt408MIiAqCX4oUKsgU2EqAGzs2Ppkm4aQYbjm+ZEWEcW4SfFNTr4uMNZma0ey4f5lgLrkB0aX0QMow==",
"dev": true,
"dependencies": {
"has-flag": "^3.0.0"
},
"engines": {
"node": ">=4"
}
},
"node_modules/@babel/compat-data": { "node_modules/@babel/compat-data": {
"version": "7.23.5", "version": "7.23.5",
"resolved": "https://registry.npmjs.org/@babel/compat-data/-/compat-data-7.23.5.tgz", "resolved": "https://registry.npmjs.org/@babel/compat-data/-/compat-data-7.23.5.tgz",
@@ -2520,21 +2589,19 @@
} }
}, },
"node_modules/@babel/helper-string-parser": { "node_modules/@babel/helper-string-parser": {
"version": "7.25.9", "version": "7.23.4",
"resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.25.9.tgz", "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.23.4.tgz",
"integrity": "sha512-4A/SCr/2KLd5jrtOMFzaKjVtAei3+2r/NChoBNoZ3EyP/+GlhoaEGoWOZUmFmoITP7zOJyHIMm+DYRd8o3PvHA==", "integrity": "sha512-803gmbQdqwdf4olxrX4AJyFBV/RTr3rSmOj0rKwesmzlfhYNDEs+/iOcznzpNWlJlIlTJC2QfPFcHB6DlzdVLQ==",
"dev": true, "dev": true,
"license": "MIT",
"engines": { "engines": {
"node": ">=6.9.0" "node": ">=6.9.0"
} }
}, },
"node_modules/@babel/helper-validator-identifier": { "node_modules/@babel/helper-validator-identifier": {
"version": "7.25.9", "version": "7.22.20",
"resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.25.9.tgz", "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.22.20.tgz",
"integrity": "sha512-Ed61U6XJc3CVRfkERJWDz4dJwKe7iLmmJsbOGu9wSloNSFttHV0I8g6UAgb7qnK5ly5bGLPd4oXZlxCdANBOWQ==", "integrity": "sha512-Y4OZ+ytlatR8AI+8KZfKuL5urKp7qey08ha31L8b3BwewJAoJamTzyvxPR/5D+KkdJCGPq/+8TukHBlY10FX9A==",
"dev": true, "dev": true,
"license": "MIT",
"engines": { "engines": {
"node": ">=6.9.0" "node": ">=6.9.0"
} }
@@ -2549,28 +2616,109 @@
} }
}, },
"node_modules/@babel/helpers": { "node_modules/@babel/helpers": {
"version": "7.27.0", "version": "7.23.8",
"resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.27.0.tgz", "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.23.8.tgz",
"integrity": "sha512-U5eyP/CTFPuNE3qk+WZMxFkp/4zUzdceQlfzf7DdGdhp+Fezd7HD+i8Y24ZuTMKX3wQBld449jijbGq6OdGNQg==", "integrity": "sha512-KDqYz4PiOWvDFrdHLPhKtCThtIcKVy6avWD2oG4GEvyQ+XDZwHD4YQd+H2vNMnq2rkdxsDkU82T+Vk8U/WXHRQ==",
"dev": true, "dev": true,
"license": "MIT",
"dependencies": { "dependencies": {
"@babel/template": "^7.27.0", "@babel/template": "^7.22.15",
"@babel/types": "^7.27.0" "@babel/traverse": "^7.23.7",
"@babel/types": "^7.23.6"
}, },
"engines": { "engines": {
"node": ">=6.9.0" "node": ">=6.9.0"
} }
}, },
"node_modules/@babel/parser": { "node_modules/@babel/highlight": {
"version": "7.27.0", "version": "7.23.4",
"resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.27.0.tgz", "resolved": "https://registry.npmjs.org/@babel/highlight/-/highlight-7.23.4.tgz",
"integrity": "sha512-iaepho73/2Pz7w2eMS0Q5f83+0RKI7i4xmiYeBmDzfRVbQtTOG7Ts0S4HzJVsTMGI9keU8rNfuZr8DKfSt7Yyg==", "integrity": "sha512-acGdbYSfp2WheJoJm/EBBBLh/ID8KDc64ISZ9DYtBmC8/Q204PZJLHyzeB5qMzJ5trcOkybd78M4x2KWsUq++A==",
"dev": true, "dev": true,
"license": "MIT",
"dependencies": { "dependencies": {
"@babel/types": "^7.27.0" "@babel/helper-validator-identifier": "^7.22.20",
"chalk": "^2.4.2",
"js-tokens": "^4.0.0"
}, },
"engines": {
"node": ">=6.9.0"
}
},
"node_modules/@babel/highlight/node_modules/ansi-styles": {
"version": "3.2.1",
"resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-3.2.1.tgz",
"integrity": "sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA==",
"dev": true,
"dependencies": {
"color-convert": "^1.9.0"
},
"engines": {
"node": ">=4"
}
},
"node_modules/@babel/highlight/node_modules/chalk": {
"version": "2.4.2",
"resolved": "https://registry.npmjs.org/chalk/-/chalk-2.4.2.tgz",
"integrity": "sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==",
"dev": true,
"dependencies": {
"ansi-styles": "^3.2.1",
"escape-string-regexp": "^1.0.5",
"supports-color": "^5.3.0"
},
"engines": {
"node": ">=4"
}
},
"node_modules/@babel/highlight/node_modules/color-convert": {
"version": "1.9.3",
"resolved": "https://registry.npmjs.org/color-convert/-/color-convert-1.9.3.tgz",
"integrity": "sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==",
"dev": true,
"dependencies": {
"color-name": "1.1.3"
}
},
"node_modules/@babel/highlight/node_modules/color-name": {
"version": "1.1.3",
"resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.3.tgz",
"integrity": "sha512-72fSenhMw2HZMTVHeCA9KCmpEIbzWiQsjN+BHcBbS9vr1mtt+vJjPdksIBNUmKAW8TFUDPJK5SUU3QhE9NEXDw==",
"dev": true
},
"node_modules/@babel/highlight/node_modules/escape-string-regexp": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz",
"integrity": "sha512-vbRorB5FUQWvla16U8R/qgaFIya2qGzwDrNmCZuYKrbdSUMG6I1ZCGQRefkRVhuOkIGVne7BQ35DSfo1qvJqFg==",
"dev": true,
"engines": {
"node": ">=0.8.0"
}
},
"node_modules/@babel/highlight/node_modules/has-flag": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/has-flag/-/has-flag-3.0.0.tgz",
"integrity": "sha512-sKJf1+ceQBr4SMkvQnBDNDtf4TXpVhVGateu0t918bl30FnbE2m4vNLX+VWe/dpjlb+HugGYzW7uQXH98HPEYw==",
"dev": true,
"engines": {
"node": ">=4"
}
},
"node_modules/@babel/highlight/node_modules/supports-color": {
"version": "5.5.0",
"resolved": "https://registry.npmjs.org/supports-color/-/supports-color-5.5.0.tgz",
"integrity": "sha512-QjVjwdXIt408MIiAqCX4oUKsgU2EqAGzs2Ppkm4aQYbjm+ZEWEcW4SfFNTr4uMNZma0ey4f5lgLrkB0aX0QMow==",
"dev": true,
"dependencies": {
"has-flag": "^3.0.0"
},
"engines": {
"node": ">=4"
}
},
"node_modules/@babel/parser": {
"version": "7.23.6",
"resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.23.6.tgz",
"integrity": "sha512-Z2uID7YJ7oNvAI20O9X0bblw7Qqs8Q2hFy0R9tAfnfLkp5MW0UH9eUvnDSnFwKZ0AvgS1ucqR4KzvVHgnke1VQ==",
"dev": true,
"bin": { "bin": {
"parser": "bin/babel-parser.js" "parser": "bin/babel-parser.js"
}, },
@@ -2756,15 +2904,14 @@
} }
}, },
"node_modules/@babel/template": { "node_modules/@babel/template": {
"version": "7.27.0", "version": "7.22.15",
"resolved": "https://registry.npmjs.org/@babel/template/-/template-7.27.0.tgz", "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.22.15.tgz",
"integrity": "sha512-2ncevenBqXI6qRMukPlXwHKHchC7RyMuu4xv5JBXRfOGVcTy1mXCD12qrp7Jsoxll1EV3+9sE4GugBVRjT2jFA==", "integrity": "sha512-QPErUVm4uyJa60rkI73qneDacvdvzxshT3kksGqlGWYdOTIUOwJ7RDUL8sGqslY1uXWSL6xMFKEXDS3ox2uF0w==",
"dev": true, "dev": true,
"license": "MIT",
"dependencies": { "dependencies": {
"@babel/code-frame": "^7.26.2", "@babel/code-frame": "^7.22.13",
"@babel/parser": "^7.27.0", "@babel/parser": "^7.22.15",
"@babel/types": "^7.27.0" "@babel/types": "^7.22.15"
}, },
"engines": { "engines": {
"node": ">=6.9.0" "node": ">=6.9.0"
@@ -2801,14 +2948,14 @@
} }
}, },
"node_modules/@babel/types": { "node_modules/@babel/types": {
"version": "7.27.0", "version": "7.23.6",
"resolved": "https://registry.npmjs.org/@babel/types/-/types-7.27.0.tgz", "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.23.6.tgz",
"integrity": "sha512-H45s8fVLYjbhFH62dIJ3WtmJ6RSPt/3DRO0ZcT2SUiYiQyz3BLVb9ADEnLl91m74aQPS3AzzeajZHYOalWe3bg==", "integrity": "sha512-+uarb83brBzPKN38NX1MkB6vb6+mwvR6amUulqAE7ccQw1pEl+bCia9TbdG1lsnFP7lZySvUn37CHyXQdfTwzg==",
"dev": true, "dev": true,
"license": "MIT",
"dependencies": { "dependencies": {
"@babel/helper-string-parser": "^7.25.9", "@babel/helper-string-parser": "^7.23.4",
"@babel/helper-validator-identifier": "^7.25.9" "@babel/helper-validator-identifier": "^7.22.20",
"to-fast-properties": "^2.0.0"
}, },
"engines": { "engines": {
"node": ">=6.9.0" "node": ">=6.9.0"
@@ -5403,11 +5550,10 @@
"devOptional": true "devOptional": true
}, },
"node_modules/axios": { "node_modules/axios": {
"version": "1.8.4", "version": "1.7.7",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.8.4.tgz", "resolved": "https://registry.npmjs.org/axios/-/axios-1.7.7.tgz",
"integrity": "sha512-eBSYY4Y68NNlHbHBMdeDmKNtDgXWhQsJcGqzO3iLUM0GraQFSS9cVgPX5I9b3lbdFKyYoAEGAZF1DwhTaljNAw==", "integrity": "sha512-S4kL7XrjgBmvdGut0sN3yJxqYzrDOnivkBiN0OFs6hLiUam3UPvswUo0kqGyhqUZGEOytHyumEdXsAkgCOUf3Q==",
"dev": true, "dev": true,
"license": "MIT",
"dependencies": { "dependencies": {
"follow-redirects": "^1.15.6", "follow-redirects": "^1.15.6",
"form-data": "^4.0.0", "form-data": "^4.0.0",
@@ -7723,8 +7869,7 @@
"version": "4.0.0", "version": "4.0.0",
"resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz",
"integrity": "sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==", "integrity": "sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==",
"dev": true, "dev": true
"license": "MIT"
}, },
"node_modules/js-yaml": { "node_modules/js-yaml": {
"version": "3.14.1", "version": "3.14.1",
@@ -9215,6 +9360,15 @@
"integrity": "sha512-3f0uOEAQwIqGuWW2MVzYg8fV/QNnc/IpuJNG837rLuczAaLVHslWHZQj4IGiEl5Hs3kkbhwL9Ab7Hrsmuj+Smw==", "integrity": "sha512-3f0uOEAQwIqGuWW2MVzYg8fV/QNnc/IpuJNG837rLuczAaLVHslWHZQj4IGiEl5Hs3kkbhwL9Ab7Hrsmuj+Smw==",
"dev": true "dev": true
}, },
"node_modules/to-fast-properties": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/to-fast-properties/-/to-fast-properties-2.0.0.tgz",
"integrity": "sha512-/OaKK0xYrs3DmxRYqL/yDc+FxFUVYhDlXMhRmv3z915w2HF1tnN1omB354j8VUGO/hbRzyD6Y3sA7v7GS/ceog==",
"dev": true,
"engines": {
"node": ">=4"
}
},
"node_modules/to-regex-range": { "node_modules/to-regex-range": {
"version": "5.0.1", "version": "5.0.1",
"resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz",

View File

@@ -11,7 +11,7 @@
"ann" "ann"
], ],
"private": false, "private": false,
"version": "0.19.0-beta.11", "version": "0.18.2-beta.1",
"main": "dist/index.js", "main": "dist/index.js",
"exports": { "exports": {
".": "./dist/index.js", ".": "./dist/index.js",
@@ -29,7 +29,6 @@
"aarch64-apple-darwin", "aarch64-apple-darwin",
"x86_64-unknown-linux-gnu", "x86_64-unknown-linux-gnu",
"aarch64-unknown-linux-gnu", "aarch64-unknown-linux-gnu",
"x86_64-unknown-linux-musl",
"aarch64-unknown-linux-musl", "aarch64-unknown-linux-musl",
"x86_64-pc-windows-msvc", "x86_64-pc-windows-msvc",
"aarch64-pc-windows-msvc" "aarch64-pc-windows-msvc"

View File

@@ -3,9 +3,7 @@
use std::sync::Arc; use std::sync::Arc;
use lancedb::index::scalar::{ use lancedb::index::scalar::FullTextSearchQuery;
BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, PhraseQuery,
};
use lancedb::query::ExecutableQuery; use lancedb::query::ExecutableQuery;
use lancedb::query::Query as LanceDbQuery; use lancedb::query::Query as LanceDbQuery;
use lancedb::query::QueryBase; use lancedb::query::QueryBase;
@@ -40,10 +38,9 @@ impl Query {
} }
#[napi] #[napi]
pub fn full_text_search(&mut self, query: napi::JsObject) -> napi::Result<()> { pub fn full_text_search(&mut self, query: String, columns: Option<Vec<String>>) {
let query = parse_fts_query(query)?; let query = FullTextSearchQuery::new(query).columns(columns);
self.inner = self.inner.clone().full_text_search(query); self.inner = self.inner.clone().full_text_search(query);
Ok(())
} }
#[napi] #[napi]
@@ -90,15 +87,11 @@ impl Query {
pub async fn execute( pub async fn execute(
&self, &self,
max_batch_length: Option<u32>, max_batch_length: Option<u32>,
timeout_ms: Option<u32>,
) -> napi::Result<RecordBatchIterator> { ) -> napi::Result<RecordBatchIterator> {
let mut execution_opts = QueryExecutionOptions::default(); let mut execution_opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length { if let Some(max_batch_length) = max_batch_length {
execution_opts.max_batch_length = max_batch_length; execution_opts.max_batch_length = max_batch_length;
} }
if let Some(timeout_ms) = timeout_ms {
execution_opts.timeout = Some(std::time::Duration::from_millis(timeout_ms as u64))
}
let inner_stream = self let inner_stream = self
.inner .inner
.execute_with_options(execution_opts) .execute_with_options(execution_opts)
@@ -121,16 +114,6 @@ impl Query {
)) ))
}) })
} }
#[napi(catch_unwind)]
pub async fn analyze_plan(&self) -> napi::Result<String> {
self.inner.analyze_plan().await.map_err(|e| {
napi::Error::from_reason(format!(
"Failed to execute analyze plan: {}",
convert_error(&e)
))
})
}
} }
#[napi] #[napi]
@@ -202,10 +185,9 @@ impl VectorQuery {
} }
#[napi] #[napi]
pub fn full_text_search(&mut self, query: napi::JsObject) -> napi::Result<()> { pub fn full_text_search(&mut self, query: String, columns: Option<Vec<String>>) {
let query = parse_fts_query(query)?; let query = FullTextSearchQuery::new(query).columns(columns);
self.inner = self.inner.clone().full_text_search(query); self.inner = self.inner.clone().full_text_search(query);
Ok(())
} }
#[napi] #[napi]
@@ -250,15 +232,11 @@ impl VectorQuery {
pub async fn execute( pub async fn execute(
&self, &self,
max_batch_length: Option<u32>, max_batch_length: Option<u32>,
timeout_ms: Option<u32>,
) -> napi::Result<RecordBatchIterator> { ) -> napi::Result<RecordBatchIterator> {
let mut execution_opts = QueryExecutionOptions::default(); let mut execution_opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length { if let Some(max_batch_length) = max_batch_length {
execution_opts.max_batch_length = max_batch_length; execution_opts.max_batch_length = max_batch_length;
} }
if let Some(timeout_ms) = timeout_ms {
execution_opts.timeout = Some(std::time::Duration::from_millis(timeout_ms as u64))
}
let inner_stream = self let inner_stream = self
.inner .inner
.execute_with_options(execution_opts) .execute_with_options(execution_opts)
@@ -281,127 +259,4 @@ impl VectorQuery {
)) ))
}) })
} }
#[napi(catch_unwind)]
pub async fn analyze_plan(&self) -> napi::Result<String> {
self.inner.analyze_plan().await.map_err(|e| {
napi::Error::from_reason(format!(
"Failed to execute analyze plan: {}",
convert_error(&e)
))
})
}
}
#[napi]
#[derive(Debug, Clone)]
pub struct JsFullTextQuery {
pub(crate) inner: FtsQuery,
}
#[napi]
impl JsFullTextQuery {
#[napi(factory)]
pub fn match_query(
query: String,
column: String,
boost: f64,
fuzziness: Option<u32>,
max_expansions: u32,
) -> napi::Result<Self> {
Ok(Self {
inner: MatchQuery::new(query)
.with_column(Some(column))
.with_boost(boost as f32)
.with_fuzziness(fuzziness)
.with_max_expansions(max_expansions as usize)
.into(),
})
}
#[napi(factory)]
pub fn phrase_query(query: String, column: String) -> napi::Result<Self> {
Ok(Self {
inner: PhraseQuery::new(query).with_column(Some(column)).into(),
})
}
#[napi(factory)]
#[allow(clippy::use_self)] // NAPI doesn't allow Self here but clippy reports it
pub fn boost_query(
positive: &JsFullTextQuery,
negative: &JsFullTextQuery,
negative_boost: Option<f64>,
) -> napi::Result<Self> {
Ok(Self {
inner: BoostQuery::new(
positive.inner.clone(),
negative.inner.clone(),
negative_boost.map(|v| v as f32),
)
.into(),
})
}
#[napi(factory)]
pub fn multi_match_query(
query: String,
columns: Vec<String>,
boosts: Option<Vec<f64>>,
) -> napi::Result<Self> {
let q = match boosts {
Some(boosts) => MultiMatchQuery::try_new(query, columns)
.and_then(|q| q.try_with_boosts(boosts.into_iter().map(|v| v as f32).collect())),
None => MultiMatchQuery::try_new(query, columns),
}
.map_err(|e| {
napi::Error::from_reason(format!("Failed to create multi match query: {}", e))
})?;
Ok(Self { inner: q.into() })
}
}
fn parse_fts_query(query: napi::JsObject) -> napi::Result<FullTextSearchQuery> {
if let Ok(Some(query)) = query.get::<_, &JsFullTextQuery>("query") {
Ok(FullTextSearchQuery::new_query(query.inner.clone()))
} else if let Ok(Some(query_text)) = query.get::<_, String>("query") {
let mut query_text = query_text;
let columns = query.get::<_, Option<Vec<String>>>("columns")?.flatten();
let is_phrase =
query_text.len() >= 2 && query_text.starts_with('"') && query_text.ends_with('"');
let is_multi_match = columns.as_ref().map(|cols| cols.len() > 1).unwrap_or(false);
if is_phrase {
// Remove the surrounding quotes for phrase queries
query_text = query_text[1..query_text.len() - 1].to_string();
}
let query: FtsQuery = match (is_phrase, is_multi_match) {
(false, _) => MatchQuery::new(query_text).into(),
(true, false) => PhraseQuery::new(query_text).into(),
(true, true) => {
return Err(napi::Error::from_reason(
"Phrase queries cannot be used with multiple columns.",
));
}
};
let mut query = FullTextSearchQuery::new_query(query);
if let Some(cols) = columns {
if !cols.is_empty() {
query = query.with_columns(&cols).map_err(|e| {
napi::Error::from_reason(format!(
"Failed to set full text search columns: {}",
e
))
})?;
}
}
Ok(query)
} else {
Err(napi::Error::from_reason(
"Invalid full text search query object".to_string(),
))
}
} }

View File

@@ -111,7 +111,6 @@ impl Table {
index: Option<&Index>, index: Option<&Index>,
column: String, column: String,
replace: Option<bool>, replace: Option<bool>,
wait_timeout_s: Option<i64>,
) -> napi::Result<()> { ) -> napi::Result<()> {
let lancedb_index = if let Some(index) = index { let lancedb_index = if let Some(index) = index {
index.consume()? index.consume()?
@@ -122,10 +121,6 @@ impl Table {
if let Some(replace) = replace { if let Some(replace) = replace {
builder = builder.replace(replace); builder = builder.replace(replace);
} }
if let Some(timeout) = wait_timeout_s {
builder =
builder.wait_timeout(std::time::Duration::from_secs(timeout.try_into().unwrap()));
}
builder.execute().await.default_error() builder.execute().await.default_error()
} }
@@ -137,26 +132,6 @@ impl Table {
.default_error() .default_error()
} }
#[napi(catch_unwind)]
pub async fn prewarm_index(&self, index_name: String) -> napi::Result<()> {
self.inner_ref()?
.prewarm_index(&index_name)
.await
.default_error()
}
#[napi(catch_unwind)]
pub async fn wait_for_index(&self, index_names: Vec<String>, timeout_s: i64) -> Result<()> {
let timeout = std::time::Duration::from_secs(timeout_s.try_into().unwrap());
let index_names: Vec<&str> = index_names.iter().map(|s| s.as_str()).collect();
let slice: &[&str] = &index_names;
self.inner_ref()?
.wait_for_index(slice, timeout)
.await
.default_error()
}
#[napi(catch_unwind)] #[napi(catch_unwind)]
pub async fn update( pub async fn update(
&self, &self,

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.22.0" current_version = "0.21.2"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb-python" name = "lancedb-python"
version = "0.22.0" version = "0.21.2"
edition.workspace = true edition.workspace = true
description = "Python bindings for LanceDB" description = "Python bindings for LanceDB"
license.workspace = true license.workspace = true

View File

@@ -4,12 +4,11 @@ name = "lancedb"
dynamic = ["version"] dynamic = ["version"]
dependencies = [ dependencies = [
"deprecation", "deprecation",
"numpy", "tqdm>=4.27.0",
"overrides>=0.7",
"packaging",
"pyarrow>=14", "pyarrow>=14",
"pydantic>=1.10", "pydantic>=1.10",
"tqdm>=4.27.0", "packaging",
"overrides>=0.7",
] ]
description = "lancedb" description = "lancedb"
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }] authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
@@ -43,9 +42,6 @@ classifiers = [
repository = "https://github.com/lancedb/lancedb" repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies] [project.optional-dependencies]
pylance = [
"pylance>=0.25",
]
tests = [ tests = [
"aiohttp", "aiohttp",
"boto3", "boto3",
@@ -58,8 +54,7 @@ tests = [
"polars>=0.19, <=1.3.0", "polars>=0.19, <=1.3.0",
"tantivy", "tantivy",
"pyarrow-stubs", "pyarrow-stubs",
"pylance>=0.25", "pylance>=0.23.2",
"requests",
] ]
dev = [ dev = [
"ruff", "ruff",
@@ -77,7 +72,6 @@ embeddings = [
"pillow", "pillow",
"open-clip-torch", "open-clip-torch",
"cohere", "cohere",
"colpali-engine>=0.3.10",
"huggingface_hub", "huggingface_hub",
"InstructorEmbedding", "InstructorEmbedding",
"google.generativeai", "google.generativeai",

View File

@@ -1,4 +1,3 @@
from datetime import timedelta
from typing import Dict, List, Optional, Tuple, Any, Union, Literal from typing import Dict, List, Optional, Tuple, Any, Union, Literal
import pyarrow as pa import pyarrow as pa
@@ -49,11 +48,10 @@ class Table:
async def version(self) -> int: ... async def version(self) -> int: ...
async def checkout(self, version: int): ... async def checkout(self, version: int): ...
async def checkout_latest(self): ... async def checkout_latest(self): ...
async def restore(self, version: Optional[int] = None): ... async def restore(self): ...
async def list_indices(self) -> list[IndexConfig]: ... async def list_indices(self) -> list[IndexConfig]: ...
async def delete(self, filter: str): ... async def delete(self, filter: str): ...
async def add_columns(self, columns: list[tuple[str, str]]) -> None: ... async def add_columns(self, columns: list[tuple[str, str]]) -> None: ...
async def add_columns_with_schema(self, schema: pa.Schema) -> None: ...
async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ... async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ...
async def optimize( async def optimize(
self, self,
@@ -95,11 +93,7 @@ class Query:
def postfilter(self): ... def postfilter(self): ...
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ... def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
def nearest_to_text(self, query: dict) -> FTSQuery: ... def nearest_to_text(self, query: dict) -> FTSQuery: ...
async def execute( async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ...
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
) -> RecordBatchStream: ...
async def explain_plan(self, verbose: Optional[bool]) -> str: ...
async def analyze_plan(self) -> str: ...
def to_query_request(self) -> PyQueryRequest: ... def to_query_request(self) -> PyQueryRequest: ...
class FTSQuery: class FTSQuery:
@@ -113,9 +107,8 @@ class FTSQuery:
def get_query(self) -> str: ... def get_query(self) -> str: ...
def add_query_vector(self, query_vec: pa.Array) -> None: ... def add_query_vector(self, query_vec: pa.Array) -> None: ...
def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ... def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ...
async def execute( async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ...
self, max_batch_length: Optional[int], timeout: Optional[timedelta] async def explain_plan(self) -> str: ...
) -> RecordBatchStream: ...
def to_query_request(self) -> PyQueryRequest: ... def to_query_request(self) -> PyQueryRequest: ...
class VectorQuery: class VectorQuery:

View File

@@ -9,7 +9,7 @@ import numpy as np
import pyarrow as pa import pyarrow as pa
import pyarrow.dataset import pyarrow.dataset
from .dependencies import _check_for_pandas, pandas as pd from .dependencies import pandas as pd
DATA = Union[List[dict], "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]] DATA = Union[List[dict], "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray] VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
@@ -63,7 +63,7 @@ def data_to_reader(
data: DATA, schema: Optional[pa.Schema] = None data: DATA, schema: Optional[pa.Schema] = None
) -> pa.RecordBatchReader: ) -> pa.RecordBatchReader:
"""Convert various types of input into a RecordBatchReader""" """Convert various types of input into a RecordBatchReader"""
if _check_for_pandas(data) and isinstance(data, pd.DataFrame): if pd is not None and isinstance(data, pd.DataFrame):
return pa.Table.from_pandas(data, schema=schema).to_reader() return pa.Table.from_pandas(data, schema=schema).to_reader()
elif isinstance(data, pa.Table): elif isinstance(data, pa.Table):
return data.to_reader() return data.to_reader()

View File

@@ -19,4 +19,3 @@ from .imagebind import ImageBindEmbeddings
from .jinaai import JinaEmbeddings from .jinaai import JinaEmbeddings
from .watsonx import WatsonxEmbeddings from .watsonx import WatsonxEmbeddings
from .voyageai import VoyageAIEmbeddingFunction from .voyageai import VoyageAIEmbeddingFunction
from .colpali import ColPaliEmbeddings

View File

@@ -1,255 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from functools import lru_cache
from typing import List, Union, Optional, Any
import numpy as np
import io
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
from .registry import register
from .utils import TEXT, IMAGES, is_flash_attn_2_available
@register("colpali")
class ColPaliEmbeddings(EmbeddingFunction):
"""
An embedding function that uses the ColPali engine for
multimodal multi-vector embeddings.
This embedding function supports ColQwen2.5 models, producing multivector outputs
for both text and image inputs. The output embeddings are lists of vectors, each
vector being 128-dimensional by default, represented as List[List[float]].
Parameters
----------
model_name : str
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
device : str
The device for inference (default "cuda:0").
dtype : str
Data type for model weights (default "bfloat16").
use_token_pooling : bool
Whether to use token pooling to reduce embedding size (default True).
pool_factor : int
Factor to reduce sequence length if token pooling is enabled (default 2).
quantization_config : Optional[BitsAndBytesConfig]
Quantization configuration for the model. (default None, bitsandbytes needed)
batch_size : int
Batch size for processing inputs (default 2).
"""
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
device: str = "auto"
dtype: str = "bfloat16"
use_token_pooling: bool = True
pool_factor: int = 2
quantization_config: Optional[Any] = None
batch_size: int = 2
_model = None
_processor = None
_token_pooler = None
_vector_dim = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
(
self._model,
self._processor,
self._token_pooler,
) = self._load_model(
self.model_name,
self.dtype,
self.device,
self.use_token_pooling,
self.quantization_config,
)
@staticmethod
@lru_cache(maxsize=1)
def _load_model(
model_name: str,
dtype: str,
device: str,
use_token_pooling: bool,
quantization_config: Optional[Any],
):
"""
Initialize and cache the ColPali model, processor, and token pooler.
"""
torch = attempt_import_or_raise("torch", "torch")
transformers = attempt_import_or_raise("transformers", "transformers")
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
if quantization_config is not None:
if not isinstance(quantization_config, transformers.BitsAndBytesConfig):
raise ValueError("quantization_config must be a BitsAndBytesConfig")
if dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif dtype == "float16":
torch_dtype = torch.float16
elif dtype == "float64":
torch_dtype = torch.float64
else:
torch_dtype = torch.float32
model = colpali_engine.models.ColQwen2_5.from_pretrained(
model_name,
torch_dtype=torch_dtype,
device_map=device,
quantization_config=quantization_config
if quantization_config is not None
else None,
attn_implementation="flash_attention_2"
if is_flash_attn_2_available()
else None,
).eval()
processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained(
model_name
)
token_pooler = HierarchicalTokenPooler() if use_token_pooling else None
return model, processor, token_pooler
def ndims(self):
"""
Return the dimension of a vector in the multivector output (e.g., 128).
"""
torch = attempt_import_or_raise("torch", "torch")
if self._vector_dim is None:
dummy_query = "test"
batch_queries = self._processor.process_queries([dummy_query]).to(
self._model.device
)
with torch.no_grad():
query_embeddings = self._model(**batch_queries)
if self.use_token_pooling and self._token_pooler is not None:
query_embeddings = self._token_pooler.pool_embeddings(
query_embeddings,
pool_factor=self.pool_factor,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
self._vector_dim = query_embeddings[0].shape[-1]
return self._vector_dim
def _process_embeddings(self, embeddings):
"""
Format model embeddings into List[List[float]].
Use token pooling if enabled.
"""
torch = attempt_import_or_raise("torch", "torch")
if self.use_token_pooling and self._token_pooler is not None:
embeddings = self._token_pooler.pool_embeddings(
embeddings,
pool_factor=self.pool_factor,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
if isinstance(embeddings, torch.Tensor):
tensors = embeddings.detach().cpu()
if tensors.dtype == torch.bfloat16:
tensors = tensors.to(torch.float32)
return (
tensors.numpy()
.astype(np.float64 if self.dtype == "float64" else np.float32)
.tolist()
)
return []
def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]:
"""
Generate embeddings for text input.
"""
torch = attempt_import_or_raise("torch", "torch")
text = self.sanitize_input(text)
all_embeddings = []
for i in range(0, len(text), self.batch_size):
batch_text = text[i : i + self.batch_size]
batch_queries = self._processor.process_queries(batch_text).to(
self._model.device
)
with torch.no_grad():
query_embeddings = self._model(**batch_queries)
all_embeddings.extend(self._process_embeddings(query_embeddings))
return all_embeddings
def _prepare_images(self, images: IMAGES) -> List:
"""
Convert image inputs to PIL Images.
"""
PIL = attempt_import_or_raise("PIL", "pillow")
requests = attempt_import_or_raise("requests", "requests")
images = self.sanitize_input(images)
pil_images = []
try:
for image in images:
if isinstance(image, str):
if image.startswith(("http://", "https://")):
response = requests.get(image, timeout=10)
response.raise_for_status()
pil_images.append(PIL.Image.open(io.BytesIO(response.content)))
else:
with PIL.Image.open(image) as im:
pil_images.append(im.copy())
elif isinstance(image, bytes):
pil_images.append(PIL.Image.open(io.BytesIO(image)))
else:
# Assume it's a PIL Image; will raise if invalid
pil_images.append(image)
except Exception as e:
raise ValueError(f"Failed to process image: {e}")
return pil_images
def generate_image_embeddings(self, images: IMAGES) -> List[List[List[float]]]:
"""
Generate embeddings for a batch of images.
"""
torch = attempt_import_or_raise("torch", "torch")
pil_images = self._prepare_images(images)
all_embeddings = []
for i in range(0, len(pil_images), self.batch_size):
batch_images = pil_images[i : i + self.batch_size]
batch_images = self._processor.process_images(batch_images).to(
self._model.device
)
with torch.no_grad():
image_embeddings = self._model(**batch_images)
all_embeddings.extend(self._process_embeddings(image_embeddings))
return all_embeddings
def compute_query_embeddings(
self, query: Union[str, IMAGES], *args, **kwargs
) -> List[List[List[float]]]:
"""
Compute embeddings for a single user query (text only).
"""
if not isinstance(query, str):
raise ValueError(
"Query must be a string, image to image search is not supported"
)
return self.generate_text_embeddings([query])
def compute_source_embeddings(
self, images: IMAGES, *args, **kwargs
) -> List[List[List[float]]]:
"""
Compute embeddings for a batch of source images.
Parameters
----------
images : Union[str, bytes, List, pa.Array, pa.ChunkedArray, np.ndarray]
Batch of images (paths, URLs, bytes, or PIL Images).
"""
images = self.sanitize_input(images)
return self.generate_image_embeddings(images)

View File

@@ -18,7 +18,6 @@ import numpy as np
import pyarrow as pa import pyarrow as pa
from ..dependencies import pandas as pd from ..dependencies import pandas as pd
from ..util import attempt_import_or_raise
# ruff: noqa: PERF203 # ruff: noqa: PERF203
@@ -276,12 +275,3 @@ def url_retrieve(url: str):
def api_key_not_found_help(provider): def api_key_not_found_help(provider):
logging.error("Could not find API key for %s", provider) logging.error("Could not find API key for %s", provider)
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.") raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")
def is_flash_attn_2_available():
try:
attempt_import_or_raise("flash_attn", "flash_attn")
return True
except ImportError:
return False

View File

@@ -1,12 +1,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors # SPDX-FileCopyrightText: Copyright The LanceDB Authors
import base64
import os
from typing import ClassVar, TYPE_CHECKING, List, Union, Any
from pathlib import Path
from urllib.parse import urlparse import os
from io import BytesIO from typing import ClassVar, TYPE_CHECKING, List, Union
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
@@ -14,100 +11,12 @@ import pyarrow as pa
from ..util import attempt_import_or_raise from ..util import attempt_import_or_raise
from .base import EmbeddingFunction from .base import EmbeddingFunction
from .registry import register from .registry import register
from .utils import api_key_not_found_help, IMAGES, TEXT from .utils import api_key_not_found_help, IMAGES
if TYPE_CHECKING: if TYPE_CHECKING:
import PIL import PIL
def is_valid_url(text):
try:
parsed = urlparse(text)
return bool(parsed.scheme) and bool(parsed.netloc)
except Exception:
return False
def transform_input(input_data: Union[str, bytes, Path]):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(input_data, str):
if is_valid_url(input_data):
content = {"type": "image_url", "image_url": input_data}
else:
content = {"type": "text", "text": input_data}
elif isinstance(input_data, PIL.Image.Image):
buffered = BytesIO()
input_data.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
content = {
"type": "image_base64",
"image_base64": "data:image/jpeg;base64," + img_str,
}
elif isinstance(input_data, bytes):
img = PIL.Image.open(BytesIO(input_data))
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
content = {
"type": "image_base64",
"image_base64": "data:image/jpeg;base64," + img_str,
}
elif isinstance(input_data, Path):
img = PIL.Image.open(input_data)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
content = {
"type": "image_base64",
"image_base64": "data:image/jpeg;base64," + img_str,
}
else:
raise ValueError("Each input should be either str, bytes, Path or Image.")
return {"content": [content]}
def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
"""
Sanitize the input to the embedding function.
"""
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)):
inputs = [inputs]
elif isinstance(inputs, pa.Array):
inputs = inputs.to_pylist()
elif isinstance(inputs, pa.ChunkedArray):
inputs = inputs.combine_chunks().to_pylist()
else:
raise ValueError(
f"Input type {type(inputs)} not allowed with multimodal model."
)
if not all(isinstance(x, (str, bytes, Path, PIL.Image.Image)) for x in inputs):
raise ValueError("Each input should be either str, bytes, Path or Image.")
return [transform_input(i) for i in inputs]
def sanitize_text_input(inputs: TEXT) -> List[str]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(inputs, str):
inputs = [inputs]
elif isinstance(inputs, pa.Array):
inputs = inputs.to_pylist()
elif isinstance(inputs, pa.ChunkedArray):
inputs = inputs.combine_chunks().to_pylist()
else:
raise ValueError(f"Input type {type(inputs)} not allowed with text model.")
if not all(isinstance(x, str) for x in inputs):
raise ValueError("Each input should be str.")
return inputs
@register("voyageai") @register("voyageai")
class VoyageAIEmbeddingFunction(EmbeddingFunction): class VoyageAIEmbeddingFunction(EmbeddingFunction):
""" """
@@ -165,11 +74,6 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
] ]
multimodal_embedding_models: list = ["voyage-multimodal-3"] multimodal_embedding_models: list = ["voyage-multimodal-3"]
def _is_multimodal_model(self, model_name: str):
return (
model_name in self.multimodal_embedding_models or "multimodal" in model_name
)
def ndims(self): def ndims(self):
if self.name == "voyage-3-lite": if self.name == "voyage-3-lite":
return 512 return 512
@@ -181,12 +85,55 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
"voyage-finance-2", "voyage-finance-2",
"voyage-multilingual-2", "voyage-multilingual-2",
"voyage-law-2", "voyage-law-2",
"voyage-multimodal-3",
]: ]:
return 1024 return 1024
else: else:
raise ValueError(f"Model {self.name} not supported") raise ValueError(f"Model {self.name} not supported")
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(images, (str, bytes)):
images = [images]
elif isinstance(images, pa.Array):
images = images.to_pylist()
elif isinstance(images, pa.ChunkedArray):
images = images.combine_chunks().to_pylist()
return images
def generate_text_embeddings(self, text: str, **kwargs) -> np.ndarray:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
input_type: Optional[str]
truncation: Optional[bool]
"""
client = VoyageAIEmbeddingFunction._get_client()
if self.name in self.text_embedding_models:
rs = client.embed(texts=[text], model=self.name, **kwargs)
elif self.name in self.multimodal_embedding_models:
rs = client.multimodal_embed(inputs=[[text]], model=self.name, **kwargs)
else:
raise ValueError(
f"Model {self.name} not supported to generate text embeddings"
)
return rs.embeddings[0]
def generate_image_embedding(
self, image: "PIL.Image.Image", **kwargs
) -> np.ndarray:
rs = VoyageAIEmbeddingFunction._get_client().multimodal_embed(
inputs=[[image]], model=self.name, **kwargs
)
return rs.embeddings[0]
def compute_query_embeddings( def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]: ) -> List[np.ndarray]:
@@ -197,52 +144,23 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
---------- ----------
query : Union[str, PIL.Image.Image] query : Union[str, PIL.Image.Image]
The query to embed. A query can be either text or an image. The query to embed. A query can be either text or an image.
Returns
-------
List[np.array]: the list of embeddings
""" """
client = VoyageAIEmbeddingFunction._get_client() if isinstance(query, str):
if self._is_multimodal_model(self.name): return [self.generate_text_embeddings(query, input_type="query")]
result = client.multimodal_embed(
inputs=[[query]], model=self.name, input_type="query", **kwargs
)
else: else:
result = client.embed( PIL = attempt_import_or_raise("PIL", "pillow")
texts=[query], model=self.name, input_type="query", **kwargs if isinstance(query, PIL.Image.Image):
) return [self.generate_image_embedding(query, input_type="query")]
else:
return [result.embeddings[0]] raise TypeError("Only text PIL images supported as query")
def compute_source_embeddings( def compute_source_embeddings(
self, inputs: Union[TEXT, IMAGES], *args, **kwargs self, images: IMAGES, *args, **kwargs
) -> List[np.array]: ) -> List[np.array]:
""" images = self.sanitize_input(images)
Compute the embeddings for the inputs return [
self.generate_image_embedding(img, input_type="document") for img in images
Parameters ]
----------
inputs : Union[TEXT, IMAGES]
The inputs to embed. The input can be either str, bytes, Path (to an image),
PIL.Image or list of these.
Returns
-------
List[np.array]: the list of embeddings
"""
client = VoyageAIEmbeddingFunction._get_client()
if self._is_multimodal_model(self.name):
inputs = sanitize_multimodal_input(inputs)
result = client.multimodal_embed(
inputs=inputs, model=self.name, input_type="document", **kwargs
)
else:
inputs = sanitize_text_input(inputs)
result = client.embed(
texts=inputs, model=self.name, input_type="document", **kwargs
)
return result.embeddings
@staticmethod @staticmethod
def _get_client(): def _get_client():

View File

@@ -152,104 +152,6 @@ def Vector(
return FixedSizeList return FixedSizeList
def MultiVector(
dim: int, value_type: pa.DataType = pa.float32(), nullable: bool = True
) -> Type:
"""Pydantic MultiVector Type for multi-vector embeddings.
This type represents a list of vectors, each with the same dimension.
Useful for models that produce multiple embeddings per input, like ColPali.
Parameters
----------
dim : int
The dimension of each vector in the multi-vector.
value_type : pyarrow.DataType, optional
The value type of the vectors, by default pa.float32()
nullable : bool, optional
Whether the multi-vector is nullable, by default it is True.
Examples
--------
>>> import pydantic
>>> from lancedb.pydantic import MultiVector
...
>>> class MyModel(pydantic.BaseModel):
... id: int
... text: str
... embeddings: MultiVector(128) # List of 128-dimensional vectors
>>> schema = pydantic_to_schema(MyModel)
>>> assert schema == pa.schema([
... pa.field("id", pa.int64(), False),
... pa.field("text", pa.utf8(), False),
... pa.field("embeddings", pa.list_(pa.list_(pa.float32(), 128)))
... ])
"""
class MultiVectorList(list, FixedSizeListMixin):
def __repr__(self):
return f"MultiVector(dim={dim})"
@staticmethod
def nullable() -> bool:
return nullable
@staticmethod
def dim() -> int:
return dim
@staticmethod
def value_arrow_type() -> pa.DataType:
return value_type
@staticmethod
def is_multi_vector() -> bool:
return True
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.no_info_after_validator_function(
cls,
core_schema.list_schema(
items_schema=core_schema.list_schema(
min_length=dim,
max_length=dim,
items_schema=core_schema.float_schema(),
),
),
)
@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v1
@classmethod
def validate(cls, v):
if not isinstance(v, (list, range)):
raise TypeError("A list of vectors is needed")
for vec in v:
if not isinstance(vec, (list, range, np.ndarray)) or len(vec) != dim:
raise TypeError(f"Each vector must be a list of {dim} numbers")
return cls(v)
if PYDANTIC_VERSION.major < 2:
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["items"] = {
"type": "array",
"items": {"type": "number"},
"minItems": dim,
"maxItems": dim,
}
return MultiVectorList
def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType: def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
"""Convert a field with native Python type to Arrow data type. """Convert a field with native Python type to Arrow data type.
@@ -304,9 +206,6 @@ def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
fields = _pydantic_model_to_fields(tp) fields = _pydantic_model_to_fields(tp)
return pa.struct(fields) return pa.struct(fields)
if issubclass(tp, FixedSizeListMixin): if issubclass(tp, FixedSizeListMixin):
if getattr(tp, "is_multi_vector", lambda: False)():
return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim()))
# For regular Vector
return pa.list_(tp.value_arrow_type(), tp.dim()) return pa.list_(tp.value_arrow_type(), tp.dim())
return _py_type_to_arrow_type(tp, field) return _py_type_to_arrow_type(tp, field)

View File

@@ -4,10 +4,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import abc
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from datetime import timedelta
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Dict, Dict,
@@ -28,8 +25,6 @@ import pyarrow.compute as pc
import pyarrow.fs as pa_fs import pyarrow.fs as pa_fs
import pydantic import pydantic
from lancedb.pydantic import PYDANTIC_VERSION
from . import __version__ from . import __version__
from .arrow import AsyncRecordBatchReader from .arrow import AsyncRecordBatchReader
from .dependencies import pandas as pd from .dependencies import pandas as pd
@@ -88,213 +83,6 @@ def ensure_vector_query(
return val return val
class FullTextQueryType(Enum):
MATCH = "match"
MATCH_PHRASE = "match_phrase"
BOOST = "boost"
MULTI_MATCH = "multi_match"
class FullTextQuery(abc.ABC, pydantic.BaseModel):
@abc.abstractmethod
def query_type(self) -> FullTextQueryType:
"""
Get the query type of the query.
Returns
-------
str
The type of the query.
"""
@abc.abstractmethod
def to_dict(self) -> dict:
"""
Convert the query to a dictionary.
Returns
-------
dict
The query as a dictionary.
"""
class MatchQuery(FullTextQuery):
query: str
column: str
boost: float = 1.0
fuzziness: int = 0
max_expansions: int = 50
def __init__(
self,
query: str,
column: str,
*,
boost: float = 1.0,
fuzziness: int = 0,
max_expansions: int = 50,
):
"""
Match query for full-text search.
Parameters
----------
query : str
The query string to match against.
column : str
The name of the column to match against.
boost : float, default 1.0
The boost factor for the query.
The score of each matching document is multiplied by this value.
fuzziness : int, optional
The maximum edit distance for each term in the match query.
Defaults to 0 (exact match).
If None, fuzziness is applied automatically by the rules:
- 0 for terms with length <= 2
- 1 for terms with length <= 5
- 2 for terms with length > 5
max_expansions : int, optional
The maximum number of terms to consider for fuzzy matching.
Defaults to 50.
"""
super().__init__(
query=query,
column=column,
boost=boost,
fuzziness=fuzziness,
max_expansions=max_expansions,
)
def query_type(self) -> FullTextQueryType:
return FullTextQueryType.MATCH
def to_dict(self) -> dict:
return {
"match": {
self.column: {
"query": self.query,
"boost": self.boost,
"fuzziness": self.fuzziness,
"max_expansions": self.max_expansions,
}
}
}
class PhraseQuery(FullTextQuery):
query: str
column: str
def __init__(self, query: str, column: str):
"""
Phrase query for full-text search.
Parameters
----------
query : str
The query string to match against.
column : str
The name of the column to match against.
"""
super().__init__(query=query, column=column)
def query_type(self) -> FullTextQueryType:
return FullTextQueryType.MATCH_PHRASE
def to_dict(self) -> dict:
return {
"match_phrase": {
self.column: self.query,
}
}
class BoostQuery(FullTextQuery):
positive: FullTextQuery
negative: FullTextQuery
negative_boost: float = 0.5
def __init__(
self,
positive: FullTextQuery,
negative: FullTextQuery,
*,
negative_boost: float = 0.5,
):
"""
Boost query for full-text search.
Parameters
----------
positive : dict
The positive query object.
negative : dict
The negative query object.
negative_boost : float
The boost factor for the negative query.
"""
super().__init__(
positive=positive, negative=negative, negative_boost=negative_boost
)
def query_type(self) -> FullTextQueryType:
return FullTextQueryType.BOOST
def to_dict(self) -> dict:
return {
"boost": {
"positive": self.positive.to_dict(),
"negative": self.negative.to_dict(),
"negative_boost": self.negative_boost,
}
}
class MultiMatchQuery(FullTextQuery):
query: str
columns: list[str]
boosts: list[float]
def __init__(
self,
query: str,
columns: list[str],
*,
boosts: Optional[list[float]] = None,
):
"""
Multi-match query for full-text search.
Parameters
----------
query : str
The query string to match against.
columns : list[str]
The list of columns to match against.
boosts : list[float], optional
The list of boost factors for each column. If not provided,
all columns will have the same boost factor.
"""
if boosts is None:
boosts = [1.0] * len(columns)
super().__init__(query=query, columns=columns, boosts=boosts)
def query_type(self) -> FullTextQueryType:
return FullTextQueryType.MULTI_MATCH
def to_dict(self) -> dict:
return {
"multi_match": {
"query": self.query,
"columns": self.columns,
"boost": self.boosts,
}
}
class FullTextSearchQuery(pydantic.BaseModel): class FullTextSearchQuery(pydantic.BaseModel):
"""A LanceDB Full Text Search Query """A LanceDB Full Text Search Query
@@ -304,13 +92,18 @@ class FullTextSearchQuery(pydantic.BaseModel):
The columns to search The columns to search
If None, then the table should select the column automatically. If None, then the table should select the column automatically.
query: str | FullTextQuery query: str
If a string, it is treated as a MatchQuery. The query to search for
If a FullTextQuery object, it is used directly. limit: Optional[int] = None
The limit on the number of results to return
wand_factor: Optional[float] = None
The wand factor to use for the search
""" """
columns: Optional[List[str]] = None columns: Optional[List[str]] = None
query: Union[str, FullTextQuery] query: str
limit: Optional[int] = None
wand_factor: Optional[float] = None
class Query(pydantic.BaseModel): class Query(pydantic.BaseModel):
@@ -500,14 +293,10 @@ class Query(pydantic.BaseModel):
) )
return query return query
class Config:
# This tells pydantic to allow custom types (needed for the `vector` query since # This tells pydantic to allow custom types (needed for the `vector` query since
# pa.Array wouln't be allowed otherwise) # pa.Array wouln't be allowed otherwise)
if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
else:
model_config = {"arbitrary_types_allowed": True}
class LanceQueryBuilder(ABC): class LanceQueryBuilder(ABC):
@@ -568,7 +357,7 @@ class LanceQueryBuilder(ABC):
table, query, vector_column_name, fts_columns=fts_columns table, query, vector_column_name, fts_columns=fts_columns
) )
if isinstance(query, (str, FullTextQuery)): if isinstance(query, str):
# fts # fts
return LanceFtsQueryBuilder( return LanceFtsQueryBuilder(
table, table,
@@ -593,10 +382,8 @@ class LanceQueryBuilder(ABC):
# If query_type is fts, then query must be a string. # If query_type is fts, then query must be a string.
# otherwise raise TypeError # otherwise raise TypeError
if query_type == "fts": if query_type == "fts":
if not isinstance(query, (str, FullTextQuery)): if not isinstance(query, str):
raise TypeError( raise TypeError(f"'fts' queries must be a string: {type(query)}")
f"'fts' query must be a string or FullTextQuery: {type(query)}"
)
return query, query_type return query, query_type
elif query_type == "vector": elif query_type == "vector":
query = cls._query_to_vector(table, query, vector_column_name) query = cls._query_to_vector(table, query, vector_column_name)
@@ -657,12 +444,7 @@ class LanceQueryBuilder(ABC):
""" """
return self.to_pandas() return self.to_pandas()
def to_pandas( def to_pandas(self, flatten: Optional[Union[int, bool]] = None) -> "pd.DataFrame":
self,
flatten: Optional[Union[int, bool]] = None,
*,
timeout: Optional[timedelta] = None,
) -> "pd.DataFrame":
""" """
Execute the query and return the results as a pandas DataFrame. Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector In addition to the selected columns, LanceDB also returns a vector
@@ -676,15 +458,12 @@ class LanceQueryBuilder(ABC):
If flatten is an integer, flatten the nested columns up to the If flatten is an integer, flatten the nested columns up to the
specified depth. specified depth.
If unspecified, do not flatten the nested columns. If unspecified, do not flatten the nested columns.
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
""" """
tbl = flatten_columns(self.to_arrow(timeout=timeout), flatten) tbl = flatten_columns(self.to_arrow(), flatten)
return tbl.to_pandas() return tbl.to_pandas()
@abstractmethod @abstractmethod
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table: def to_arrow(self) -> pa.Table:
""" """
Execute the query and return the results as an Execute the query and return the results as an
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table). [Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
@@ -692,65 +471,34 @@ class LanceQueryBuilder(ABC):
In addition to the selected columns, LanceDB also returns a vector In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query and also the "_distance" column which is the distance between the query
vector and the returned vectors. vector and the returned vectors.
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def to_batches( def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
self,
/,
batch_size: Optional[int] = None,
*,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader:
""" """
Execute the query and return the results as a pyarrow Execute the query and return the results as a pyarrow
[RecordBatchReader](https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatchReader.html) [RecordBatchReader](https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatchReader.html)
Parameters
----------
batch_size: int
The maximum number of selected records in a RecordBatch object.
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
""" """
raise NotImplementedError raise NotImplementedError
def to_list(self, *, timeout: Optional[timedelta] = None) -> List[dict]: def to_list(self) -> List[dict]:
""" """
Execute the query and return the results as a list of dictionaries. Execute the query and return the results as a list of dictionaries.
Each list entry is a dictionary with the selected column names as keys, Each list entry is a dictionary with the selected column names as keys,
or all table columns if `select` is not called. The vector and the "_distance" or all table columns if `select` is not called. The vector and the "_distance"
fields are returned whether or not they're explicitly selected. fields are returned whether or not they're explicitly selected.
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
""" """
return self.to_arrow(timeout=timeout).to_pylist() return self.to_arrow().to_pylist()
def to_pydantic( def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
self, model: Type[LanceModel], *, timeout: Optional[timedelta] = None
) -> List[LanceModel]:
"""Return the table as a list of pydantic models. """Return the table as a list of pydantic models.
Parameters Parameters
---------- ----------
model: Type[LanceModel] model: Type[LanceModel]
The pydantic model to use. The pydantic model to use.
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
Returns Returns
------- -------
@@ -758,25 +506,19 @@ class LanceQueryBuilder(ABC):
""" """
return [ return [
model(**{k: v for k, v in row.items() if k in model.field_names()}) model(**{k: v for k, v in row.items() if k in model.field_names()})
for row in self.to_arrow(timeout=timeout).to_pylist() for row in self.to_arrow().to_pylist()
] ]
def to_polars(self, *, timeout: Optional[timedelta] = None) -> "pl.DataFrame": def to_polars(self) -> "pl.DataFrame":
""" """
Execute the query and return the results as a Polars DataFrame. Execute the query and return the results as a Polars DataFrame.
In addition to the selected columns, LanceDB also returns a vector In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query and also the "_distance" column which is the distance between the query
vector and the returned vector. vector and the returned vector.
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
""" """
import polars as pl import polars as pl
return pl.from_arrow(self.to_arrow(timeout=timeout)) return pl.from_arrow(self.to_arrow())
def limit(self, limit: Union[int, None]) -> Self: def limit(self, limit: Union[int, None]) -> Self:
"""Set the maximum number of results to return. """Set the maximum number of results to return.
@@ -915,45 +657,7 @@ class LanceQueryBuilder(ABC):
------- -------
plan : str plan : str
""" # noqa: E501 """ # noqa: E501
return self._table._explain_plan(self.to_query_object(), verbose=verbose) return self._table._explain_plan(self.to_query_object())
def analyze_plan(self) -> str:
"""
Run the query and return its execution plan with runtime metrics.
This returns detailed metrics for each step, such as elapsed time,
rows processed, bytes read, and I/O stats. It is useful for debugging
and performance tuning.
Examples
--------
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", [{"vector": [99.0, 99]}])
>>> query = [100, 100]
>>> plan = table.search(query).analyze_plan()
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
AnalyzeExec verbose=true, metrics=[]
ProjectionExec: expr=[...], metrics=[...]
GlobalLimitExec: skip=0, fetch=10, metrics=[...]
FilterExec: _distance@2 IS NOT NULL,
metrics=[output_rows=..., elapsed_compute=...]
SortExec: TopK(fetch=10), expr=[...],
preserve_partitioning=[...],
metrics=[output_rows=..., elapsed_compute=..., row_replacements=...]
KNNVectorDistance: metric=l2,
metrics=[output_rows=..., elapsed_compute=..., output_batches=...]
LanceScan: uri=..., projection=[vector], row_id=true,
row_addr=false, ordered=false,
metrics=[output_rows=..., elapsed_compute=...,
bytes_read=..., iops=..., requests=...]
Returns
-------
plan : str
The physical query execution plan with runtime metrics.
"""
return self._table._analyze_plan(self.to_query_object())
def vector(self, vector: Union[np.ndarray, list]) -> Self: def vector(self, vector: Union[np.ndarray, list]) -> Self:
"""Set the vector to search for. """Set the vector to search for.
@@ -970,14 +674,13 @@ class LanceQueryBuilder(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def text(self, text: str | FullTextQuery) -> Self: def text(self, text: str) -> Self:
"""Set the text to search for. """Set the text to search for.
Parameters Parameters
---------- ----------
text: str | FullTextQuery text: str
If a string, it is treated as a MatchQuery. The text to search for.
If a FullTextQuery object, it is used directly.
Returns Returns
------- -------
@@ -1191,7 +894,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._refine_factor = refine_factor self._refine_factor = refine_factor
return self return self
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table: def to_arrow(self) -> pa.Table:
""" """
Execute the query and return the results as an Execute the query and return the results as an
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table). [Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
@@ -1199,14 +902,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
In addition to the selected columns, LanceDB also returns a vector In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query and also the "_distance" column which is the distance between the query
vector and the returned vectors. vector and the returned vectors.
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
""" """
return self.to_batches(timeout=timeout).read_all() return self.to_batches().read_all()
def to_query_object(self) -> Query: def to_query_object(self) -> Query:
""" """
@@ -1236,13 +933,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
bypass_vector_index=self._bypass_vector_index, bypass_vector_index=self._bypass_vector_index,
) )
def to_batches( def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
self,
/,
batch_size: Optional[int] = None,
*,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader:
""" """
Execute the query and return the result as a RecordBatchReader object. Execute the query and return the result as a RecordBatchReader object.
@@ -1250,9 +941,6 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
---------- ----------
batch_size: int batch_size: int
The maximum number of selected records in a RecordBatch object. The maximum number of selected records in a RecordBatch object.
timeout: timedelta, default None
The maximum time to wait for the query to complete.
If None, wait indefinitely.
Returns Returns
------- -------
@@ -1262,9 +950,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
if isinstance(vector[0], np.ndarray): if isinstance(vector[0], np.ndarray):
vector = [v.tolist() for v in vector] vector = [v.tolist() for v in vector]
query = self.to_query_object() query = self.to_query_object()
result_set = self._table._execute_query( result_set = self._table._execute_query(query, batch_size)
query, batch_size=batch_size, timeout=timeout
)
if self._reranker is not None: if self._reranker is not None:
rs_table = result_set.read_all() rs_table = result_set.read_all()
result_set = self._reranker.rerank_vector(self._str_query, rs_table) result_set = self._reranker.rerank_vector(self._str_query, rs_table)
@@ -1360,7 +1046,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
def __init__( def __init__(
self, self,
table: "Table", table: "Table",
query: str | FullTextQuery, query: str,
ordering_field_name: Optional[str] = None, ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None, fts_columns: Optional[Union[str, List[str]]] = None,
): ):
@@ -1403,7 +1089,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
offset=self._offset, offset=self._offset,
) )
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table: def to_arrow(self) -> pa.Table:
path, fs, exist = self._table._get_fts_index_path() path, fs, exist = self._table._get_fts_index_path()
if exist: if exist:
return self.tantivy_to_arrow() return self.tantivy_to_arrow()
@@ -1415,16 +1101,14 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
"Use tantivy-based index instead for now." "Use tantivy-based index instead for now."
) )
query = self.to_query_object() query = self.to_query_object()
results = self._table._execute_query(query, timeout=timeout) results = self._table._execute_query(query)
results = results.read_all() results = results.read_all()
if self._reranker is not None: if self._reranker is not None:
results = self._reranker.rerank_fts(self._query, results) results = self._reranker.rerank_fts(self._query, results)
check_reranker_result(results) check_reranker_result(results)
return results return results
def to_batches( def to_batches(self, /, batch_size: Optional[int] = None):
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
):
raise NotImplementedError("to_batches on an FTS query") raise NotImplementedError("to_batches on an FTS query")
def tantivy_to_arrow(self) -> pa.Table: def tantivy_to_arrow(self) -> pa.Table:
@@ -1529,8 +1213,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
class LanceEmptyQueryBuilder(LanceQueryBuilder): class LanceEmptyQueryBuilder(LanceQueryBuilder):
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table: def to_arrow(self) -> pa.Table:
return self.to_batches(timeout=timeout).read_all() return self.to_batches().read_all()
def to_query_object(self) -> Query: def to_query_object(self) -> Query:
return Query( return Query(
@@ -1541,11 +1225,9 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
offset=self._offset, offset=self._offset,
) )
def to_batches( def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
) -> pa.RecordBatchReader:
query = self.to_query_object() query = self.to_query_object()
return self._table._execute_query(query, batch_size=batch_size, timeout=timeout) return self._table._execute_query(query, batch_size)
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder: def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
"""Rerank the results using the specified reranker. """Rerank the results using the specified reranker.
@@ -1578,7 +1260,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
def __init__( def __init__(
self, self,
table: "Table", table: "Table",
query: Optional[Union[str, FullTextQuery]] = None, query: Optional[str] = None,
vector_column: Optional[str] = None, vector_column: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None, fts_columns: Optional[Union[str, List[str]]] = None,
): ):
@@ -1592,8 +1274,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._refine_factor = None self._refine_factor = None
self._distance_type = None self._distance_type = None
self._phrase_query = None self._phrase_query = None
self._lower_bound = None
self._upper_bound = None
def _validate_query(self, query, vector=None, text=None): def _validate_query(self, query, vector=None, text=None):
if query is not None and (vector is not None or text is not None): if query is not None and (vector is not None or text is not None):
@@ -1610,8 +1290,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
text_query = text or query text_query = text or query
if text_query is None: if text_query is None:
raise ValueError("Text query must be provided for hybrid search.") raise ValueError("Text query must be provided for hybrid search.")
if not isinstance(text_query, (str, FullTextQuery)): if not isinstance(text_query, str):
raise ValueError("Text query must be a string or FullTextQuery") raise ValueError("Text query must be a string")
return vector_query, text_query return vector_query, text_query
@@ -1635,7 +1315,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
def to_query_object(self) -> Query: def to_query_object(self) -> Query:
raise NotImplementedError("to_query_object not yet supported on a hybrid query") raise NotImplementedError("to_query_object not yet supported on a hybrid query")
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table: def to_arrow(self) -> pa.Table:
vector_query, fts_query = self._validate_query( vector_query, fts_query = self._validate_query(
self._query, self._vector, self._text self._query, self._vector, self._text
) )
@@ -1673,20 +1353,14 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._vector_query.ef(self._ef) self._vector_query.ef(self._ef)
if self._bypass_vector_index: if self._bypass_vector_index:
self._vector_query.bypass_vector_index() self._vector_query.bypass_vector_index()
if self._lower_bound or self._upper_bound:
self._vector_query.distance_range(
lower_bound=self._lower_bound, upper_bound=self._upper_bound
)
if self._reranker is None: if self._reranker is None:
self._reranker = RRFReranker() self._reranker = RRFReranker()
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
fts_future = executor.submit( fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
self._fts_query.with_row_id(True).to_arrow, timeout=timeout
)
vector_future = executor.submit( vector_future = executor.submit(
self._vector_query.with_row_id(True).to_arrow, timeout=timeout self._vector_query.with_row_id(True).to_arrow
) )
fts_results = fts_future.result() fts_results = fts_future.result()
vector_results = vector_future.result() vector_results = vector_future.result()
@@ -1773,9 +1447,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
return results return results
def to_batches( def to_batches(self):
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
):
raise NotImplementedError("to_batches not yet supported on a hybrid query") raise NotImplementedError("to_batches not yet supported on a hybrid query")
@staticmethod @staticmethod
@@ -1981,7 +1653,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._vector = vector self._vector = vector
return self return self
def text(self, text: str | FullTextQuery) -> LanceHybridQueryBuilder: def text(self, text: str) -> LanceHybridQueryBuilder:
self._text = text self._text = text
return self return self
@@ -2139,10 +1811,7 @@ class AsyncQueryBase(object):
return self return self
async def to_batches( async def to_batches(
self, self, *, max_batch_length: Optional[int] = None
*,
max_batch_length: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> AsyncRecordBatchReader: ) -> AsyncRecordBatchReader:
""" """
Execute the query and return the results as an Apache Arrow RecordBatchReader. Execute the query and return the results as an Apache Arrow RecordBatchReader.
@@ -2155,56 +1824,34 @@ class AsyncQueryBase(object):
If not specified, a default batch length is used. If not specified, a default batch length is used.
It is possible for batches to be smaller than the provided length if the It is possible for batches to be smaller than the provided length if the
underlying data is stored in smaller chunks. underlying data is stored in smaller chunks.
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
""" """
return AsyncRecordBatchReader( return AsyncRecordBatchReader(await self._inner.execute(max_batch_length))
await self._inner.execute(max_batch_length, timeout)
)
async def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table: async def to_arrow(self) -> pa.Table:
""" """
Execute the query and collect the results into an Apache Arrow Table. Execute the query and collect the results into an Apache Arrow Table.
This method will collect all results into memory before returning. If This method will collect all results into memory before returning. If
you expect a large number of results, you may want to use you expect a large number of results, you may want to use
[to_batches][lancedb.query.AsyncQueryBase.to_batches] [to_batches][lancedb.query.AsyncQueryBase.to_batches]
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
""" """
batch_iter = await self.to_batches(timeout=timeout) batch_iter = await self.to_batches()
return pa.Table.from_batches( return pa.Table.from_batches(
await batch_iter.read_all(), schema=batch_iter.schema await batch_iter.read_all(), schema=batch_iter.schema
) )
async def to_list(self, timeout: Optional[timedelta] = None) -> List[dict]: async def to_list(self) -> List[dict]:
""" """
Execute the query and return the results as a list of dictionaries. Execute the query and return the results as a list of dictionaries.
Each list entry is a dictionary with the selected column names as keys, Each list entry is a dictionary with the selected column names as keys,
or all table columns if `select` is not called. The vector and the "_distance" or all table columns if `select` is not called. The vector and the "_distance"
fields are returned whether or not they're explicitly selected. fields are returned whether or not they're explicitly selected.
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
""" """
return (await self.to_arrow(timeout=timeout)).to_pylist() return (await self.to_arrow()).to_pylist()
async def to_pandas( async def to_pandas(
self, self, flatten: Optional[Union[int, bool]] = None
flatten: Optional[Union[int, bool]] = None,
timeout: Optional[timedelta] = None,
) -> "pd.DataFrame": ) -> "pd.DataFrame":
""" """
Execute the query and collect the results into a pandas DataFrame. Execute the query and collect the results into a pandas DataFrame.
@@ -2233,19 +1880,10 @@ class AsyncQueryBase(object):
If flatten is an integer, flatten the nested columns up to the If flatten is an integer, flatten the nested columns up to the
specified depth. specified depth.
If unspecified, do not flatten the nested columns. If unspecified, do not flatten the nested columns.
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
""" """
return ( return (flatten_columns(await self.to_arrow(), flatten)).to_pandas()
flatten_columns(await self.to_arrow(timeout=timeout), flatten)
).to_pandas()
async def to_polars( async def to_polars(self) -> "pl.DataFrame":
self,
timeout: Optional[timedelta] = None,
) -> "pl.DataFrame":
""" """
Execute the query and collect the results into a Polars DataFrame. Execute the query and collect the results into a Polars DataFrame.
@@ -2254,13 +1892,6 @@ class AsyncQueryBase(object):
[to_batches][lancedb.query.AsyncQueryBase.to_batches] and convert each batch to [to_batches][lancedb.query.AsyncQueryBase.to_batches] and convert each batch to
polars separately. polars separately.
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
Examples Examples
-------- --------
@@ -2276,7 +1907,7 @@ class AsyncQueryBase(object):
""" """
import polars as pl import polars as pl
return pl.from_arrow(await self.to_arrow(timeout=timeout)) return pl.from_arrow(await self.to_arrow())
async def explain_plan(self, verbose: Optional[bool] = False): async def explain_plan(self, verbose: Optional[bool] = False):
"""Return the execution plan for this query. """Return the execution plan for this query.
@@ -2310,15 +1941,6 @@ class AsyncQueryBase(object):
""" # noqa: E501 """ # noqa: E501
return await self._inner.explain_plan(verbose) return await self._inner.explain_plan(verbose)
async def analyze_plan(self):
"""Execute the query and display with runtime metrics.
Returns
-------
plan : str
"""
return await self._inner.analyze_plan()
class AsyncQuery(AsyncQueryBase): class AsyncQuery(AsyncQueryBase):
def __init__(self, inner: LanceQuery): def __init__(self, inner: LanceQuery):
@@ -2419,7 +2041,7 @@ class AsyncQuery(AsyncQueryBase):
) )
def nearest_to_text( def nearest_to_text(
self, query: str | FullTextQuery, columns: Union[str, List[str], None] = None self, query: str, columns: Union[str, List[str], None] = None
) -> AsyncFTSQuery: ) -> AsyncFTSQuery:
""" """
Find the documents that are most relevant to the given text query. Find the documents that are most relevant to the given text query.
@@ -2445,13 +2067,9 @@ class AsyncQuery(AsyncQueryBase):
columns = [columns] columns = [columns]
if columns is None: if columns is None:
columns = [] columns = []
if isinstance(query, str):
return AsyncFTSQuery( return AsyncFTSQuery(
self._inner.nearest_to_text({"query": query, "columns": columns}) self._inner.nearest_to_text({"query": query, "columns": columns})
) )
# FullTextQuery object
return AsyncFTSQuery(self._inner.nearest_to_text({"query": query.to_dict()}))
class AsyncFTSQuery(AsyncQueryBase): class AsyncFTSQuery(AsyncQueryBase):
@@ -2547,12 +2165,9 @@ class AsyncFTSQuery(AsyncQueryBase):
) )
async def to_batches( async def to_batches(
self, self, *, max_batch_length: Optional[int] = None
*,
max_batch_length: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> AsyncRecordBatchReader: ) -> AsyncRecordBatchReader:
reader = await super().to_batches(timeout=timeout) reader = await super().to_batches()
results = pa.Table.from_batches(await reader.read_all(), reader.schema) results = pa.Table.from_batches(await reader.read_all(), reader.schema)
if self._reranker: if self._reranker:
results = self._reranker.rerank_fts(self.get_query(), results) results = self._reranker.rerank_fts(self.get_query(), results)
@@ -2737,7 +2352,7 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase):
return self return self
def nearest_to_text( def nearest_to_text(
self, query: str | FullTextQuery, columns: Union[str, List[str], None] = None self, query: str, columns: Union[str, List[str], None] = None
) -> AsyncHybridQuery: ) -> AsyncHybridQuery:
""" """
Find the documents that are most relevant to the given text query, Find the documents that are most relevant to the given text query,
@@ -2767,21 +2382,14 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase):
columns = [columns] columns = [columns]
if columns is None: if columns is None:
columns = [] columns = []
if isinstance(query, str):
return AsyncHybridQuery( return AsyncHybridQuery(
self._inner.nearest_to_text({"query": query, "columns": columns}) self._inner.nearest_to_text({"query": query, "columns": columns})
) )
# FullTextQuery object
return AsyncHybridQuery(self._inner.nearest_to_text({"query": query.to_dict()}))
async def to_batches( async def to_batches(
self, self, *, max_batch_length: Optional[int] = None
*,
max_batch_length: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> AsyncRecordBatchReader: ) -> AsyncRecordBatchReader:
reader = await super().to_batches(timeout=timeout) reader = await super().to_batches()
results = pa.Table.from_batches(await reader.read_all(), reader.schema) results = pa.Table.from_batches(await reader.read_all(), reader.schema)
if self._reranker: if self._reranker:
results = self._reranker.rerank_vector(self._query_string, results) results = self._reranker.rerank_vector(self._query_string, results)
@@ -2837,10 +2445,7 @@ class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase):
return self return self
async def to_batches( async def to_batches(
self, self, *, max_batch_length: Optional[int] = None
*,
max_batch_length: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> AsyncRecordBatchReader: ) -> AsyncRecordBatchReader:
fts_query = AsyncFTSQuery(self._inner.to_fts_query()) fts_query = AsyncFTSQuery(self._inner.to_fts_query())
vec_query = AsyncVectorQuery(self._inner.to_vector_query()) vec_query = AsyncVectorQuery(self._inner.to_vector_query())
@@ -2852,8 +2457,8 @@ class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase):
vec_query.with_row_id() vec_query.with_row_id()
fts_results, vector_results = await asyncio.gather( fts_results, vector_results = await asyncio.gather(
fts_query.to_arrow(timeout=timeout), fts_query.to_arrow(),
vec_query.to_arrow(timeout=timeout), vec_query.to_arrow(),
) )
result = LanceHybridQueryBuilder._combine_hybrid_results( result = LanceHybridQueryBuilder._combine_hybrid_results(
@@ -2905,7 +2510,7 @@ class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase):
Returns Returns
------- -------
plan : str plan
""" # noqa: E501 """ # noqa: E501
results = ["Vector Search Plan:"] results = ["Vector Search Plan:"]
@@ -2914,23 +2519,3 @@ class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase):
results.append(await self._inner.to_fts_query().explain_plan(verbose)) results.append(await self._inner.to_fts_query().explain_plan(verbose))
return "\n".join(results) return "\n".join(results)
async def analyze_plan(self):
"""
Execute the query and return the physical execution plan with runtime metrics.
This runs both the vector and FTS (full-text search) queries and returns
detailed metrics for each step of execution—such as rows processed,
elapsed time, I/O stats, and more. Its useful for debugging and
performance analysis.
Returns
-------
plan : str
"""
results = ["Vector Search Query:"]
results.append(await self._inner.to_vector_query().analyze_plan())
results.append("FTS Search Query:")
results.append(await self._inner.to_fts_query().analyze_plan())
return "\n".join(results)

View File

@@ -87,9 +87,6 @@ class RemoteTable(Table):
def checkout_latest(self): def checkout_latest(self):
return LOOP.run(self._table.checkout_latest()) return LOOP.run(self._table.checkout_latest())
def restore(self, version: Optional[int] = None):
return LOOP.run(self._table.restore(version))
def list_indices(self) -> Iterable[IndexConfig]: def list_indices(self) -> Iterable[IndexConfig]:
"""List all the indices on the table""" """List all the indices on the table"""
return LOOP.run(self._table.list_indices()) return LOOP.run(self._table.list_indices())
@@ -104,7 +101,6 @@ class RemoteTable(Table):
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar", index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
*, *,
replace: bool = False, replace: bool = False,
wait_timeout: timedelta = None,
): ):
"""Creates a scalar index """Creates a scalar index
Parameters Parameters
@@ -127,18 +123,13 @@ class RemoteTable(Table):
else: else:
raise ValueError(f"Unknown index type: {index_type}") raise ValueError(f"Unknown index type: {index_type}")
LOOP.run( LOOP.run(self._table.create_index(column, config=config, replace=replace))
self._table.create_index(
column, config=config, replace=replace, wait_timeout=wait_timeout
)
)
def create_fts_index( def create_fts_index(
self, self,
column: str, column: str,
*, *,
replace: bool = False, replace: bool = False,
wait_timeout: timedelta = None,
with_position: bool = True, with_position: bool = True,
# tokenizer configs: # tokenizer configs:
base_tokenizer: str = "simple", base_tokenizer: str = "simple",
@@ -159,11 +150,7 @@ class RemoteTable(Table):
remove_stop_words=remove_stop_words, remove_stop_words=remove_stop_words,
ascii_folding=ascii_folding, ascii_folding=ascii_folding,
) )
LOOP.run( LOOP.run(self._table.create_index(column, config=config, replace=replace))
self._table.create_index(
column, config=config, replace=replace, wait_timeout=wait_timeout
)
)
def create_index( def create_index(
self, self,
@@ -175,7 +162,6 @@ class RemoteTable(Table):
replace: Optional[bool] = None, replace: Optional[bool] = None,
accelerator: Optional[str] = None, accelerator: Optional[str] = None,
index_type="vector", index_type="vector",
wait_timeout: Optional[timedelta] = None,
): ):
"""Create an index on the table. """Create an index on the table.
Currently, the only parameters that matter are Currently, the only parameters that matter are
@@ -247,11 +233,7 @@ class RemoteTable(Table):
" 'IVF_FLAT', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'" " 'IVF_FLAT', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
) )
LOOP.run( LOOP.run(self._table.create_index(vector_column_name, config=config))
self._table.create_index(
vector_column_name, config=config, wait_timeout=wait_timeout
)
)
def add( def add(
self, self,
@@ -370,15 +352,9 @@ class RemoteTable(Table):
) )
def _execute_query( def _execute_query(
self, self, query: Query, batch_size: Optional[int] = None
query: Query,
*,
batch_size: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader: ) -> pa.RecordBatchReader:
async_iter = LOOP.run( async_iter = LOOP.run(self._table._execute_query(query, batch_size=batch_size))
self._table._execute_query(query, batch_size=batch_size, timeout=timeout)
)
def iter_sync(): def iter_sync():
try: try:
@@ -389,12 +365,6 @@ class RemoteTable(Table):
return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync()) return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync())
def _explain_plan(self, query: Query, verbose: Optional[bool] = False) -> str:
return LOOP.run(self._table._explain_plan(query, verbose))
def _analyze_plan(self, query: Query) -> str:
return LOOP.run(self._table._analyze_plan(query))
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder] """Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
that can be used to create a "merge insert" operation. that can be used to create a "merge insert" operation.
@@ -569,11 +539,6 @@ class RemoteTable(Table):
def drop_index(self, index_name: str): def drop_index(self, index_name: str):
return LOOP.run(self._table.drop_index(index_name)) return LOOP.run(self._table.drop_index(index_name))
def wait_for_index(
self, index_names: Iterable[str], timeout: timedelta = timedelta(seconds=300)
):
return LOOP.run(self._table.wait_for_index(index_names, timeout))
def uses_v2_manifest_paths(self) -> bool: def uses_v2_manifest_paths(self) -> bool:
raise NotImplementedError( raise NotImplementedError(
"uses_v2_manifest_paths() is not supported on the LanceDB Cloud" "uses_v2_manifest_paths() is not supported on the LanceDB Cloud"

View File

@@ -47,9 +47,6 @@ class AnswerdotaiRerankers(Reranker):
) )
def _rerank(self, result_set: pa.Table, query: str): def _rerank(self, result_set: pa.Table, query: str):
result_set = self._handle_empty_results(result_set)
if len(result_set) == 0:
return result_set
docs = result_set[self.column].to_pylist() docs = result_set[self.column].to_pylist()
doc_ids = list(range(len(docs))) doc_ids = list(range(len(docs)))
result = self.reranker.rank(query, docs, doc_ids=doc_ids) result = self.reranker.rank(query, docs, doc_ids=doc_ids)
@@ -86,6 +83,7 @@ class AnswerdotaiRerankers(Reranker):
vector_results = self._rerank(vector_results, query) vector_results = self._rerank(vector_results, query)
if self.score == "relevance": if self.score == "relevance":
vector_results = vector_results.drop_columns(["_distance"]) vector_results = vector_results.drop_columns(["_distance"])
vector_results = vector_results.sort_by([("_relevance_score", "descending")]) vector_results = vector_results.sort_by([("_relevance_score", "descending")])
return vector_results return vector_results
@@ -93,5 +91,7 @@ class AnswerdotaiRerankers(Reranker):
fts_results = self._rerank(fts_results, query) fts_results = self._rerank(fts_results, query)
if self.score == "relevance": if self.score == "relevance":
fts_results = fts_results.drop_columns(["_score"]) fts_results = fts_results.drop_columns(["_score"])
fts_results = fts_results.sort_by([("_relevance_score", "descending")]) fts_results = fts_results.sort_by([("_relevance_score", "descending")])
return fts_results return fts_results

View File

@@ -65,16 +65,6 @@ class Reranker(ABC):
f"{self.__class__.__name__} does not implement rerank_vector" f"{self.__class__.__name__} does not implement rerank_vector"
) )
def _handle_empty_results(self, results: pa.Table):
"""
Helper method to handle empty FTS results consistently
"""
if len(results) > 0:
return results
return results.append_column(
"_relevance_score", pa.array([], type=pa.float32())
)
def rerank_fts( def rerank_fts(
self, self,
query: str, query: str,

View File

@@ -62,9 +62,6 @@ class CohereReranker(Reranker):
return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key) return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key)
def _rerank(self, result_set: pa.Table, query: str): def _rerank(self, result_set: pa.Table, query: str):
result_set = self._handle_empty_results(result_set)
if len(result_set) == 0:
return result_set
docs = result_set[self.column].to_pylist() docs = result_set[self.column].to_pylist()
response = self._client.rerank( response = self._client.rerank(
query=query, query=query,
@@ -102,14 +99,24 @@ class CohereReranker(Reranker):
) )
return combined_results return combined_results
def rerank_vector(self, query: str, vector_results: pa.Table): def rerank_vector(
vector_results = self._rerank(vector_results, query) self,
query: str,
vector_results: pa.Table,
):
result_set = self._rerank(vector_results, query)
if self.score == "relevance": if self.score == "relevance":
vector_results = vector_results.drop_columns(["_distance"]) result_set = result_set.drop_columns(["_distance"])
return vector_results
def rerank_fts(self, query: str, fts_results: pa.Table): return result_set
fts_results = self._rerank(fts_results, query)
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
result_set = self._rerank(fts_results, query)
if self.score == "relevance": if self.score == "relevance":
fts_results = fts_results.drop_columns(["_score"]) result_set = result_set.drop_columns(["_score"])
return fts_results
return result_set

View File

@@ -63,9 +63,6 @@ class CrossEncoderReranker(Reranker):
return cross_encoder return cross_encoder
def _rerank(self, result_set: pa.Table, query: str): def _rerank(self, result_set: pa.Table, query: str):
result_set = self._handle_empty_results(result_set)
if len(result_set) == 0:
return result_set
passages = result_set[self.column].to_pylist() passages = result_set[self.column].to_pylist()
cross_inp = [[query, passage] for passage in passages] cross_inp = [[query, passage] for passage in passages]
cross_scores = self.model.predict(cross_inp) cross_scores = self.model.predict(cross_inp)
@@ -96,7 +93,11 @@ class CrossEncoderReranker(Reranker):
return combined_results return combined_results
def rerank_vector(self, query: str, vector_results: pa.Table): def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
vector_results = self._rerank(vector_results, query) vector_results = self._rerank(vector_results, query)
if self.score == "relevance": if self.score == "relevance":
vector_results = vector_results.drop_columns(["_distance"]) vector_results = vector_results.drop_columns(["_distance"])
@@ -104,7 +105,11 @@ class CrossEncoderReranker(Reranker):
vector_results = vector_results.sort_by([("_relevance_score", "descending")]) vector_results = vector_results.sort_by([("_relevance_score", "descending")])
return vector_results return vector_results
def rerank_fts(self, query: str, fts_results: pa.Table): def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
fts_results = self._rerank(fts_results, query) fts_results = self._rerank(fts_results, query)
if self.score == "relevance": if self.score == "relevance":
fts_results = fts_results.drop_columns(["_score"]) fts_results = fts_results.drop_columns(["_score"])

View File

@@ -62,9 +62,6 @@ class JinaReranker(Reranker):
return self._session return self._session
def _rerank(self, result_set: pa.Table, query: str): def _rerank(self, result_set: pa.Table, query: str):
result_set = self._handle_empty_results(result_set)
if len(result_set) == 0:
return result_set
docs = result_set[self.column].to_pylist() docs = result_set[self.column].to_pylist()
response = self._client.post( # type: ignore response = self._client.post( # type: ignore
API_URL, API_URL,
@@ -107,14 +104,24 @@ class JinaReranker(Reranker):
) )
return combined_results return combined_results
def rerank_vector(self, query: str, vector_results: pa.Table): def rerank_vector(
vector_results = self._rerank(vector_results, query) self,
query: str,
vector_results: pa.Table,
):
result_set = self._rerank(vector_results, query)
if self.score == "relevance": if self.score == "relevance":
vector_results = vector_results.drop_columns(["_distance"]) result_set = result_set.drop_columns(["_distance"])
return vector_results
def rerank_fts(self, query: str, fts_results: pa.Table): return result_set
fts_results = self._rerank(fts_results, query)
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
result_set = self._rerank(fts_results, query)
if self.score == "relevance": if self.score == "relevance":
fts_results = fts_results.drop_columns(["_score"]) result_set = result_set.drop_columns(["_score"])
return fts_results
return result_set

View File

@@ -44,9 +44,6 @@ class OpenaiReranker(Reranker):
self.api_key = api_key self.api_key = api_key
def _rerank(self, result_set: pa.Table, query: str): def _rerank(self, result_set: pa.Table, query: str):
result_set = self._handle_empty_results(result_set)
if len(result_set) == 0:
return result_set
docs = result_set[self.column].to_pylist() docs = result_set[self.column].to_pylist()
response = self._client.chat.completions.create( response = self._client.chat.completions.create(
model=self.model_name, model=self.model_name,
@@ -107,14 +104,18 @@ class OpenaiReranker(Reranker):
vector_results = self._rerank(vector_results, query) vector_results = self._rerank(vector_results, query)
if self.score == "relevance": if self.score == "relevance":
vector_results = vector_results.drop_columns(["_distance"]) vector_results = vector_results.drop_columns(["_distance"])
vector_results = vector_results.sort_by([("_relevance_score", "descending")]) vector_results = vector_results.sort_by([("_relevance_score", "descending")])
return vector_results return vector_results
def rerank_fts(self, query: str, fts_results: pa.Table): def rerank_fts(self, query: str, fts_results: pa.Table):
fts_results = self._rerank(fts_results, query) fts_results = self._rerank(fts_results, query)
if self.score == "relevance": if self.score == "relevance":
fts_results = fts_results.drop_columns(["_score"]) fts_results = fts_results.drop_columns(["_score"])
fts_results = fts_results.sort_by([("_relevance_score", "descending")]) fts_results = fts_results.sort_by([("_relevance_score", "descending")])
return fts_results return fts_results
@cached_property @cached_property

View File

@@ -63,9 +63,6 @@ class VoyageAIReranker(Reranker):
) )
def _rerank(self, result_set: pa.Table, query: str): def _rerank(self, result_set: pa.Table, query: str):
result_set = self._handle_empty_results(result_set)
if len(result_set) == 0:
return result_set
docs = result_set[self.column].to_pylist() docs = result_set[self.column].to_pylist()
response = self._client.rerank( response = self._client.rerank(
query=query, query=query,
@@ -104,14 +101,24 @@ class VoyageAIReranker(Reranker):
) )
return combined_results return combined_results
def rerank_vector(self, query: str, vector_results: pa.Table): def rerank_vector(
vector_results = self._rerank(vector_results, query) self,
query: str,
vector_results: pa.Table,
):
result_set = self._rerank(vector_results, query)
if self.score == "relevance": if self.score == "relevance":
vector_results = vector_results.drop_columns(["_distance"]) result_set = result_set.drop_columns(["_distance"])
return vector_results
def rerank_fts(self, query: str, fts_results: pa.Table): return result_set
fts_results = self._rerank(fts_results, query)
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
result_set = self._rerank(fts_results, query)
if self.score == "relevance": if self.score == "relevance":
fts_results = fts_results.drop_columns(["_score"]) result_set = result_set.drop_columns(["_score"])
return fts_results
return result_set

View File

@@ -52,7 +52,6 @@ from .query import (
AsyncHybridQuery, AsyncHybridQuery,
AsyncQuery, AsyncQuery,
AsyncVectorQuery, AsyncVectorQuery,
FullTextQuery,
LanceEmptyQueryBuilder, LanceEmptyQueryBuilder,
LanceFtsQueryBuilder, LanceFtsQueryBuilder,
LanceHybridQueryBuilder, LanceHybridQueryBuilder,
@@ -631,7 +630,6 @@ class Table(ABC):
index_cache_size: Optional[int] = None, index_cache_size: Optional[int] = None,
*, *,
index_type: VectorIndexType = "IVF_PQ", index_type: VectorIndexType = "IVF_PQ",
wait_timeout: Optional[timedelta] = None,
num_bits: int = 8, num_bits: int = 8,
max_iterations: int = 50, max_iterations: int = 50,
sample_rate: int = 256, sample_rate: int = 256,
@@ -667,8 +665,6 @@ class Table(ABC):
num_bits: int num_bits: int
The number of bits to encode sub-vectors. Only used with the IVF_PQ index. The number of bits to encode sub-vectors. Only used with the IVF_PQ index.
Only 4 and 8 are supported. Only 4 and 8 are supported.
wait_timeout: timedelta, optional
The timeout to wait if indexing is asynchronous.
""" """
raise NotImplementedError raise NotImplementedError
@@ -692,23 +688,6 @@ class Table(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def wait_for_index(
self, index_names: Iterable[str], timeout: timedelta = timedelta(seconds=300)
) -> None:
"""
Wait for indexing to complete for the given index names.
This will poll the table until all the indices are fully indexed,
or raise a timeout exception if the timeout is reached.
Parameters
----------
index_names: str
The name of the indices to poll
timeout: timedelta
Timeout to wait for asynchronous indexing. The default is 5 minutes.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def create_scalar_index( def create_scalar_index(
self, self,
@@ -716,7 +695,6 @@ class Table(ABC):
*, *,
replace: bool = True, replace: bool = True,
index_type: ScalarIndexType = "BTREE", index_type: ScalarIndexType = "BTREE",
wait_timeout: Optional[timedelta] = None,
): ):
"""Create a scalar index on a column. """Create a scalar index on a column.
@@ -729,8 +707,7 @@ class Table(ABC):
Replace the existing index if it exists. Replace the existing index if it exists.
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"], default "BTREE" index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"], default "BTREE"
The type of index to create. The type of index to create.
wait_timeout: timedelta, optional
The timeout to wait if indexing is asynchronous.
Examples Examples
-------- --------
@@ -789,7 +766,6 @@ class Table(ABC):
stem: bool = False, stem: bool = False,
remove_stop_words: bool = False, remove_stop_words: bool = False,
ascii_folding: bool = False, ascii_folding: bool = False,
wait_timeout: Optional[timedelta] = None,
): ):
"""Create a full-text search index on the table. """Create a full-text search index on the table.
@@ -845,8 +821,6 @@ class Table(ABC):
ascii_folding : bool, default False ascii_folding : bool, default False
Whether to fold ASCII characters. This converts accented characters to Whether to fold ASCII characters. This converts accented characters to
their ASCII equivalent. For example, "café" would be converted to "cafe". their ASCII equivalent. For example, "café" would be converted to "cafe".
wait_timeout: timedelta, optional
The timeout to wait if indexing is asynchronous.
""" """
raise NotImplementedError raise NotImplementedError
@@ -945,9 +919,7 @@ class Table(ABC):
@abstractmethod @abstractmethod
def search( def search(
self, self,
query: Optional[ query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
Union[VEC, str, "PIL.Image.Image", Tuple, FullTextQuery]
] = None,
vector_column_name: Optional[str] = None, vector_column_name: Optional[str] = None,
query_type: QueryType = "auto", query_type: QueryType = "auto",
ordering_field_name: Optional[str] = None, ordering_field_name: Optional[str] = None,
@@ -1032,19 +1004,9 @@ class Table(ABC):
@abstractmethod @abstractmethod
def _execute_query( def _execute_query(
self, self, query: Query, batch_size: Optional[int] = None
query: Query,
*,
batch_size: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader: ... ) -> pa.RecordBatchReader: ...
@abstractmethod
def _explain_plan(self, query: Query, verbose: Optional[bool] = False) -> str: ...
@abstractmethod
def _analyze_plan(self, query: Query) -> str: ...
@abstractmethod @abstractmethod
def _do_merge( def _do_merge(
self, self,
@@ -1300,21 +1262,16 @@ class Table(ABC):
""" """
@abstractmethod @abstractmethod
def add_columns( def add_columns(self, transforms: Dict[str, str]):
self, transforms: Dict[str, str] | pa.Field | List[pa.Field] | pa.Schema
):
""" """
Add new columns with defined values. Add new columns with defined values.
Parameters Parameters
---------- ----------
transforms: Dict[str, str], pa.Field, List[pa.Field], pa.Schema transforms: Dict[str, str]
A map of column name to a SQL expression to use to calculate the A map of column name to a SQL expression to use to calculate the
value of the new column. These expressions will be evaluated for value of the new column. These expressions will be evaluated for
each row in the table, and can reference existing columns. each row in the table, and can reference existing columns.
Alternatively, a pyarrow Field or Schema can be provided to add
new columns with the specified data types. The new columns will
be initialized with null values.
""" """
@abstractmethod @abstractmethod
@@ -1382,21 +1339,6 @@ class Table(ABC):
It can also be used to undo a `[Self::checkout]` operation It can also be used to undo a `[Self::checkout]` operation
""" """
@abstractmethod
def restore(self, version: Optional[int] = None):
"""Restore a version of the table. This is an in-place operation.
This creates a new version where the data is equivalent to the
specified previous version. Data is not copied (as of python-v0.2.1).
Parameters
----------
version : int, default None
The version to restore. If unspecified then restores the currently
checked out version. If the currently checked out version is the
latest version then this is a no-op.
"""
@abstractmethod @abstractmethod
def list_versions(self) -> List[Dict[str, Any]]: def list_versions(self) -> List[Dict[str, Any]]:
"""List all versions of the table""" """List all versions of the table"""
@@ -1770,37 +1712,8 @@ class LanceTable(Table):
) )
def drop_index(self, name: str) -> None: def drop_index(self, name: str) -> None:
"""
Drops an index from the table
Parameters
----------
name: str
The name of the index to drop
"""
return LOOP.run(self._table.drop_index(name)) return LOOP.run(self._table.drop_index(name))
def prewarm_index(self, name: str) -> None:
"""
Prewarms an index in the table
This loads the entire index into memory
If the index does not fit into the available cache this call
may be wasteful
Parameters
----------
name: str
The name of the index to prewarm
"""
return LOOP.run(self._table.prewarm_index(name))
def wait_for_index(
self, index_names: Iterable[str], timeout: timedelta = timedelta(seconds=300)
) -> None:
return LOOP.run(self._table.wait_for_index(index_names, timeout))
def create_scalar_index( def create_scalar_index(
self, self,
column: str, column: str,
@@ -2100,9 +2013,7 @@ class LanceTable(Table):
@overload @overload
def search( def search(
self, self,
query: Optional[ query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
Union[VEC, str, "PIL.Image.Image", Tuple, FullTextQuery]
] = None,
vector_column_name: Optional[str] = None, vector_column_name: Optional[str] = None,
query_type: Literal["hybrid"] = "hybrid", query_type: Literal["hybrid"] = "hybrid",
ordering_field_name: Optional[str] = None, ordering_field_name: Optional[str] = None,
@@ -2121,9 +2032,7 @@ class LanceTable(Table):
def search( def search(
self, self,
query: Optional[ query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
Union[VEC, str, "PIL.Image.Image", Tuple, FullTextQuery]
] = None,
vector_column_name: Optional[str] = None, vector_column_name: Optional[str] = None,
query_type: QueryType = "auto", query_type: QueryType = "auto",
ordering_field_name: Optional[str] = None, ordering_field_name: Optional[str] = None,
@@ -2195,8 +2104,6 @@ class LanceTable(Table):
and also the "_distance" column which is the distance between the query and also the "_distance" column which is the distance between the query
vector and the returned vector. vector and the returned vector.
""" """
if isinstance(query, FullTextQuery):
query_type = "fts"
vector_column_name = infer_vector_column_name( vector_column_name = infer_vector_column_name(
schema=self.schema, schema=self.schema,
query_type=query_type, query_type=query_type,
@@ -2372,15 +2279,9 @@ class LanceTable(Table):
LOOP.run(self._table.update(values, where=where, updates_sql=values_sql)) LOOP.run(self._table.update(values, where=where, updates_sql=values_sql))
def _execute_query( def _execute_query(
self, self, query: Query, batch_size: Optional[int] = None
query: Query,
*,
batch_size: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader: ) -> pa.RecordBatchReader:
async_iter = LOOP.run( async_iter = LOOP.run(self._table._execute_query(query, batch_size))
self._table._execute_query(query, batch_size=batch_size, timeout=timeout)
)
def iter_sync(): def iter_sync():
try: try:
@@ -2391,11 +2292,8 @@ class LanceTable(Table):
return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync()) return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync())
def _explain_plan(self, query: Query, verbose: Optional[bool] = False) -> str: def _explain_plan(self, query: Query) -> str:
return LOOP.run(self._table._explain_plan(query, verbose)) return LOOP.run(self._table._explain_plan(query))
def _analyze_plan(self, query: Query) -> str:
return LOOP.run(self._table._analyze_plan(query))
def _do_merge( def _do_merge(
self, self,
@@ -2544,9 +2442,7 @@ class LanceTable(Table):
""" """
return LOOP.run(self._table.index_stats(index_name)) return LOOP.run(self._table.index_stats(index_name))
def add_columns( def add_columns(self, transforms: Dict[str, str]):
self, transforms: Dict[str, str] | pa.field | List[pa.field] | pa.Schema
):
LOOP.run(self._table.add_columns(transforms)) LOOP.run(self._table.add_columns(transforms))
def alter_columns(self, *alterations: Iterable[Dict[str, str]]): def alter_columns(self, *alterations: Iterable[Dict[str, str]]):
@@ -2994,7 +2890,6 @@ class AsyncTable:
config: Optional[ config: Optional[
Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS] Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
] = None, ] = None,
wait_timeout: Optional[timedelta] = None,
): ):
"""Create an index to speed up queries """Create an index to speed up queries
@@ -3019,8 +2914,6 @@ class AsyncTable:
For advanced configuration you can specify the type of index you would For advanced configuration you can specify the type of index you would
like to create. You can also specify index-specific parameters when like to create. You can also specify index-specific parameters when
creating an index object. creating an index object.
wait_timeout: timedelta, optional
The timeout to wait if indexing is asynchronous.
""" """
if config is not None: if config is not None:
if not isinstance( if not isinstance(
@@ -3031,9 +2924,7 @@ class AsyncTable:
" Bitmap, LabelList, or FTS" " Bitmap, LabelList, or FTS"
) )
try: try:
await self._inner.create_index( await self._inner.create_index(column, index=config, replace=replace)
column, index=config, replace=replace, wait_timeout=wait_timeout
)
except ValueError as e: except ValueError as e:
if "not support the requested language" in str(e): if "not support the requested language" in str(e):
supported_langs = ", ".join(lang_mapping.values()) supported_langs = ", ".join(lang_mapping.values())
@@ -3061,40 +2952,6 @@ class AsyncTable:
""" """
await self._inner.drop_index(name) await self._inner.drop_index(name)
async def prewarm_index(self, name: str) -> None:
"""
Prewarm an index in the table.
Parameters
----------
name: str
The name of the index to prewarm
Notes
-----
This will load the index into memory. This may reduce the cold-start time for
future queries. If the index does not fit in the cache then this call may be
wasteful.
"""
await self._inner.prewarm_index(name)
async def wait_for_index(
self, index_names: Iterable[str], timeout: timedelta = timedelta(seconds=300)
) -> None:
"""
Wait for indexing to complete for the given index names.
This will poll the table until all the indices are fully indexed,
or raise a timeout exception if the timeout is reached.
Parameters
----------
index_names: str
The name of the indices to poll
timeout: timedelta
Timeout to wait for asynchronous indexing. The default is 5 minutes.
"""
await self._inner.wait_for_index(index_names, timeout)
async def add( async def add(
self, self,
data: DATA, data: DATA,
@@ -3246,9 +3103,7 @@ class AsyncTable:
@overload @overload
async def search( async def search(
self, self,
query: Optional[ query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
Union[VEC, str, "PIL.Image.Image", Tuple, FullTextQuery]
] = None,
vector_column_name: Optional[str] = None, vector_column_name: Optional[str] = None,
query_type: Literal["vector"] = ..., query_type: Literal["vector"] = ...,
ordering_field_name: Optional[str] = None, ordering_field_name: Optional[str] = None,
@@ -3257,9 +3112,7 @@ class AsyncTable:
async def search( async def search(
self, self,
query: Optional[ query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
Union[VEC, str, "PIL.Image.Image", Tuple, FullTextQuery]
] = None,
vector_column_name: Optional[str] = None, vector_column_name: Optional[str] = None,
query_type: QueryType = "auto", query_type: QueryType = "auto",
ordering_field_name: Optional[str] = None, ordering_field_name: Optional[str] = None,
@@ -3318,10 +3171,8 @@ class AsyncTable:
async def get_embedding_func( async def get_embedding_func(
vector_column_name: Optional[str], vector_column_name: Optional[str],
query_type: QueryType, query_type: QueryType,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple, FullTextQuery]], query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]],
) -> Tuple[str, EmbeddingFunctionConfig]: ) -> Tuple[str, EmbeddingFunctionConfig]:
if isinstance(query, FullTextQuery):
query_type = "fts"
schema = await self.schema() schema = await self.schema()
vector_column_name = infer_vector_column_name( vector_column_name = infer_vector_column_name(
schema=schema, schema=schema,
@@ -3371,8 +3222,6 @@ class AsyncTable:
if is_embedding(query): if is_embedding(query):
vector_query = query vector_query = query
query_type = "vector" query_type = "vector"
elif isinstance(query, FullTextQuery):
query_type = "fts"
elif isinstance(query, str): elif isinstance(query, str):
try: try:
( (
@@ -3493,15 +3342,13 @@ class AsyncTable:
async_query = async_query.nearest_to_text( async_query = async_query.nearest_to_text(
query.full_text_query.query, query.full_text_query.columns query.full_text_query.query, query.full_text_query.columns
) )
if query.full_text_query.limit is not None:
async_query = async_query.limit(query.full_text_query.limit)
return async_query return async_query
async def _execute_query( async def _execute_query(
self, self, query: Query, batch_size: Optional[int] = None
query: Query,
*,
batch_size: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader: ) -> pa.RecordBatchReader:
# The sync table calls into this method, so we need to map the # The sync table calls into this method, so we need to map the
# query to the async version of the query and run that here. This is only # query to the async version of the query and run that here. This is only
@@ -3509,19 +3356,12 @@ class AsyncTable:
async_query = self._sync_query_to_async(query) async_query = self._sync_query_to_async(query)
return await async_query.to_batches( return await async_query.to_batches(max_batch_length=batch_size)
max_batch_length=batch_size, timeout=timeout
)
async def _explain_plan(self, query: Query, verbose: Optional[bool]) -> str: async def _explain_plan(self, query: Query) -> str:
# This method is used by the sync table # This method is used by the sync table
async_query = self._sync_query_to_async(query) async_query = self._sync_query_to_async(query)
return await async_query.explain_plan(verbose) return await async_query.explain_plan()
async def _analyze_plan(self, query: Query) -> str:
# This method is used by the sync table
async_query = self._sync_query_to_async(query)
return await async_query.analyze_plan()
async def _do_merge( async def _do_merge(
self, self,
@@ -3661,9 +3501,7 @@ class AsyncTable:
return await self._inner.update(updates_sql, where) return await self._inner.update(updates_sql, where)
async def add_columns( async def add_columns(self, transforms: dict[str, str]):
self, transforms: dict[str, str] | pa.field | List[pa.field] | pa.Schema
):
""" """
Add new columns with defined values. Add new columns with defined values.
@@ -3673,18 +3511,7 @@ class AsyncTable:
A map of column name to a SQL expression to use to calculate the A map of column name to a SQL expression to use to calculate the
value of the new column. These expressions will be evaluated for value of the new column. These expressions will be evaluated for
each row in the table, and can reference existing columns. each row in the table, and can reference existing columns.
Alternatively, you can pass a pyarrow field or schema to add
new columns with NULLs.
""" """
if isinstance(transforms, pa.Field):
transforms = [transforms]
if isinstance(transforms, list) and all(
{isinstance(f, pa.Field) for f in transforms}
):
transforms = pa.schema(transforms)
if isinstance(transforms, pa.Schema):
await self._inner.add_columns_with_schema(transforms)
else:
await self._inner.add_columns(list(transforms.items())) await self._inner.add_columns(list(transforms.items()))
async def alter_columns(self, *alterations: Iterable[dict[str, Any]]): async def alter_columns(self, *alterations: Iterable[dict[str, Any]]):
@@ -3783,7 +3610,7 @@ class AsyncTable:
""" """
await self._inner.checkout_latest() await self._inner.checkout_latest()
async def restore(self, version: Optional[int] = None): async def restore(self):
""" """
Restore the table to the currently checked out version Restore the table to the currently checked out version
@@ -3796,7 +3623,7 @@ class AsyncTable:
Once the operation concludes the table will no longer be in a checked Once the operation concludes the table will no longer be in a checked
out state and the read_consistency_interval, if any, will apply. out state and the read_consistency_interval, if any, will apply.
""" """
await self._inner.restore(version) await self._inner.restore()
async def optimize( async def optimize(
self, self,

View File

@@ -253,14 +253,9 @@ def infer_vector_column_name(
query: Optional[Any], # inferred later in query builder query: Optional[Any], # inferred later in query builder
vector_column_name: Optional[str], vector_column_name: Optional[str],
): ):
if vector_column_name is not None: if (vector_column_name is None and query is not None and query_type != "fts") or (
return vector_column_name vector_column_name is None and query_type == "hybrid"
):
if query_type == "fts":
# FTS queries do not require a vector column
return None
if query is not None or query_type == "hybrid":
try: try:
vector_column_name = inf_vector_column_query(schema) vector_column_name = inf_vector_column_query(schema)
except Exception as e: except Exception as e:

View File

@@ -562,7 +562,7 @@ async def test_table_async():
async_db = await lancedb.connect_async(uri, read_consistency_interval=timedelta(0)) async_db = await lancedb.connect_async(uri, read_consistency_interval=timedelta(0))
async_tbl = await async_db.open_table("test_table_async") async_tbl = await async_db.open_table("test_table_async")
# --8<-- [end:table_async_strong_consistency] # --8<-- [end:table_async_strong_consistency]
# --8<-- [start:table_async_eventual_consistency] # --8<-- [start:table_async_ventual_consistency]
uri = "data/sample-lancedb" uri = "data/sample-lancedb"
async_db = await lancedb.connect_async( async_db = await lancedb.connect_async(
uri, read_consistency_interval=timedelta(seconds=5) uri, read_consistency_interval=timedelta(seconds=5)

View File

@@ -6,9 +6,7 @@ import lancedb
# --8<-- [end:import-lancedb] # --8<-- [end:import-lancedb]
# --8<-- [start:import-numpy] # --8<-- [start:import-numpy]
from lancedb.query import BoostQuery, MatchQuery
import numpy as np import numpy as np
import pyarrow as pa
# --8<-- [end:import-numpy] # --8<-- [end:import-numpy]
# --8<-- [start:import-datetime] # --8<-- [start:import-datetime]
@@ -156,84 +154,6 @@ async def test_vector_search_async():
# --8<-- [end:search_result_async_as_list] # --8<-- [end:search_result_async_as_list]
def test_fts_fuzzy_query():
uri = "data/fuzzy-example"
db = lancedb.connect(uri)
table = db.create_table(
"my_table_fts_fuzzy",
data=pa.table(
{
"text": [
"fa",
"fo", # spellchecker:disable-line
"fob",
"focus",
"foo",
"food",
"foul",
]
}
),
mode="overwrite",
)
table.create_fts_index("text", use_tantivy=False, replace=True)
results = table.search(MatchQuery("foo", "text", fuzziness=1)).to_pandas()
assert len(results) == 4
assert set(results["text"].to_list()) == {
"foo",
"fo", # 1 deletion # spellchecker:disable-line
"fob", # 1 substitution
"food", # 1 insertion
}
def test_fts_boost_query():
uri = "data/boost-example"
db = lancedb.connect(uri)
table = db.create_table(
"my_table_fts_boost",
data=pa.table(
{
"title": [
"The Hidden Gems of Travel",
"Exploring Nature's Wonders",
"Cultural Treasures Unveiled",
"The Nightlife Chronicles",
"Scenic Escapes and Challenges",
],
"desc": [
"A vibrant city with occasional traffic jams.",
"Beautiful landscapes but overpriced tourist spots.",
"Rich cultural heritage but humid summers.",
"Bustling nightlife but noisy streets.",
"Scenic views but limited public transport options.",
],
}
),
mode="overwrite",
)
table.create_fts_index("desc", use_tantivy=False, replace=True)
results = table.search(
BoostQuery(
MatchQuery("beautiful, cultural, nightlife", "desc"),
MatchQuery("bad traffic jams, overpriced", "desc"),
),
).to_pandas()
# we will hit 3 results because the positive query has 3 hits
assert len(results) == 3
# the one containing "overpriced" will be negatively boosted,
# so it will be the last one
assert (
results["desc"].to_list()[2]
== "Beautiful landscapes but overpriced tourist spots."
)
def test_fts_native(): def test_fts_native():
# --8<-- [start:basic_fts] # --8<-- [start:basic_fts]
uri = "data/sample-lancedb" uri = "data/sample-lancedb"

View File

@@ -11,8 +11,7 @@ import pandas as pd
import pyarrow as pa import pyarrow as pa
import pytest import pytest
from lancedb.embeddings import get_registry from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector, MultiVector from lancedb.pydantic import LanceModel, Vector
import requests
# These are integration tests for embedding functions. # These are integration tests for embedding functions.
# They are slow because they require downloading models # They are slow because they require downloading models
@@ -517,125 +516,3 @@ def test_voyageai_embedding_function():
tbl.add(df) tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims() assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_embedding_function():
voyageai = (
get_registry().get("voyageai").create(name="voyage-multimodal-3", max_retries=0)
)
class Images(LanceModel):
label: str
image_uri: str = voyageai.SourceField() # image uri as the source
image_bytes: bytes = voyageai.SourceField() # image bytes as the source
vector: Vector(voyageai.ndims()) = voyageai.VectorField() # vector column
vec_from_bytes: Vector(voyageai.ndims()) = (
voyageai.VectorField()
) # Another vector column
db = lancedb.connect("~/lancedb")
table = db.create_table("test", schema=Images, mode="overwrite")
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# get each uri as bytes
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
)
assert len(table.to_pandas()["vector"][0]) == voyageai.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_embedding_text_function():
voyageai = (
get_registry().get("voyageai").create(name="voyage-multimodal-3", max_retries=0)
)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
importlib.util.find_spec("colpali_engine") is None,
reason="colpali_engine not installed",
)
def test_colpali(tmp_path):
import requests
from lancedb.pydantic import LanceModel
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("colpali").create()
class MediaItems(LanceModel):
text: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
image_vectors: MultiVector(func.ndims()) = (
func.VectorField()
) # Multivector image embeddings
table = db.create_table("media", schema=MediaItems)
texts = [
"a cute cat playing with yarn",
"a puppy in a flower field",
"a red sports car on the highway",
"a vintage bicycle leaning against a wall",
"a plate of delicious pasta",
"fresh fruit salad in a bowl",
]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# Get images as bytes
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
)
# Test text-to-image search
image_results = (
table.search("fluffy companion", vector_column_name="image_vectors")
.limit(1)
.to_pydantic(MediaItems)[0]
)
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
# Verify multivector dimensions
first_row = table.to_arrow().to_pylist()[0]
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
assert len(first_row["image_vectors"][0]) == func.ndims(), (
"Vector dimension mismatch"
)

View File

@@ -20,9 +20,7 @@ from unittest import mock
import lancedb as ldb import lancedb as ldb
from lancedb.db import DBConnection from lancedb.db import DBConnection
from lancedb.index import FTS from lancedb.index import FTS
from lancedb.query import BoostQuery, MatchQuery, MultiMatchQuery, PhraseQuery
import numpy as np import numpy as np
import pyarrow as pa
import pandas as pd import pandas as pd
import pytest import pytest
from utils import exception_output from utils import exception_output
@@ -180,47 +178,11 @@ def test_search_fts(table, use_tantivy):
results = table.search("puppy").select(["id", "text"]).to_list() results = table.search("puppy").select(["id", "text"]).to_list()
assert len(results) == 10 assert len(results) == 10
if not use_tantivy:
# Test with a query
results = (
table.search(MatchQuery("puppy", "text"))
.select(["id", "text"])
.limit(5)
.to_list()
)
assert len(results) == 5
# Test boost query
results = (
table.search(
BoostQuery(
MatchQuery("puppy", "text"),
MatchQuery("runs", "text"),
)
)
.select(["id", "text"])
.limit(5)
.to_list()
)
assert len(results) == 5
# Test multi match query
table.create_fts_index("text2", use_tantivy=use_tantivy)
results = (
table.search(MultiMatchQuery("puppy", ["text", "text2"]))
.select(["id", "text"])
.limit(5)
.to_list()
)
assert len(results) == 5
assert len(results[0]) == 3 # id, text, _score
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fts_select_async(async_table): async def test_fts_select_async(async_table):
tbl = await async_table tbl = await async_table
await tbl.create_index("text", config=FTS()) await tbl.create_index("text", config=FTS())
await tbl.create_index("text2", config=FTS())
results = ( results = (
await tbl.query() await tbl.query()
.nearest_to_text("puppy") .nearest_to_text("puppy")
@@ -231,54 +193,6 @@ async def test_fts_select_async(async_table):
assert len(results) == 5 assert len(results) == 5
assert len(results[0]) == 3 # id, text, _score assert len(results[0]) == 3 # id, text, _score
# Test with FullTextQuery
results = (
await tbl.query()
.nearest_to_text(MatchQuery("puppy", "text"))
.select(["id", "text"])
.limit(5)
.to_list()
)
assert len(results) == 5
assert len(results[0]) == 3 # id, text, _score
# Test with BoostQuery
results = (
await tbl.query()
.nearest_to_text(
BoostQuery(
MatchQuery("puppy", "text"),
MatchQuery("runs", "text"),
)
)
.select(["id", "text"])
.limit(5)
.to_list()
)
assert len(results) == 5
assert len(results[0]) == 3 # id, text, _score
# Test with MultiMatchQuery
results = (
await tbl.query()
.nearest_to_text(MultiMatchQuery("puppy", ["text", "text2"]))
.select(["id", "text"])
.limit(5)
.to_list()
)
assert len(results) == 5
assert len(results[0]) == 3 # id, text, _score
# Test with search() API
results = (
await (await tbl.search(MatchQuery("puppy", "text")))
.select(["id", "text"])
.limit(5)
.to_list()
)
assert len(results) == 5
assert len(results[0]) == 3 # id, text, _score
def test_search_fts_phrase_query(table): def test_search_fts_phrase_query(table):
table.create_fts_index("text", use_tantivy=False, with_position=False) table.create_fts_index("text", use_tantivy=False, with_position=False)
@@ -293,13 +207,6 @@ def test_search_fts_phrase_query(table):
assert len(results) > len(phrase_results) assert len(results) > len(phrase_results)
assert len(phrase_results) > 0 assert len(phrase_results) > 0
# Test with a query
phrase_results = (
table.search(PhraseQuery("puppy runs", "text")).limit(100).to_list()
)
assert len(results) > len(phrase_results)
assert len(phrase_results) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_fts_phrase_query_async(async_table): async def test_search_fts_phrase_query_async(async_table):
@@ -320,16 +227,6 @@ async def test_search_fts_phrase_query_async(async_table):
assert len(results) > len(phrase_results) assert len(results) > len(phrase_results)
assert len(phrase_results) > 0 assert len(phrase_results) > 0
# Test with a query
phrase_results = (
await async_table.query()
.nearest_to_text(PhraseQuery("puppy runs", "text"))
.limit(100)
.to_list()
)
assert len(results) > len(phrase_results)
assert len(phrase_results) > 0
def test_search_fts_specify_column(table): def test_search_fts_specify_column(table):
table.create_fts_index("text", use_tantivy=False) table.create_fts_index("text", use_tantivy=False)
@@ -627,32 +524,3 @@ def test_language(mem_db: DBConnection):
# Stop words -> no results # Stop words -> no results
results = table.search("la", query_type="fts").limit(5).to_list() results = table.search("la", query_type="fts").limit(5).to_list()
assert len(results) == 0 assert len(results) == 0
def test_fts_on_list(mem_db: DBConnection):
data = pa.table(
{
"text": [
["lance database", "the", "search"],
["lance database"],
["lance", "search"],
["database", "search"],
["unrelated", "doc"],
],
"vector": [
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
[10.0, 11.0, 12.0],
[13.0, 14.0, 15.0],
],
}
)
table = mem_db.create_table("test", data=data)
table.create_fts_index("text", use_tantivy=False)
res = table.search("lance").limit(5).to_list()
assert len(res) == 3
res = table.search(PhraseQuery("lance database", "text")).limit(5).to_list()
assert len(res) == 2

View File

@@ -4,32 +4,13 @@
import lancedb import lancedb
from lancedb.query import LanceHybridQueryBuilder from lancedb.query import LanceHybridQueryBuilder
from lancedb.rerankers.rrf import RRFReranker
import pyarrow as pa import pyarrow as pa
import pyarrow.compute as pc import pyarrow.compute as pc
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from lancedb.index import FTS from lancedb.index import FTS
from lancedb.table import AsyncTable, Table from lancedb.table import AsyncTable
@pytest.fixture
def sync_table(tmpdir_factory) -> Table:
tmp_path = str(tmpdir_factory.mktemp("data"))
db = lancedb.connect(tmp_path)
data = pa.table(
{
"text": pa.array(["a", "b", "cat", "dog"]),
"vector": pa.array(
[[0.1, 0.1], [2, 2], [-0.1, -0.1], [0.5, -0.5]],
type=pa.list_(pa.float32(), list_size=2),
),
}
)
table = db.create_table("test", data)
table.create_fts_index("text", with_position=False, use_tantivy=False)
return table
@pytest_asyncio.fixture @pytest_asyncio.fixture
@@ -121,42 +102,6 @@ async def test_async_hybrid_query_default_limit(table: AsyncTable):
assert texts.count("a") == 1 assert texts.count("a") == 1
def test_hybrid_query_distance_range(sync_table: Table):
reranker = RRFReranker(return_score="all")
result = (
sync_table.search(query_type="hybrid")
.vector([0.0, 0.4])
.text("cat and dog")
.distance_range(lower_bound=0.2, upper_bound=0.5)
.rerank(reranker)
.limit(2)
.to_arrow()
)
assert len(result) == 2
print(result)
for dist in result["_distance"]:
if dist.is_valid:
assert 0.2 <= dist.as_py() <= 0.5
@pytest.mark.asyncio
async def test_hybrid_query_distance_range_async(table: AsyncTable):
reranker = RRFReranker(return_score="all")
result = await (
table.query()
.nearest_to([0.0, 0.4])
.nearest_to_text("cat and dog")
.distance_range(lower_bound=0.2, upper_bound=0.5)
.rerank(reranker)
.limit(2)
.to_arrow()
)
assert len(result) == 2
for dist in result["_distance"]:
if dist.is_valid:
assert 0.2 <= dist.as_py() <= 0.5
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_explain_plan(table: AsyncTable): async def test_explain_plan(table: AsyncTable):
plan = await ( plan = await (
@@ -169,16 +114,6 @@ async def test_explain_plan(table: AsyncTable):
assert "LanceScan" in plan assert "LanceScan" in plan
@pytest.mark.asyncio
async def test_analyze_plan(table: AsyncTable):
res = await (
table.query().nearest_to_text("dog").nearest_to([0.1, 0.1]).analyze_plan()
)
assert "AnalyzeExec" in res
assert "metrics=" in res
def test_normalize_scores(): def test_normalize_scores():
cases = [ cases = [
(pa.array([0.1, 0.4]), pa.array([0.0, 1.0])), (pa.array([0.1, 0.4]), pa.array([0.0, 1.0])),

View File

@@ -8,7 +8,7 @@ import pyarrow as pa
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from lancedb import AsyncConnection, AsyncTable, connect_async from lancedb import AsyncConnection, AsyncTable, connect_async
from lancedb.index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS from lancedb.index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq
@pytest_asyncio.fixture @pytest_asyncio.fixture
@@ -31,7 +31,6 @@ async def some_table(db_async):
{ {
"id": list(range(NROWS)), "id": list(range(NROWS)),
"vector": sample_fixed_size_list_array(NROWS, DIM), "vector": sample_fixed_size_list_array(NROWS, DIM),
"fsb": pa.array([bytes([i]) for i in range(NROWS)], pa.binary(1)),
"tags": [ "tags": [
[f"tag{random.randint(0, 8)}" for _ in range(2)] for _ in range(NROWS) [f"tag{random.randint(0, 8)}" for _ in range(2)] for _ in range(NROWS)
], ],
@@ -86,16 +85,6 @@ async def test_create_scalar_index(some_table: AsyncTable):
assert len(indices) == 0 assert len(indices) == 0
@pytest.mark.asyncio
async def test_create_fixed_size_binary_index(some_table: AsyncTable):
await some_table.create_index("fsb", config=BTree())
indices = await some_table.list_indices()
assert str(indices) == '[Index(BTree, columns=["fsb"], name="fsb_idx")]'
assert len(indices) == 1
assert indices[0].index_type == "BTree"
assert indices[0].columns == ["fsb"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_bitmap_index(some_table: AsyncTable): async def test_create_bitmap_index(some_table: AsyncTable):
await some_table.create_index("id", config=Bitmap()) await some_table.create_index("id", config=Bitmap())
@@ -119,18 +108,6 @@ async def test_create_label_list_index(some_table: AsyncTable):
assert str(indices) == '[Index(LabelList, columns=["tags"], name="tags_idx")]' assert str(indices) == '[Index(LabelList, columns=["tags"], name="tags_idx")]'
@pytest.mark.asyncio
async def test_full_text_search_index(some_table: AsyncTable):
await some_table.create_index("tags", config=FTS(with_position=False))
indices = await some_table.list_indices()
assert str(indices) == '[Index(FTS, columns=["tags"], name="tags_idx")]'
await some_table.prewarm_index("tags_idx")
res = await (await some_table.search("tag0")).to_arrow()
assert res.num_rows > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_vector_index(some_table: AsyncTable): async def test_create_vector_index(some_table: AsyncTable):
# Can create # Can create

View File

@@ -9,13 +9,7 @@ from typing import List, Optional, Tuple
import pyarrow as pa import pyarrow as pa
import pydantic import pydantic
import pytest import pytest
from lancedb.pydantic import ( from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
PYDANTIC_VERSION,
LanceModel,
Vector,
pydantic_to_schema,
MultiVector,
)
from pydantic import BaseModel from pydantic import BaseModel
from pydantic import Field from pydantic import Field
@@ -360,55 +354,3 @@ def test_optional_nested_model():
), ),
] ]
) )
def test_multi_vector():
class TestModel(pydantic.BaseModel):
vec: MultiVector(8)
schema = pydantic_to_schema(TestModel)
assert schema == pa.schema(
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 8)), True)]
)
with pytest.raises(pydantic.ValidationError):
TestModel(vec=[[1.0] * 7])
with pytest.raises(pydantic.ValidationError):
TestModel(vec=[[1.0] * 9])
TestModel(vec=[[1.0] * 8])
TestModel(vec=[[1.0] * 8, [2.0] * 8])
TestModel(vec=[])
def test_multi_vector_nullable():
class NullableModel(pydantic.BaseModel):
vec: MultiVector(16, nullable=False)
schema = pydantic_to_schema(NullableModel)
assert schema == pa.schema(
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), False)]
)
class DefaultModel(pydantic.BaseModel):
vec: MultiVector(16)
schema = pydantic_to_schema(DefaultModel)
assert schema == pa.schema(
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), True)]
)
def test_multi_vector_in_lance_model():
class TestModel(LanceModel):
id: int
vectors: MultiVector(16) = Field(default=[[0.0] * 16])
schema = pydantic_to_schema(TestModel)
assert schema == TestModel.to_arrow_schema()
assert TestModel.field_names() == ["id", "vectors"]
t = TestModel(id=1)
assert t.vectors == [[0.0] * 16]

View File

@@ -257,9 +257,7 @@ async def test_distance_range_with_new_rows_async():
} }
) )
table = await conn.create_table("test", data) table = await conn.create_table("test", data)
await table.create_index( table.create_index("vector", config=IvfPq(num_partitions=1, num_sub_vectors=2))
"vector", config=IvfPq(num_partitions=1, num_sub_vectors=2)
)
q = [0, 0] q = [0, 0]
rs = await table.query().nearest_to(q).to_arrow() rs = await table.query().nearest_to(q).to_arrow()
@@ -513,8 +511,7 @@ def test_query_builder_with_different_vector_column():
columns=["b"], columns=["b"],
vector_column="foo_vector", vector_column="foo_vector",
), ),
batch_size=None, None,
timeout=None,
) )
@@ -705,20 +702,6 @@ async def test_fast_search_async(tmp_path):
assert "LanceScan" not in plan assert "LanceScan" not in plan
def test_analyze_plan(table):
q = LanceVectorQueryBuilder(table, [0, 0], "vector")
res = q.analyze_plan()
assert "AnalyzeExec" in res
assert "metrics=" in res
@pytest.mark.asyncio
async def test_analyze_plan_async(table_async: AsyncTable):
res = await table_async.query().nearest_to(pa.array([1, 2])).analyze_plan()
assert "AnalyzeExec" in res
assert "metrics=" in res
def test_explain_plan(table): def test_explain_plan(table):
q = LanceVectorQueryBuilder(table, [0, 0], "vector") q = LanceVectorQueryBuilder(table, [0, 0], "vector")
plan = q.explain_plan(verbose=True) plan = q.explain_plan(verbose=True)
@@ -1079,67 +1062,3 @@ async def test_query_serialization_async(table_async: AsyncTable):
full_text_query=FullTextSearchQuery(columns=[], query="foo"), full_text_query=FullTextSearchQuery(columns=[], query="foo"),
with_row_id=False, with_row_id=False,
) )
def test_query_timeout(tmp_path):
# Use local directory instead of memory:// to add a bit of latency to
# operations so a timeout of zero will trigger exceptions.
db = lancedb.connect(tmp_path)
data = pa.table(
{
"text": ["a", "b"],
"vector": pa.FixedSizeListArray.from_arrays(
pc.random(4).cast(pa.float32()), 2
),
}
)
table = db.create_table("test", data)
table.create_fts_index("text", use_tantivy=False)
with pytest.raises(Exception, match="Query timeout"):
table.search().where("text = 'a'").to_list(timeout=timedelta(0))
with pytest.raises(Exception, match="Query timeout"):
table.search([0.0, 0.0]).to_arrow(timeout=timedelta(0))
with pytest.raises(Exception, match="Query timeout"):
table.search("a", query_type="fts").to_pandas(timeout=timedelta(0))
with pytest.raises(Exception, match="Query timeout"):
table.search(query_type="hybrid").vector([0.0, 0.0]).text("a").to_arrow(
timeout=timedelta(0)
)
@pytest.mark.asyncio
async def test_query_timeout_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
data = pa.table(
{
"text": ["a", "b"],
"vector": pa.FixedSizeListArray.from_arrays(
pc.random(4).cast(pa.float32()), 2
),
}
)
table = await db.create_table("test", data)
await table.create_index("text", config=FTS())
with pytest.raises(Exception, match="Query timeout"):
await table.query().where("text != 'a'").to_list(timeout=timedelta(0))
with pytest.raises(Exception, match="Query timeout"):
await table.vector_search([0.0, 0.0]).to_arrow(timeout=timedelta(0))
with pytest.raises(Exception, match="Query timeout"):
await (await table.search("a", query_type="fts")).to_pandas(
timeout=timedelta(0)
)
with pytest.raises(Exception, match="Query timeout"):
await (
table.query()
.nearest_to_text("a")
.nearest_to([0.0, 0.0])
.to_list(timeout=timedelta(0))
)

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors # SPDX-FileCopyrightText: Copyright The LanceDB Authors
import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import contextlib import contextlib
from datetime import timedelta from datetime import timedelta
@@ -235,10 +235,6 @@ def test_table_add_in_threadpool():
def test_table_create_indices(): def test_table_create_indices():
def handler(request): def handler(request):
index_stats = dict(
index_type="IVF_PQ", num_indexed_rows=1000, num_unindexed_rows=0
)
if request.path == "/v1/table/test/create_index/": if request.path == "/v1/table/test/create_index/":
request.send_response(200) request.send_response(200)
request.end_headers() request.end_headers()
@@ -262,47 +258,6 @@ def test_table_create_indices():
) )
) )
request.wfile.write(payload.encode()) request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/list/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(
dict(
indexes=[
{
"index_name": "id_idx",
"columns": ["id"],
},
{
"index_name": "text_idx",
"columns": ["text"],
},
{
"index_name": "vector_idx",
"columns": ["vector"],
},
]
)
)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/id_idx/stats/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(index_stats)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/text_idx/stats/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(index_stats)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/vector_idx/stats/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(index_stats)
request.wfile.write(payload.encode())
elif "/drop/" in request.path: elif "/drop/" in request.path:
request.send_response(200) request.send_response(200)
request.end_headers() request.end_headers()
@@ -314,81 +269,14 @@ def test_table_create_indices():
# Parameters are well-tested through local and async tests. # Parameters are well-tested through local and async tests.
# This is a smoke-test. # This is a smoke-test.
table = db.create_table("test", [{"id": 1}]) table = db.create_table("test", [{"id": 1}])
table.create_scalar_index("id", wait_timeout=timedelta(seconds=2)) table.create_scalar_index("id")
table.create_fts_index("text", wait_timeout=timedelta(seconds=2)) table.create_fts_index("text")
table.create_index( table.create_scalar_index("vector")
vector_column_name="vector", wait_timeout=timedelta(seconds=10)
)
table.wait_for_index(["id_idx"], timedelta(seconds=2))
table.wait_for_index(["text_idx", "vector_idx"], timedelta(seconds=2))
table.drop_index("vector_idx") table.drop_index("vector_idx")
table.drop_index("id_idx") table.drop_index("id_idx")
table.drop_index("text_idx") table.drop_index("text_idx")
def test_table_wait_for_index_timeout():
def handler(request):
index_stats = dict(
index_type="BTREE", num_indexed_rows=1000, num_unindexed_rows=1
)
if request.path == "/v1/table/test/create/?mode=create":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
elif request.path == "/v1/table/test/describe/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(
dict(
version=1,
schema=dict(
fields=[
dict(name="id", type={"type": "int64"}, nullable=False),
]
),
)
)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/list/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(
dict(
indexes=[
{
"index_name": "id_idx",
"columns": ["id"],
},
]
)
)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/id_idx/stats/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(index_stats)
print(f"{index_stats=}")
request.wfile.write(payload.encode())
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}])
with pytest.raises(
RuntimeError,
match=re.escape(
'Timeout error: timed out waiting for indices: ["id_idx"] after 1s'
),
):
table.wait_for_index(["id_idx"], timedelta(seconds=1))
@contextlib.contextmanager @contextlib.contextmanager
def query_test_table(query_handler, *, server_version=Version("0.1.0")): def query_test_table(query_handler, *, server_version=Version("0.1.0")):
def handler(request): def handler(request):
@@ -556,16 +444,6 @@ def test_query_sync_fts():
"prefilter": True, "prefilter": True,
"with_row_id": True, "with_row_id": True,
"version": None, "version": None,
} or body == {
"full_text_query": {
"query": "puppy",
"columns": ["description", "name"],
},
"k": 42,
"vector": [],
"prefilter": True,
"with_row_id": True,
"version": None,
} }
return pa.table({"id": [1, 2, 3]}) return pa.table({"id": [1, 2, 3]})

View File

@@ -457,45 +457,3 @@ def test_voyageai_reranker(tmp_path, use_tantivy):
reranker = VoyageAIReranker(model_name="rerank-2") reranker = VoyageAIReranker(model_name="rerank-2")
table, schema = get_test_table(tmp_path, use_tantivy) table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema) _run_test_reranker(reranker, table, "single player experience", None, schema)
def test_empty_result_reranker():
pytest.importorskip("sentence_transformers")
db = lancedb.connect("memory://")
# Define schema
schema = pa.schema(
[
("id", pa.int64()),
("text", pa.string()),
("vector", pa.list_(pa.float32(), 128)), # 128-dimensional vector
]
)
# Create empty table with schema
empty_table = db.create_table("empty_table", schema=schema, mode="overwrite")
empty_table.create_fts_index("text", use_tantivy=False, replace=True)
for reranker in [
CrossEncoderReranker(),
# ColbertReranker(),
# AnswerdotaiRerankers(),
# OpenaiReranker(),
# JinaReranker(),
# VoyageAIReranker(model_name="rerank-2"),
]:
results = (
empty_table.search(list(range(128)))
.limit(3)
.rerank(reranker, "query")
.to_arrow()
)
# check if empty set contains _relevance_score column
assert "_relevance_score" in results.column_names
assert len(results) == 0
results = (
empty_table.search("query", query_type="fts")
.limit(3)
.rerank(reranker)
.to_arrow()
)

View File

@@ -9,9 +9,9 @@ from typing import List
from unittest.mock import patch from unittest.mock import patch
import lancedb import lancedb
from lancedb.dependencies import _PANDAS_AVAILABLE
from lancedb.index import HnswPq, HnswSq, IvfPq from lancedb.index import HnswPq, HnswSq, IvfPq
import numpy as np import numpy as np
import pandas as pd
import polars as pl import polars as pl
import pyarrow as pa import pyarrow as pa
import pyarrow.dataset import pyarrow.dataset
@@ -138,16 +138,13 @@ def test_create_table(mem_db: DBConnection):
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
] ]
pa_table = pa.Table.from_pylist(rows, schema=schema) df = pd.DataFrame(rows)
pa_table = pa.Table.from_pandas(df, schema=schema)
data = [ data = [
("Rows", rows), ("Rows", rows),
("pd_DataFrame", df),
("pa_Table", pa_table), ("pa_Table", pa_table),
] ]
if _PANDAS_AVAILABLE:
import pandas as pd
df = pd.DataFrame(rows)
data.append(("pd_DataFrame", df))
for name, d in data: for name, d in data:
tbl = mem_db.create_table(name, data=d, schema=schema).to_arrow() tbl = mem_db.create_table(name, data=d, schema=schema).to_arrow()
@@ -299,7 +296,7 @@ def test_add_subschema(mem_db: DBConnection):
data = {"price": 10.0, "item": "foo"} data = {"price": 10.0, "item": "foo"}
table.add([data]) table.add([data])
data = pa.Table.from_pydict({"price": [2.0], "vector": [[3.1, 4.1]]}) data = pd.DataFrame({"price": [2.0], "vector": [[3.1, 4.1]]})
table.add(data) table.add(data)
data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"} data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"}
table.add([data]) table.add([data])
@@ -408,7 +405,6 @@ def test_add_nullability(mem_db: DBConnection):
def test_add_pydantic_model(mem_db: DBConnection): def test_add_pydantic_model(mem_db: DBConnection):
pytest.importorskip("pandas")
# https://github.com/lancedb/lancedb/issues/562 # https://github.com/lancedb/lancedb/issues/562
class Metadata(BaseModel): class Metadata(BaseModel):
@@ -477,10 +473,10 @@ def test_polars(mem_db: DBConnection):
table = mem_db.create_table("test", data=pl.DataFrame(data)) table = mem_db.create_table("test", data=pl.DataFrame(data))
assert len(table) == 2 assert len(table) == 2
result = table.to_arrow() result = table.to_pandas()
assert np.allclose(result["vector"].to_pylist(), data["vector"]) assert np.allclose(result["vector"].tolist(), data["vector"])
assert result["item"].to_pylist() == data["item"] assert result["item"].tolist() == data["item"]
assert np.allclose(result["price"].to_pylist(), data["price"]) assert np.allclose(result["price"].tolist(), data["price"])
schema = pa.schema( schema = pa.schema(
[ [
@@ -692,7 +688,7 @@ def test_delete(mem_db: DBConnection):
assert len(table.list_versions()) == 2 assert len(table.list_versions()) == 2
assert table.version == 2 assert table.version == 2
assert len(table) == 1 assert len(table) == 1
assert table.to_arrow()["id"].to_pylist() == [1] assert table.to_pandas()["id"].tolist() == [1]
def test_update(mem_db: DBConnection): def test_update(mem_db: DBConnection):
@@ -856,7 +852,6 @@ def test_merge_insert(mem_db: DBConnection):
ids=["pa.Table", "pd.DataFrame", "rows"], ids=["pa.Table", "pd.DataFrame", "rows"],
) )
def test_merge_insert_subschema(mem_db: DBConnection, data_format): def test_merge_insert_subschema(mem_db: DBConnection, data_format):
pytest.importorskip("pandas")
initial_data = pa.table( initial_data = pa.table(
{"id": range(3), "a": [1.0, 2.0, 3.0], "c": ["x", "x", "x"]} {"id": range(3), "a": [1.0, 2.0, 3.0], "c": ["x", "x", "x"]}
) )
@@ -953,7 +948,7 @@ def test_create_with_embedding_function(mem_db: DBConnection):
func = MockTextEmbeddingFunction.create() func = MockTextEmbeddingFunction.create()
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"] texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pa.table({"text": texts, "vector": func.compute_source_embeddings(texts)}) df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
conf = EmbeddingFunctionConfig( conf = EmbeddingFunctionConfig(
source_column="text", vector_column="vector", function=func source_column="text", vector_column="vector", function=func
@@ -978,7 +973,7 @@ def test_create_f16_table(mem_db: DBConnection):
text: str text: str
vector: Vector(32, value_type=pa.float16()) vector: Vector(32, value_type=pa.float16())
df = pa.table( df = pd.DataFrame(
{ {
"text": [f"s-{i}" for i in range(512)], "text": [f"s-{i}" for i in range(512)],
"vector": [np.random.randn(32).astype(np.float16) for _ in range(512)], "vector": [np.random.randn(32).astype(np.float16) for _ in range(512)],
@@ -991,7 +986,7 @@ def test_create_f16_table(mem_db: DBConnection):
table.add(df) table.add(df)
table.create_index(num_partitions=2, num_sub_vectors=2) table.create_index(num_partitions=2, num_sub_vectors=2)
query = df["vector"][2].as_py() query = df.vector.iloc[2]
expected = table.search(query).limit(2).to_arrow() expected = table.search(query).limit(2).to_arrow()
assert "s-2" in expected["text"].to_pylist() assert "s-2" in expected["text"].to_pylist()
@@ -1007,7 +1002,7 @@ def test_add_with_embedding_function(mem_db: DBConnection):
table = mem_db.create_table("my_table", schema=MyTable) table = mem_db.create_table("my_table", schema=MyTable)
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"] texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pa.table({"text": texts}) df = pd.DataFrame({"text": texts})
table.add(df) table.add(df)
texts = ["the quick brown fox", "jumped over the lazy dog"] texts = ["the quick brown fox", "jumped over the lazy dog"]
@@ -1038,14 +1033,14 @@ def test_multiple_vector_columns(mem_db: DBConnection):
{"vector1": v1, "vector2": v2, "text": "foo"}, {"vector1": v1, "vector2": v2, "text": "foo"},
{"vector1": v2, "vector2": v1, "text": "bar"}, {"vector1": v2, "vector2": v1, "text": "bar"},
] ]
df = pa.Table.from_pylist(data) df = pd.DataFrame(data)
table.add(df) table.add(df)
q = np.random.randn(10) q = np.random.randn(10)
result1 = table.search(q, vector_column_name="vector1").limit(1).to_arrow() result1 = table.search(q, vector_column_name="vector1").limit(1).to_pandas()
result2 = table.search(q, vector_column_name="vector2").limit(1).to_arrow() result2 = table.search(q, vector_column_name="vector2").limit(1).to_pandas()
assert result1["text"][0] != result2["text"][0] assert result1["text"].iloc[0] != result2["text"].iloc[0]
def test_create_scalar_index(mem_db: DBConnection): def test_create_scalar_index(mem_db: DBConnection):
@@ -1083,22 +1078,22 @@ def test_empty_query(mem_db: DBConnection):
"my_table", "my_table",
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}], data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
) )
df = table.search().select(["id"]).where("text='bar'").limit(1).to_arrow() df = table.search().select(["id"]).where("text='bar'").limit(1).to_pandas()
val = df["id"][0].as_py() val = df.id.iloc[0]
assert val == 1 assert val == 1
table = mem_db.create_table("my_table2", data=[{"id": i} for i in range(100)]) table = mem_db.create_table("my_table2", data=[{"id": i} for i in range(100)])
df = table.search().select(["id"]).to_arrow() df = table.search().select(["id"]).to_pandas()
assert df.num_rows == 100 assert len(df) == 100
# None is the same as default # None is the same as default
df = table.search().select(["id"]).limit(None).to_arrow() df = table.search().select(["id"]).limit(None).to_pandas()
assert df.num_rows == 100 assert len(df) == 100
# invalid limist is the same as None, wihch is the same as default # invalid limist is the same as None, wihch is the same as default
df = table.search().select(["id"]).limit(-1).to_arrow() df = table.search().select(["id"]).limit(-1).to_pandas()
assert df.num_rows == 100 assert len(df) == 100
# valid limit should work # valid limit should work
df = table.search().select(["id"]).limit(42).to_arrow() df = table.search().select(["id"]).limit(42).to_pandas()
assert df.num_rows == 42 assert len(df) == 42
def test_search_with_schema_inf_single_vector(mem_db: DBConnection): def test_search_with_schema_inf_single_vector(mem_db: DBConnection):
@@ -1117,14 +1112,14 @@ def test_search_with_schema_inf_single_vector(mem_db: DBConnection):
{"vector_col": v1, "text": "foo"}, {"vector_col": v1, "text": "foo"},
{"vector_col": v2, "text": "bar"}, {"vector_col": v2, "text": "bar"},
] ]
df = pa.Table.from_pylist(data) df = pd.DataFrame(data)
table.add(df) table.add(df)
q = np.random.randn(10) q = np.random.randn(10)
result1 = table.search(q, vector_column_name="vector_col").limit(1).to_arrow() result1 = table.search(q, vector_column_name="vector_col").limit(1).to_pandas()
result2 = table.search(q).limit(1).to_arrow() result2 = table.search(q).limit(1).to_pandas()
assert result1["text"][0].as_py() == result2["text"][0].as_py() assert result1["text"].iloc[0] == result2["text"].iloc[0]
def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection): def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection):
@@ -1144,12 +1139,12 @@ def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection):
{"vector1": v1, "vector2": v2, "text": "foo"}, {"vector1": v1, "vector2": v2, "text": "foo"},
{"vector1": v2, "vector2": v1, "text": "bar"}, {"vector1": v2, "vector2": v1, "text": "bar"},
] ]
df = pa.Table.from_pylist(data) df = pd.DataFrame(data)
table.add(df) table.add(df)
q = np.random.randn(10) q = np.random.randn(10)
with pytest.raises(ValueError): with pytest.raises(ValueError):
table.search(q).limit(1).to_arrow() table.search(q).limit(1).to_pandas()
def test_compact_cleanup(tmp_db: DBConnection): def test_compact_cleanup(tmp_db: DBConnection):
@@ -1389,37 +1384,6 @@ async def test_add_columns_async(mem_db_async: AsyncConnection):
assert data["new_col"].to_pylist() == [2, 3] assert data["new_col"].to_pylist() == [2, 3]
@pytest.mark.asyncio
async def test_add_columns_with_schema(mem_db_async: AsyncConnection):
data = pa.table({"id": [0, 1]})
table = await mem_db_async.create_table("my_table", data=data)
await table.add_columns(
[pa.field("x", pa.int64()), pa.field("vector", pa.list_(pa.float32(), 8))]
)
assert await table.schema() == pa.schema(
[
pa.field("id", pa.int64()),
pa.field("x", pa.int64()),
pa.field("vector", pa.list_(pa.float32(), 8)),
]
)
table = await mem_db_async.create_table("table2", data=data)
await table.add_columns(
pa.schema(
[pa.field("y", pa.int64()), pa.field("emb", pa.list_(pa.float32(), 8))]
)
)
assert await table.schema() == pa.schema(
[
pa.field("id", pa.int64()),
pa.field("y", pa.int64()),
pa.field("emb", pa.list_(pa.float32(), 8)),
]
)
def test_alter_columns(mem_db: DBConnection): def test_alter_columns(mem_db: DBConnection):
data = pa.table({"id": [0, 1]}) data = pa.table({"id": [0, 1]})
table = mem_db.create_table("my_table", data=data) table = mem_db.create_table("my_table", data=data)

View File

@@ -2,26 +2,25 @@
// SPDX-FileCopyrightText: Copyright The LanceDB Authors // SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use arrow::array::make_array; use arrow::array::make_array;
use arrow::array::Array; use arrow::array::Array;
use arrow::array::ArrayData; use arrow::array::ArrayData;
use arrow::pyarrow::FromPyArrow; use arrow::pyarrow::FromPyArrow;
use arrow::pyarrow::IntoPyArrow; use arrow::pyarrow::IntoPyArrow;
use lancedb::index::scalar::{FtsQuery, FullTextSearchQuery, MatchQuery, PhraseQuery}; use lancedb::index::scalar::FullTextSearchQuery;
use lancedb::query::QueryExecutionOptions; use lancedb::query::QueryExecutionOptions;
use lancedb::query::QueryFilter; use lancedb::query::QueryFilter;
use lancedb::query::{ use lancedb::query::{
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery, ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
}; };
use lancedb::table::AnyQuery; use lancedb::table::AnyQuery;
use pyo3::exceptions::PyNotImplementedError;
use pyo3::exceptions::PyRuntimeError; use pyo3::exceptions::PyRuntimeError;
use pyo3::exceptions::{PyNotImplementedError, PyValueError};
use pyo3::prelude::{PyAnyMethods, PyDictMethods}; use pyo3::prelude::{PyAnyMethods, PyDictMethods};
use pyo3::pymethods; use pyo3::pymethods;
use pyo3::types::PyDict;
use pyo3::types::PyList; use pyo3::types::PyList;
use pyo3::types::{PyDict, PyString};
use pyo3::Bound; use pyo3::Bound;
use pyo3::IntoPyObject; use pyo3::IntoPyObject;
use pyo3::PyAny; use pyo3::PyAny;
@@ -32,7 +31,7 @@ use pyo3_async_runtimes::tokio::future_into_py;
use crate::arrow::RecordBatchStream; use crate::arrow::RecordBatchStream;
use crate::error::PythonErrorExt; use crate::error::PythonErrorExt;
use crate::util::{parse_distance_type, parse_fts_query}; use crate::util::parse_distance_type;
// Python representation of full text search parameters // Python representation of full text search parameters
#[derive(Clone)] #[derive(Clone)]
@@ -46,9 +45,9 @@ pub struct PyFullTextSearchQuery {
impl From<FullTextSearchQuery> for PyFullTextSearchQuery { impl From<FullTextSearchQuery> for PyFullTextSearchQuery {
fn from(query: FullTextSearchQuery) -> Self { fn from(query: FullTextSearchQuery) -> Self {
Self { PyFullTextSearchQuery {
columns: query.columns().into_iter().collect(), columns: query.columns,
query: query.query.query().to_owned(), query: query.query,
limit: query.limit, limit: query.limit,
wand_factor: query.wand_factor, wand_factor: query.wand_factor,
} }
@@ -100,7 +99,7 @@ pub struct PyQueryRequest {
impl From<AnyQuery> for PyQueryRequest { impl From<AnyQuery> for PyQueryRequest {
fn from(query: AnyQuery) -> Self { fn from(query: AnyQuery) -> Self {
match query { match query {
AnyQuery::Query(query_request) => Self { AnyQuery::Query(query_request) => PyQueryRequest {
limit: query_request.limit, limit: query_request.limit,
offset: query_request.offset, offset: query_request.offset,
filter: query_request.filter.map(PyQueryFilter), filter: query_request.filter.map(PyQueryFilter),
@@ -122,7 +121,7 @@ impl From<AnyQuery> for PyQueryRequest {
postfilter: None, postfilter: None,
norm: None, norm: None,
}, },
AnyQuery::VectorQuery(vector_query) => Self { AnyQuery::VectorQuery(vector_query) => PyQueryRequest {
limit: vector_query.base.limit, limit: vector_query.base.limit,
offset: vector_query.base.offset, offset: vector_query.base.offset,
filter: vector_query.base.filter.map(PyQueryFilter), filter: vector_query.base.filter.map(PyQueryFilter),
@@ -237,69 +236,29 @@ impl Query {
} }
pub fn nearest_to_text(&mut self, query: Bound<'_, PyDict>) -> PyResult<FTSQuery> { pub fn nearest_to_text(&mut self, query: Bound<'_, PyDict>) -> PyResult<FTSQuery> {
let fts_query = query let query_text = query
.get_item("query")? .get_item("query")?
.ok_or(PyErr::new::<PyRuntimeError, _>( .ok_or(PyErr::new::<PyRuntimeError, _>(
"Query text is required for nearest_to_text", "Query text is required for nearest_to_text",
))?; ))?
.extract::<String>()?;
let query = if let Ok(query_text) = fts_query.downcast::<PyString>() {
let mut query_text = query_text.to_string();
let columns = query let columns = query
.get_item("columns")? .get_item("columns")?
.map(|columns| columns.extract::<Vec<String>>()) .map(|columns| columns.extract::<Vec<String>>())
.transpose()?; .transpose()?;
let is_phrase = let fts_query = FullTextSearchQuery::new(query_text).columns(columns);
query_text.len() >= 2 && query_text.starts_with('"') && query_text.ends_with('"');
let is_multi_match = columns.as_ref().map(|cols| cols.len() > 1).unwrap_or(false);
if is_phrase {
// Remove the surrounding quotes for phrase queries
query_text = query_text[1..query_text.len() - 1].to_string();
}
let query: FtsQuery = match (is_phrase, is_multi_match) {
(false, _) => MatchQuery::new(query_text).into(),
(true, false) => PhraseQuery::new(query_text).into(),
(true, true) => {
return Err(PyValueError::new_err(
"Phrase queries cannot be used with multiple columns.",
));
}
};
let mut query = FullTextSearchQuery::new_query(query);
if let Some(cols) = columns {
if !cols.is_empty() {
query = query.with_columns(&cols).map_err(|e| {
PyValueError::new_err(format!(
"Failed to set full text search columns: {}",
e
))
})?;
}
}
query
} else if let Ok(query) = fts_query.downcast::<PyDict>() {
let query = parse_fts_query(query)?;
FullTextSearchQuery::new_query(query)
} else {
return Err(PyValueError::new_err(
"query must be a string or a Query object",
));
};
Ok(FTSQuery { Ok(FTSQuery {
fts_query,
inner: self.inner.clone(), inner: self.inner.clone(),
fts_query: query,
}) })
} }
#[pyo3(signature = (max_batch_length=None, timeout=None))] #[pyo3(signature = (max_batch_length=None))]
pub fn execute( pub fn execute(
self_: PyRef<'_, Self>, self_: PyRef<'_, Self>,
max_batch_length: Option<u32>, max_batch_length: Option<u32>,
timeout: Option<Duration>,
) -> PyResult<Bound<'_, PyAny>> { ) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone(); let inner = self_.inner.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
@@ -307,15 +266,12 @@ impl Query {
if let Some(max_batch_length) = max_batch_length { if let Some(max_batch_length) = max_batch_length {
opts.max_batch_length = max_batch_length; opts.max_batch_length = max_batch_length;
} }
if let Some(timeout) = timeout {
opts.timeout = Some(timeout);
}
let inner_stream = inner.execute_with_options(opts).await.infer_error()?; let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream)) Ok(RecordBatchStream::new(inner_stream))
}) })
} }
pub fn explain_plan(self_: PyRef<'_, Self>, verbose: bool) -> PyResult<Bound<'_, PyAny>> { fn explain_plan(self_: PyRef<'_, Self>, verbose: bool) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone(); let inner = self_.inner.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
inner inner
@@ -325,16 +281,6 @@ impl Query {
}) })
} }
pub fn analyze_plan(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
inner
.analyze_plan()
.await
.map_err(|e| PyRuntimeError::new_err(e.to_string()))
})
}
pub fn to_query_request(&self) -> PyQueryRequest { pub fn to_query_request(&self) -> PyQueryRequest {
PyQueryRequest::from(AnyQuery::Query(self.inner.clone().into_request())) PyQueryRequest::from(AnyQuery::Query(self.inner.clone().into_request()))
} }
@@ -381,11 +327,10 @@ impl FTSQuery {
self.inner = self.inner.clone().postfilter(); self.inner = self.inner.clone().postfilter();
} }
#[pyo3(signature = (max_batch_length=None, timeout=None))] #[pyo3(signature = (max_batch_length=None))]
pub fn execute( pub fn execute(
self_: PyRef<'_, Self>, self_: PyRef<'_, Self>,
max_batch_length: Option<u32>, max_batch_length: Option<u32>,
timeout: Option<Duration>,
) -> PyResult<Bound<'_, PyAny>> { ) -> PyResult<Bound<'_, PyAny>> {
let inner = self_ let inner = self_
.inner .inner
@@ -397,9 +342,6 @@ impl FTSQuery {
if let Some(max_batch_length) = max_batch_length { if let Some(max_batch_length) = max_batch_length {
opts.max_batch_length = max_batch_length; opts.max_batch_length = max_batch_length;
} }
if let Some(timeout) = timeout {
opts.timeout = Some(timeout);
}
let inner_stream = inner.execute_with_options(opts).await.infer_error()?; let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream)) Ok(RecordBatchStream::new(inner_stream))
}) })
@@ -423,18 +365,8 @@ impl FTSQuery {
}) })
} }
pub fn analyze_plan(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
inner
.analyze_plan()
.await
.map_err(|e| PyRuntimeError::new_err(e.to_string()))
})
}
pub fn get_query(&self) -> String { pub fn get_query(&self) -> String {
self.fts_query.query.query().to_owned() self.fts_query.query.clone()
} }
pub fn to_query_request(&self) -> PyQueryRequest { pub fn to_query_request(&self) -> PyQueryRequest {
@@ -522,11 +454,10 @@ impl VectorQuery {
self.inner = self.inner.clone().bypass_vector_index() self.inner = self.inner.clone().bypass_vector_index()
} }
#[pyo3(signature = (max_batch_length=None, timeout=None))] #[pyo3(signature = (max_batch_length=None))]
pub fn execute( pub fn execute(
self_: PyRef<'_, Self>, self_: PyRef<'_, Self>,
max_batch_length: Option<u32>, max_batch_length: Option<u32>,
timeout: Option<Duration>,
) -> PyResult<Bound<'_, PyAny>> { ) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone(); let inner = self_.inner.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
@@ -534,15 +465,12 @@ impl VectorQuery {
if let Some(max_batch_length) = max_batch_length { if let Some(max_batch_length) = max_batch_length {
opts.max_batch_length = max_batch_length; opts.max_batch_length = max_batch_length;
} }
if let Some(timeout) = timeout {
opts.timeout = Some(timeout);
}
let inner_stream = inner.execute_with_options(opts).await.infer_error()?; let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream)) Ok(RecordBatchStream::new(inner_stream))
}) })
} }
pub fn explain_plan(self_: PyRef<'_, Self>, verbose: bool) -> PyResult<Bound<'_, PyAny>> { fn explain_plan(self_: PyRef<'_, Self>, verbose: bool) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone(); let inner = self_.inner.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
inner inner
@@ -552,16 +480,6 @@ impl VectorQuery {
}) })
} }
pub fn analyze_plan(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
inner
.analyze_plan()
.await
.map_err(|e| PyRuntimeError::new_err(e.to_string()))
})
}
pub fn nearest_to_text(&mut self, query: Bound<'_, PyDict>) -> PyResult<HybridQuery> { pub fn nearest_to_text(&mut self, query: Bound<'_, PyDict>) -> PyResult<HybridQuery> {
let base_query = self.inner.clone().into_plain(); let base_query = self.inner.clone().into_plain();
let fts_query = Query::new(base_query).nearest_to_text(query)?; let fts_query = Query::new(base_query).nearest_to_text(query)?;
@@ -652,11 +570,6 @@ impl HybridQuery {
self.inner_vec.bypass_vector_index(); self.inner_vec.bypass_vector_index();
} }
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
self.inner_vec.distance_range(lower_bound, upper_bound);
}
pub fn to_vector_query(&mut self) -> PyResult<VectorQuery> { pub fn to_vector_query(&mut self) -> PyResult<VectorQuery> {
Ok(VectorQuery { Ok(VectorQuery {
inner: self.inner_vec.inner.clone(), inner: self.inner_vec.inner.clone(),

View File

@@ -1,11 +1,9 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors // SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{collections::HashMap, sync::Arc};
use arrow::{ use arrow::{
datatypes::{DataType, Schema}, datatypes::DataType,
ffi_stream::ArrowArrayStreamReader, ffi_stream::ArrowArrayStreamReader,
pyarrow::{FromPyArrow, PyArrowType, ToPyArrow}, pyarrow::{FromPyArrow, ToPyArrow},
}; };
use lancedb::table::{ use lancedb::table::{
AddDataMode, ColumnAlteration, Duration, NewColumnTransform, OptimizeAction, OptimizeOptions, AddDataMode, ColumnAlteration, Duration, NewColumnTransform, OptimizeAction, OptimizeOptions,
@@ -18,6 +16,7 @@ use pyo3::{
Bound, FromPyObject, PyAny, PyRef, PyResult, Python, Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
}; };
use pyo3_async_runtimes::tokio::future_into_py; use pyo3_async_runtimes::tokio::future_into_py;
use std::collections::HashMap;
use crate::{ use crate::{
error::PythonErrorExt, error::PythonErrorExt,
@@ -177,19 +176,15 @@ impl Table {
}) })
} }
#[pyo3(signature = (column, index=None, replace=None, wait_timeout=None))] #[pyo3(signature = (column, index=None, replace=None))]
pub fn create_index<'a>( pub fn create_index<'a>(
self_: PyRef<'a, Self>, self_: PyRef<'a, Self>,
column: String, column: String,
index: Option<Bound<'_, PyAny>>, index: Option<Bound<'_, PyAny>>,
replace: Option<bool>, replace: Option<bool>,
wait_timeout: Option<Bound<'_, PyAny>>,
) -> PyResult<Bound<'a, PyAny>> { ) -> PyResult<Bound<'a, PyAny>> {
let index = extract_index_params(&index)?; let index = extract_index_params(&index)?;
let timeout = wait_timeout.map(|t| t.extract::<std::time::Duration>().unwrap()); let mut op = self_.inner_ref()?.create_index(&[column], index);
let mut op = self_
.inner_ref()?
.create_index_with_timeout(&[column], index, timeout);
if let Some(replace) = replace { if let Some(replace) = replace {
op = op.replace(replace); op = op.replace(replace);
} }
@@ -208,34 +203,6 @@ impl Table {
}) })
} }
pub fn wait_for_index<'a>(
self_: PyRef<'a, Self>,
index_names: Vec<String>,
timeout: Bound<'_, PyAny>,
) -> PyResult<Bound<'a, PyAny>> {
let inner = self_.inner_ref()?.clone();
let timeout = timeout.extract::<std::time::Duration>()?;
future_into_py(self_.py(), async move {
let index_refs = index_names
.iter()
.map(String::as_str)
.collect::<Vec<&str>>();
inner
.wait_for_index(&index_refs, timeout)
.await
.infer_error()?;
Ok(())
})
}
pub fn prewarm_index(self_: PyRef<'_, Self>, index_name: String) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
inner.prewarm_index(&index_name).await.infer_error()?;
Ok(())
})
}
pub fn list_indices(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> { pub fn list_indices(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone(); let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
@@ -336,16 +303,12 @@ impl Table {
}) })
} }
#[pyo3(signature = (version=None))] pub fn restore(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
pub fn restore(self_: PyRef<'_, Self>, version: Option<u64>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone(); let inner = self_.inner_ref()?.clone();
future_into_py(
future_into_py(self_.py(), async move { self_.py(),
if let Some(version) = version { async move { inner.restore().await.infer_error() },
inner.checkout(version).await.infer_error()?; )
}
inner.restore().await.infer_error()
})
} }
pub fn query(&self) -> Query { pub fn query(&self) -> Query {
@@ -477,20 +440,6 @@ impl Table {
}) })
} }
pub fn add_columns_with_schema(
self_: PyRef<'_, Self>,
schema: PyArrowType<Schema>,
) -> PyResult<Bound<'_, PyAny>> {
let arrow_schema = &schema.0;
let transform = NewColumnTransform::AllNulls(Arc::new(arrow_schema.clone()));
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
inner.add_columns(transform, None).await.infer_error()?;
Ok(())
})
}
pub fn alter_columns<'a>( pub fn alter_columns<'a>(
self_: PyRef<'a, Self>, self_: PyRef<'a, Self>,
alterations: Vec<Bound<PyDict>>, alterations: Vec<Bound<PyDict>>,

View File

@@ -3,15 +3,11 @@
use std::sync::Mutex; use std::sync::Mutex;
use lancedb::index::scalar::{BoostQuery, FtsQuery, MatchQuery, MultiMatchQuery, PhraseQuery};
use lancedb::DistanceType; use lancedb::DistanceType;
use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods};
use pyo3::types::PyDict;
use pyo3::{ use pyo3::{
exceptions::{PyRuntimeError, PyValueError}, exceptions::{PyRuntimeError, PyValueError},
pyfunction, PyResult, pyfunction, PyResult,
}; };
use pyo3::{Bound, PyAny};
/// A wrapper around a rust builder /// A wrapper around a rust builder
/// ///
@@ -63,117 +59,3 @@ pub fn validate_table_name(table_name: &str) -> PyResult<()> {
lancedb::utils::validate_table_name(table_name) lancedb::utils::validate_table_name(table_name)
.map_err(|e| PyValueError::new_err(e.to_string())) .map_err(|e| PyValueError::new_err(e.to_string()))
} }
pub fn parse_fts_query(query: &Bound<'_, PyDict>) -> PyResult<FtsQuery> {
let query_type = query.keys().get_item(0)?.extract::<String>()?;
let query_value = query
.get_item(&query_type)?
.ok_or(PyValueError::new_err(format!(
"Query type {} not found",
query_type
)))?;
let query_value = query_value.downcast::<PyDict>()?;
match query_type.as_str() {
"match" => {
let column = query_value.keys().get_item(0)?.extract::<String>()?;
let params = query_value
.get_item(&column)?
.ok_or(PyValueError::new_err(format!(
"column {} not found",
column
)))?;
let params = params.downcast::<PyDict>()?;
let query = params
.get_item("query")?
.ok_or(PyValueError::new_err("query not found"))?
.extract::<String>()?;
let boost = params
.get_item("boost")?
.ok_or(PyValueError::new_err("boost not found"))?
.extract::<f32>()?;
let fuzziness = params
.get_item("fuzziness")?
.ok_or(PyValueError::new_err("fuzziness not found"))?
.extract::<Option<u32>>()?;
let max_expansions = params
.get_item("max_expansions")?
.ok_or(PyValueError::new_err("max_expansions not found"))?
.extract::<usize>()?;
let query = MatchQuery::new(query)
.with_column(Some(column))
.with_boost(boost)
.with_fuzziness(fuzziness)
.with_max_expansions(max_expansions);
Ok(query.into())
}
"match_phrase" => {
let column = query_value.keys().get_item(0)?.extract::<String>()?;
let query = query_value
.get_item(&column)?
.ok_or(PyValueError::new_err(format!(
"column {} not found",
column
)))?
.extract::<String>()?;
let query = PhraseQuery::new(query).with_column(Some(column));
Ok(query.into())
}
"boost" => {
let positive: Bound<'_, PyAny> = query_value
.get_item("positive")?
.ok_or(PyValueError::new_err("positive not found"))?;
let positive = positive.downcast::<PyDict>()?;
let negative = query_value
.get_item("negative")?
.ok_or(PyValueError::new_err("negative not found"))?;
let negative = negative.downcast::<PyDict>()?;
let negative_boost = query_value
.get_item("negative_boost")?
.ok_or(PyValueError::new_err("negative_boost not found"))?
.extract::<f32>()?;
let positive_query = parse_fts_query(positive)?;
let negative_query = parse_fts_query(negative)?;
let query = BoostQuery::new(positive_query, negative_query, Some(negative_boost));
Ok(query.into())
}
"multi_match" => {
let query = query_value
.get_item("query")?
.ok_or(PyValueError::new_err("query not found"))?
.extract::<String>()?;
let columns = query_value
.get_item("columns")?
.ok_or(PyValueError::new_err("columns not found"))?
.extract::<Vec<String>>()?;
let boost = query_value
.get_item("boost")?
.ok_or(PyValueError::new_err("boost not found"))?
.extract::<Vec<f32>>()?;
let query = MultiMatchQuery::try_new(query, columns)
.and_then(|q| q.try_with_boosts(boost))
.map_err(|e| {
PyValueError::new_err(format!("Error creating MultiMatchQuery: {}", e))
})?;
Ok(query.into())
}
_ => Err(PyValueError::new_err(format!(
"Unsupported query type: {}",
query_type
))),
}
}

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb-node" name = "lancedb-node"
version = "0.19.0-beta.11" version = "0.18.2-beta.1"
description = "Serverless, low-latency vector database for AI applications" description = "Serverless, low-latency vector database for AI applications"
license.workspace = true license.workspace = true
edition.workspace = true edition.workspace = true

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb" name = "lancedb"
version = "0.19.0-beta.11" version = "0.18.2-beta.1"
edition.workspace = true edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications" description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true license.workspace = true

View File

@@ -81,7 +81,7 @@ impl ListingCatalogOptionsBuilder {
/// [`crate::database::listing::ListingDatabase`] /// [`crate::database::listing::ListingDatabase`]
#[derive(Debug)] #[derive(Debug)]
pub struct ListingCatalog { pub struct ListingCatalog {
object_store: Arc<ObjectStore>, object_store: ObjectStore,
uri: String, uri: String,
@@ -105,7 +105,7 @@ impl ListingCatalog {
} }
async fn open_path(path: &str) -> Result<Self> { async fn open_path(path: &str) -> Result<Self> {
let (object_store, base_path) = ObjectStore::from_uri(path).await.unwrap(); let (object_store, base_path) = ObjectStore::from_path(path).unwrap();
if object_store.is_local() { if object_store.is_local() {
Self::try_create_dir(path).context(CreateDirSnafu { path })?; Self::try_create_dir(path).context(CreateDirSnafu { path })?;
} }

View File

@@ -139,6 +139,12 @@ impl CreateTableBuilder<true> {
} }
} }
/// Apply the given write options when writing the initial data
pub fn write_options(mut self, write_options: WriteOptions) -> Self {
self.request.write_options = write_options;
self
}
/// Execute the create table operation /// Execute the create table operation
pub async fn execute(self) -> Result<Table> { pub async fn execute(self) -> Result<Table> {
let embedding_registry = self.embedding_registry.clone(); let embedding_registry = self.embedding_registry.clone();
@@ -220,12 +226,6 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
self self
} }
/// Apply the given write options when writing the initial data
pub fn write_options(mut self, write_options: WriteOptions) -> Self {
self.request.write_options = write_options;
self
}
/// Set an option for the storage layer. /// Set an option for the storage layer.
/// ///
/// Options already set on the connection will be inherited by the table, /// Options already set on the connection will be inherited by the table,

View File

@@ -201,7 +201,7 @@ impl ListingDatabaseOptionsBuilder {
/// We will have two tables named `table1` and `table2`. /// We will have two tables named `table1` and `table2`.
#[derive(Debug)] #[derive(Debug)]
pub struct ListingDatabase { pub struct ListingDatabase {
object_store: Arc<ObjectStore>, object_store: ObjectStore,
query_string: Option<String>, query_string: Option<String>,
pub(crate) uri: String, pub(crate) uri: String,

View File

@@ -35,8 +35,6 @@ pub enum Error {
Schema { message: String }, Schema { message: String },
#[snafu(display("Runtime error: {message}"))] #[snafu(display("Runtime error: {message}"))]
Runtime { message: String }, Runtime { message: String },
#[snafu(display("Timeout error: {message}"))]
Timeout { message: String },
// 3rd party / external errors // 3rd party / external errors
#[snafu(display("object_store error: {source}"))] #[snafu(display("object_store error: {source}"))]

View File

@@ -1,11 +1,11 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors // SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::Arc;
use scalar::FtsIndexBuilder; use scalar::FtsIndexBuilder;
use serde::Deserialize; use serde::Deserialize;
use serde_with::skip_serializing_none; use serde_with::skip_serializing_none;
use std::sync::Arc;
use std::time::Duration;
use vector::IvfFlatIndexBuilder; use vector::IvfFlatIndexBuilder;
use crate::{table::BaseTable, DistanceType, Error, Result}; use crate::{table::BaseTable, DistanceType, Error, Result};
@@ -17,7 +17,6 @@ use self::{
pub mod scalar; pub mod scalar;
pub mod vector; pub mod vector;
pub mod waiter;
/// Supported index types. /// Supported index types.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -70,7 +69,6 @@ pub struct IndexBuilder {
pub(crate) index: Index, pub(crate) index: Index,
pub(crate) columns: Vec<String>, pub(crate) columns: Vec<String>,
pub(crate) replace: bool, pub(crate) replace: bool,
pub(crate) wait_timeout: Option<Duration>,
} }
impl IndexBuilder { impl IndexBuilder {
@@ -80,7 +78,6 @@ impl IndexBuilder {
index, index,
columns, columns,
replace: true, replace: true,
wait_timeout: None,
} }
} }
@@ -94,15 +91,6 @@ impl IndexBuilder {
self self
} }
/// Duration of time to wait for asynchronous indexing to complete. If not set,
/// `create_index()` will not wait.
///
/// This is not supported for `NativeTable` since indexing is synchronous.
pub fn wait_timeout(mut self, d: Duration) -> Self {
self.wait_timeout = Some(d);
self
}
pub async fn execute(self) -> Result<()> { pub async fn execute(self) -> Result<()> {
self.parent.clone().create_index(self).await self.parent.clone().create_index(self).await
} }

View File

@@ -80,6 +80,5 @@ impl FtsIndexBuilder {
} }
} }
pub use lance_index::scalar::inverted::query::*;
pub use lance_index::scalar::inverted::TokenizerConfig; pub use lance_index::scalar::inverted::TokenizerConfig;
pub use lance_index::scalar::FullTextSearchQuery; pub use lance_index::scalar::FullTextSearchQuery;

View File

@@ -1,89 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use crate::error::Result;
use crate::table::BaseTable;
use crate::Error;
use log::debug;
use std::time::{Duration, Instant};
use tokio::time::sleep;
const DEFAULT_SLEEP_MS: u64 = 1000;
const MAX_WAIT: Duration = Duration::from_secs(2 * 60 * 60);
/// Poll the table using list_indices() and index_stats() until all of the indices have 0 un-indexed rows.
/// Will return Error::Timeout if the columns are not fully indexed within the timeout.
pub async fn wait_for_index(
table: &dyn BaseTable,
index_names: &[&str],
timeout: Duration,
) -> Result<()> {
if timeout > MAX_WAIT {
return Err(Error::InvalidInput {
message: format!("timeout must be less than {:?}", MAX_WAIT),
});
}
let start = Instant::now();
let mut remaining = index_names.to_vec();
// poll via list_indices() and index_stats() until all indices are created and fully indexed
while start.elapsed() < timeout {
let mut completed = vec![];
let indices = table.list_indices().await?;
for &idx in &remaining {
if !indices.iter().any(|i| i.name == *idx) {
debug!("still waiting for new index '{}'", idx);
continue;
}
let stats = table.index_stats(idx.as_ref()).await?;
match stats {
None => {
debug!("still waiting for new index '{}'", idx);
continue;
}
Some(s) => {
if s.num_unindexed_rows == 0 {
// note: this may never stabilize under constant writes.
// we should later replace this with a status/job model
completed.push(idx);
debug!(
"fully indexed '{}'. indexed rows: {}",
idx, s.num_indexed_rows
);
} else {
debug!(
"still waiting for index '{}'. unindexed rows: {}",
idx, s.num_unindexed_rows
);
}
}
}
}
remaining.retain(|idx| !completed.contains(idx));
if remaining.is_empty() {
return Ok(());
}
sleep(Duration::from_millis(DEFAULT_SLEEP_MS)).await;
}
// debug log index diagnostics
for &r in &remaining {
let stats = table.index_stats(r.as_ref()).await?;
match stats {
Some(s) => debug!(
"index '{}' not fully indexed after {:?}. stats: {:?}",
r, timeout, s
),
None => debug!("index '{}' not found after {:?}", r, timeout),
}
}
Err(Error::Timeout {
message: format!(
"timed out waiting for indices: {:?} after {:?}",
remaining, timeout
),
})
}

View File

@@ -14,9 +14,6 @@ use object_store::{
use async_trait::async_trait; use async_trait::async_trait;
#[cfg(test)]
pub mod io_tracking;
#[derive(Debug)] #[derive(Debug)]
struct MirroringObjectStore { struct MirroringObjectStore {
primary: Arc<dyn ObjectStore>, primary: Arc<dyn ObjectStore>,

View File

@@ -1,237 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{
fmt::{Display, Formatter},
sync::{Arc, Mutex},
};
use bytes::Bytes;
use futures::stream::BoxStream;
use lance::io::WrappingObjectStore;
use object_store::{
path::Path, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
PutMultipartOpts, PutOptions, PutPayload, PutResult, Result as OSResult, UploadPart,
};
#[derive(Debug, Default)]
pub struct IoStats {
pub read_iops: u64,
pub read_bytes: u64,
pub write_iops: u64,
pub write_bytes: u64,
}
impl Display for IoStats {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:#?}", self)
}
}
#[derive(Debug, Clone)]
pub struct IoTrackingStore {
target: Arc<dyn ObjectStore>,
stats: Arc<Mutex<IoStats>>,
}
impl Display for IoTrackingStore {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:#?}", self)
}
}
#[derive(Debug, Default, Clone)]
pub struct IoStatsHolder(Arc<Mutex<IoStats>>);
impl IoStatsHolder {
pub fn incremental_stats(&self) -> IoStats {
std::mem::take(&mut self.0.lock().expect("failed to lock IoStats"))
}
}
impl WrappingObjectStore for IoStatsHolder {
fn wrap(&self, target: Arc<dyn ObjectStore>) -> Arc<dyn ObjectStore> {
Arc::new(IoTrackingStore {
target,
stats: self.0.clone(),
})
}
}
impl IoTrackingStore {
pub fn new_wrapper() -> (Arc<dyn WrappingObjectStore>, Arc<Mutex<IoStats>>) {
let stats = Arc::new(Mutex::new(IoStats::default()));
(Arc::new(IoStatsHolder(stats.clone())), stats)
}
fn record_read(&self, num_bytes: u64) {
let mut stats = self.stats.lock().unwrap();
stats.read_iops += 1;
stats.read_bytes += num_bytes;
}
fn record_write(&self, num_bytes: u64) {
let mut stats = self.stats.lock().unwrap();
stats.write_iops += 1;
stats.write_bytes += num_bytes;
}
}
#[async_trait::async_trait]
#[deny(clippy::missing_trait_methods)]
impl ObjectStore for IoTrackingStore {
async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult<PutResult> {
self.record_write(bytes.content_length() as u64);
self.target.put(location, bytes).await
}
async fn put_opts(
&self,
location: &Path,
bytes: PutPayload,
opts: PutOptions,
) -> OSResult<PutResult> {
self.record_write(bytes.content_length() as u64);
self.target.put_opts(location, bytes, opts).await
}
async fn put_multipart(&self, location: &Path) -> OSResult<Box<dyn MultipartUpload>> {
let target = self.target.put_multipart(location).await?;
Ok(Box::new(IoTrackingMultipartUpload {
target,
stats: self.stats.clone(),
}))
}
async fn put_multipart_opts(
&self,
location: &Path,
opts: PutMultipartOpts,
) -> OSResult<Box<dyn MultipartUpload>> {
let target = self.target.put_multipart_opts(location, opts).await?;
Ok(Box::new(IoTrackingMultipartUpload {
target,
stats: self.stats.clone(),
}))
}
async fn get(&self, location: &Path) -> OSResult<GetResult> {
let result = self.target.get(location).await;
if let Ok(result) = &result {
let num_bytes = result.range.end - result.range.start;
self.record_read(num_bytes as u64);
}
result
}
async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
let result = self.target.get_opts(location, options).await;
if let Ok(result) = &result {
let num_bytes = result.range.end - result.range.start;
self.record_read(num_bytes as u64);
}
result
}
async fn get_range(&self, location: &Path, range: std::ops::Range<usize>) -> OSResult<Bytes> {
let result = self.target.get_range(location, range).await;
if let Ok(result) = &result {
self.record_read(result.len() as u64);
}
result
}
async fn get_ranges(
&self,
location: &Path,
ranges: &[std::ops::Range<usize>],
) -> OSResult<Vec<Bytes>> {
let result = self.target.get_ranges(location, ranges).await;
if let Ok(result) = &result {
self.record_read(result.iter().map(|b| b.len() as u64).sum());
}
result
}
async fn head(&self, location: &Path) -> OSResult<ObjectMeta> {
self.record_read(0);
self.target.head(location).await
}
async fn delete(&self, location: &Path) -> OSResult<()> {
self.record_write(0);
self.target.delete(location).await
}
fn delete_stream<'a>(
&'a self,
locations: BoxStream<'a, OSResult<Path>>,
) -> BoxStream<'a, OSResult<Path>> {
self.target.delete_stream(locations)
}
fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, OSResult<ObjectMeta>> {
self.record_read(0);
self.target.list(prefix)
}
fn list_with_offset(
&self,
prefix: Option<&Path>,
offset: &Path,
) -> BoxStream<'_, OSResult<ObjectMeta>> {
self.record_read(0);
self.target.list_with_offset(prefix, offset)
}
async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
self.record_read(0);
self.target.list_with_delimiter(prefix).await
}
async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> {
self.record_write(0);
self.target.copy(from, to).await
}
async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> {
self.record_write(0);
self.target.rename(from, to).await
}
async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
self.record_write(0);
self.target.rename_if_not_exists(from, to).await
}
async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
self.record_write(0);
self.target.copy_if_not_exists(from, to).await
}
}
#[derive(Debug)]
struct IoTrackingMultipartUpload {
target: Box<dyn MultipartUpload>,
stats: Arc<Mutex<IoStats>>,
}
#[async_trait::async_trait]
impl MultipartUpload for IoTrackingMultipartUpload {
async fn abort(&mut self) -> OSResult<()> {
self.target.abort().await
}
async fn complete(&mut self) -> OSResult<PutResult> {
self.target.complete().await
}
fn put_part(&mut self, payload: PutPayload) -> UploadPart {
{
let mut stats = self.stats.lock().unwrap();
stats.write_iops += 1;
stats.write_bytes += payload.content_length() as u64;
}
self.target.put_part(payload)
}
}

View File

@@ -1,8 +1,8 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors // SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use std::{future::Future, time::Duration};
use arrow::compute::concat_batches; use arrow::compute::concat_batches;
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array}; use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
@@ -25,7 +25,6 @@ use crate::error::{Error, Result};
use crate::rerankers::rrf::RRFReranker; use crate::rerankers::rrf::RRFReranker;
use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker}; use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker};
use crate::table::BaseTable; use crate::table::BaseTable;
use crate::utils::TimeoutStream;
use crate::DistanceType; use crate::DistanceType;
use crate::{arrow::SendableRecordBatchStream, table::AnyQuery}; use crate::{arrow::SendableRecordBatchStream, table::AnyQuery};
@@ -526,15 +525,12 @@ pub struct QueryExecutionOptions {
/// ///
/// By default, this is 1024 /// By default, this is 1024
pub max_batch_length: u32, pub max_batch_length: u32,
/// Max duration to wait for the query to execute before timing out.
pub timeout: Option<Duration>,
} }
impl Default for QueryExecutionOptions { impl Default for QueryExecutionOptions {
fn default() -> Self { fn default() -> Self {
Self { Self {
max_batch_length: 1024, max_batch_length: 1024,
timeout: None,
} }
} }
} }
@@ -583,15 +579,6 @@ pub trait ExecutableQuery {
) -> impl Future<Output = Result<SendableRecordBatchStream>> + Send; ) -> impl Future<Output = Result<SendableRecordBatchStream>> + Send;
fn explain_plan(&self, verbose: bool) -> impl Future<Output = Result<String>> + Send; fn explain_plan(&self, verbose: bool) -> impl Future<Output = Result<String>> + Send;
fn analyze_plan(&self) -> impl Future<Output = Result<String>> + Send {
self.analyze_plan_with_options(QueryExecutionOptions::default())
}
fn analyze_plan_with_options(
&self,
options: QueryExecutionOptions,
) -> impl Future<Output = Result<String>> + Send;
} }
/// A query filter that can be applied to a query /// A query filter that can be applied to a query
@@ -778,11 +765,6 @@ impl ExecutableQuery for Query {
let query = AnyQuery::Query(self.request.clone()); let query = AnyQuery::Query(self.request.clone());
self.parent.explain_plan(&query, verbose).await self.parent.explain_plan(&query, verbose).await
} }
async fn analyze_plan_with_options(&self, options: QueryExecutionOptions) -> Result<String> {
let query = AnyQuery::Query(self.request.clone());
self.parent.analyze_plan(&query, options).await
}
} }
/// A request for a nearest-neighbors search into a table /// A request for a nearest-neighbors search into a table
@@ -1011,10 +993,7 @@ impl VectorQuery {
self self
} }
pub async fn execute_hybrid( pub async fn execute_hybrid(&self) -> Result<SendableRecordBatchStream> {
&self,
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
// clone query and specify we want to include row IDs, which can be needed for reranking // clone query and specify we want to include row IDs, which can be needed for reranking
let mut fts_query = Query::new(self.parent.clone()); let mut fts_query = Query::new(self.parent.clone());
fts_query.request = self.request.base.clone(); fts_query.request = self.request.base.clone();
@@ -1023,10 +1002,7 @@ impl VectorQuery {
let mut vector_query = self.clone().with_row_id(); let mut vector_query = self.clone().with_row_id();
vector_query.request.base.full_text_search = None; vector_query.request.base.full_text_search = None;
let (fts_results, vec_results) = try_join!( let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?;
fts_query.execute_with_options(options.clone()),
vector_query.inner_execute_with_options(options)
)?;
let (fts_results, vec_results) = try_join!( let (fts_results, vec_results) = try_join!(
fts_results.try_collect::<Vec<_>>(), fts_results.try_collect::<Vec<_>>(),
@@ -1066,7 +1042,7 @@ impl VectorQuery {
})?; })?;
let mut results = reranker let mut results = reranker
.rerank_hybrid(&fts_query.query.query(), vec_results, fts_results) .rerank_hybrid(&fts_query.query, vec_results, fts_results)
.await?; .await?;
check_reranker_result(&results)?; check_reranker_result(&results)?;
@@ -1084,20 +1060,6 @@ impl VectorQuery {
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])), RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
)) ))
} }
async fn inner_execute_with_options(
&self,
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
let plan = self.create_plan(options.clone()).await?;
let inner = execute_plan(plan, Default::default())?;
let inner = if let Some(timeout) = options.timeout {
TimeoutStream::new_boxed(inner, timeout)
} else {
inner
};
Ok(DatasetRecordBatchStream::new(inner).into())
}
} }
impl ExecutableQuery for VectorQuery { impl ExecutableQuery for VectorQuery {
@@ -1111,24 +1073,22 @@ impl ExecutableQuery for VectorQuery {
options: QueryExecutionOptions, options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> { ) -> Result<SendableRecordBatchStream> {
if self.request.base.full_text_search.is_some() { if self.request.base.full_text_search.is_some() {
let hybrid_result = async move { self.execute_hybrid(options).await } let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?;
.boxed()
.await?;
return Ok(hybrid_result); return Ok(hybrid_result);
} }
self.inner_execute_with_options(options).await Ok(SendableRecordBatchStream::from(
DatasetRecordBatchStream::new(execute_plan(
self.create_plan(options).await?,
Default::default(),
)?),
))
} }
async fn explain_plan(&self, verbose: bool) -> Result<String> { async fn explain_plan(&self, verbose: bool) -> Result<String> {
let query = AnyQuery::VectorQuery(self.request.clone()); let query = AnyQuery::VectorQuery(self.request.clone());
self.parent.explain_plan(&query, verbose).await self.parent.explain_plan(&query, verbose).await
} }
async fn analyze_plan_with_options(&self, options: QueryExecutionOptions) -> Result<String> {
let query = AnyQuery::VectorQuery(self.request.clone());
self.parent.analyze_plan(&query, options).await
}
} }
impl HasQuery for VectorQuery { impl HasQuery for VectorQuery {
@@ -1410,31 +1370,6 @@ mod tests {
} }
} }
#[tokio::test]
async fn test_analyze_plan() {
let tmp_dir = tempdir().unwrap();
let table = make_test_table(&tmp_dir).await;
let result = table.query().analyze_plan().await.unwrap();
assert!(result.contains("metrics="));
}
#[tokio::test]
async fn test_analyze_plan_with_options() {
let tmp_dir = tempdir().unwrap();
let table = make_test_table(&tmp_dir).await;
let result = table
.query()
.analyze_plan_with_options(QueryExecutionOptions {
max_batch_length: 10,
..Default::default()
})
.await
.unwrap();
assert!(result.contains("metrics="));
}
fn assert_plan_exists(plan: &Arc<dyn ExecutionPlan>, name: &str) -> bool { fn assert_plan_exists(plan: &Arc<dyn ExecutionPlan>, name: &str) -> bool {
if plan.name() == name { if plan.name() == name {
return true; return true;

View File

@@ -8,7 +8,6 @@
pub(crate) mod client; pub(crate) mod client;
pub(crate) mod db; pub(crate) mod db;
mod retry;
pub(crate) mod table; pub(crate) mod table;
pub(crate) mod util; pub(crate) mod util;

View File

@@ -1,19 +1,19 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors // SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{collections::HashMap, future::Future, str::FromStr, time::Duration};
use http::HeaderName; use http::HeaderName;
use log::debug; use log::debug;
use reqwest::{ use reqwest::{
header::{HeaderMap, HeaderValue}, header::{HeaderMap, HeaderValue},
Body, Request, RequestBuilder, Response, Request, RequestBuilder, Response,
}; };
use std::{collections::HashMap, future::Future, str::FromStr, time::Duration};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::remote::db::RemoteOptions; use crate::remote::db::RemoteOptions;
use crate::remote::retry::{ResolvedRetryConfig, RetryCounter};
const REQUEST_ID_HEADER: HeaderName = HeaderName::from_static("x-request-id"); const REQUEST_ID_HEADER: &str = "x-request-id";
/// Configuration for the LanceDB Cloud HTTP client. /// Configuration for the LanceDB Cloud HTTP client.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@@ -118,14 +118,41 @@ pub struct RetryConfig {
/// You can also set the `LANCE_CLIENT_RETRY_STATUSES` environment variable /// You can also set the `LANCE_CLIENT_RETRY_STATUSES` environment variable
/// to set this value. Use a comma-separated list of integer values. /// to set this value. Use a comma-separated list of integer values.
/// ///
/// Note that write operations will never be retried on 5xx errors as this may /// The default is 429, 500, 502, 503.
/// result in duplicated writes.
///
/// The default is 409, 429, 500, 502, 503, 504.
pub statuses: Option<Vec<u16>>, pub statuses: Option<Vec<u16>>,
// TODO: should we allow customizing methods? // TODO: should we allow customizing methods?
} }
#[derive(Debug, Clone)]
struct ResolvedRetryConfig {
retries: u8,
connect_retries: u8,
read_retries: u8,
backoff_factor: f32,
backoff_jitter: f32,
statuses: Vec<reqwest::StatusCode>,
}
impl TryFrom<RetryConfig> for ResolvedRetryConfig {
type Error = Error;
fn try_from(retry_config: RetryConfig) -> Result<Self> {
Ok(Self {
retries: retry_config.retries.unwrap_or(3),
connect_retries: retry_config.connect_retries.unwrap_or(3),
read_retries: retry_config.read_retries.unwrap_or(3),
backoff_factor: retry_config.backoff_factor.unwrap_or(0.25),
backoff_jitter: retry_config.backoff_jitter.unwrap_or(0.25),
statuses: retry_config
.statuses
.unwrap_or_else(|| vec![429, 500, 502, 503])
.into_iter()
.map(|status| reqwest::StatusCode::from_u16(status).unwrap())
.collect(),
})
}
}
// We use the `HttpSend` trait to abstract over the `reqwest::Client` so that // We use the `HttpSend` trait to abstract over the `reqwest::Client` so that
// we can mock responses in tests. Based on the patterns from this blog post: // we can mock responses in tests. Based on the patterns from this blog post:
// https://write.as/balrogboogie/testing-reqwest-based-clients // https://write.as/balrogboogie/testing-reqwest-based-clients
@@ -133,8 +160,8 @@ pub struct RetryConfig {
pub struct RestfulLanceDbClient<S: HttpSend = Sender> { pub struct RestfulLanceDbClient<S: HttpSend = Sender> {
client: reqwest::Client, client: reqwest::Client,
host: String, host: String,
pub(crate) retry_config: ResolvedRetryConfig, retry_config: ResolvedRetryConfig,
pub(crate) sender: S, sender: S,
} }
pub trait HttpSend: Clone + Send + Sync + std::fmt::Debug + 'static { pub trait HttpSend: Clone + Send + Sync + std::fmt::Debug + 'static {
@@ -272,7 +299,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
) -> Result<HeaderMap> { ) -> Result<HeaderMap> {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert( headers.insert(
HeaderName::from_static("x-api-key"), "x-api-key",
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput { HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
message: "non-ascii api key provided".to_string(), message: "non-ascii api key provided".to_string(),
})?, })?,
@@ -280,7 +307,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
if region == "local" { if region == "local" {
let host = format!("{}.local.api.lancedb.com", db_name); let host = format!("{}.local.api.lancedb.com", db_name);
headers.insert( headers.insert(
http::header::HOST, "Host",
HeaderValue::from_str(&host).map_err(|_| Error::InvalidInput { HeaderValue::from_str(&host).map_err(|_| Error::InvalidInput {
message: format!("non-ascii database name '{}' provided", db_name), message: format!("non-ascii database name '{}' provided", db_name),
})?, })?,
@@ -288,7 +315,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
} }
if has_host_override { if has_host_override {
headers.insert( headers.insert(
HeaderName::from_static("x-lancedb-database"), "x-lancedb-database",
HeaderValue::from_str(db_name).map_err(|_| Error::InvalidInput { HeaderValue::from_str(db_name).map_err(|_| Error::InvalidInput {
message: format!("non-ascii database name '{}' provided", db_name), message: format!("non-ascii database name '{}' provided", db_name),
})?, })?,
@@ -296,7 +323,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
} }
if db_prefix.is_some() { if db_prefix.is_some() {
headers.insert( headers.insert(
HeaderName::from_static("x-lancedb-database-prefix"), "x-lancedb-database-prefix",
HeaderValue::from_str(db_prefix.unwrap()).map_err(|_| Error::InvalidInput { HeaderValue::from_str(db_prefix.unwrap()).map_err(|_| Error::InvalidInput {
message: format!( message: format!(
"non-ascii database prefix '{}' provided", "non-ascii database prefix '{}' provided",
@@ -308,7 +335,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
if let Some(v) = options.0.get("account_name") { if let Some(v) = options.0.get("account_name") {
headers.insert( headers.insert(
HeaderName::from_static("x-azure-storage-account-name"), "x-azure-storage-account-name",
HeaderValue::from_str(v).map_err(|_| Error::InvalidInput { HeaderValue::from_str(v).map_err(|_| Error::InvalidInput {
message: format!("non-ascii storage account name '{}' provided", db_name), message: format!("non-ascii storage account name '{}' provided", db_name),
})?, })?,
@@ -316,7 +343,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
} }
if let Some(v) = options.0.get("azure_storage_account_name") { if let Some(v) = options.0.get("azure_storage_account_name") {
headers.insert( headers.insert(
HeaderName::from_static("x-azure-storage-account-name"), "x-azure-storage-account-name",
HeaderValue::from_str(v).map_err(|_| Error::InvalidInput { HeaderValue::from_str(v).map_err(|_| Error::InvalidInput {
message: format!("non-ascii storage account name '{}' provided", db_name), message: format!("non-ascii storage account name '{}' provided", db_name),
})?, })?,
@@ -348,12 +375,41 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
self.client.post(full_uri) self.client.post(full_uri)
} }
pub async fn send(&self, req: RequestBuilder) -> Result<(String, Response)> { pub async fn send(&self, req: RequestBuilder, with_retry: bool) -> Result<(String, Response)> {
let (client, request) = req.build_split(); let (client, request) = req.build_split();
let mut request = request.unwrap(); let mut request = request.unwrap();
let request_id = self.extract_request_id(&mut request);
self.log_request(&request, &request_id);
// Set a request id.
// TODO: allow the user to supply this, through middleware?
let request_id = if let Some(request_id) = request.headers().get(REQUEST_ID_HEADER) {
request_id.to_str().unwrap().to_string()
} else {
let request_id = uuid::Uuid::new_v4().to_string();
let header = HeaderValue::from_str(&request_id).unwrap();
request.headers_mut().insert(REQUEST_ID_HEADER, header);
request_id
};
if log::log_enabled!(log::Level::Debug) {
let content_type = request
.headers()
.get("content-type")
.map(|v| v.to_str().unwrap());
if content_type == Some("application/json") {
let body = request.body().as_ref().unwrap().as_bytes().unwrap();
let body = String::from_utf8_lossy(body);
debug!(
"Sending request_id={}: {:?} with body {}",
request_id, request, body
);
} else {
debug!("Sending request_id={}: {:?}", request_id, request);
}
}
if with_retry {
self.send_with_retry_impl(client, request, request_id).await
} else {
let response = self let response = self
.sender .sender
.send(&client, request) .send(&client, request)
@@ -365,52 +421,28 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
); );
Ok((request_id, response)) Ok((request_id, response))
} }
/// Send the request using retries configured in the RetryConfig.
/// If retry_5xx is false, 5xx requests will not be retried regardless of the statuses configured
/// in the RetryConfig.
/// Since this requires arrow serialization, this is implemented here instead of in RestfulLanceDbClient
pub async fn send_with_retry(
&self,
req_builder: RequestBuilder,
mut make_body: Option<Box<dyn FnMut() -> Result<Body> + Send + 'static>>,
retry_5xx: bool,
) -> Result<(String, Response)> {
let retry_config = &self.retry_config;
let non_5xx_statuses = retry_config
.statuses
.iter()
.filter(|s| !s.is_server_error())
.cloned()
.collect::<Vec<_>>();
// clone and build the request to extract the request id
let tmp_req = req_builder.try_clone().ok_or_else(|| Error::Runtime {
message: "Attempted to retry a request that cannot be cloned".to_string(),
})?;
let (_, r) = tmp_req.build_split();
let mut r = r.unwrap();
let request_id = self.extract_request_id(&mut r);
let mut retry_counter = RetryCounter::new(retry_config, request_id.clone());
loop {
let mut req_builder = req_builder.try_clone().ok_or_else(|| Error::Runtime {
message: "Attempted to retry a request that cannot be cloned".to_string(),
})?;
// set the streaming body on the request builder after clone
if let Some(body_gen) = make_body.as_mut() {
let body = body_gen()?;
req_builder = req_builder.body(body);
} }
let (c, request) = req_builder.build_split(); async fn send_with_retry_impl(
let mut request = request.unwrap(); &self,
self.set_request_id(&mut request, &request_id.clone()); client: reqwest::Client,
self.log_request(&request, &request_id); req: Request,
request_id: String,
let response = self.sender.send(&c, request).await.map(|r| (r.status(), r)); ) -> Result<(String, Response)> {
let mut retry_counter = RetryCounter::new(&self.retry_config, request_id);
loop {
// This only works if the request body is not a stream. If it is
// a stream, we can't use the retry path. We would need to implement
// an outer retry.
let request = req.try_clone().ok_or_else(|| Error::Runtime {
message: "Attempted to retry a request that cannot be cloned".to_string(),
})?;
let response = self
.sender
.send(&client, request)
.await
.map(|r| (r.status(), r));
match response { match response {
Ok((status, response)) if status.is_success() => { Ok((status, response)) if status.is_success() => {
debug!( debug!(
@@ -419,10 +451,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
); );
return Ok((retry_counter.request_id, response)); return Ok((retry_counter.request_id, response));
} }
Ok((status, response)) Ok((status, response)) if self.retry_config.statuses.contains(&status) => {
if (retry_5xx && retry_config.statuses.contains(&status))
|| non_5xx_statuses.contains(&status) =>
{
let source = self let source = self
.check_response(&retry_counter.request_id, response) .check_response(&retry_counter.request_id, response)
.await .await
@@ -451,47 +480,6 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
} }
} }
fn log_request(&self, request: &Request, request_id: &String) {
if log::log_enabled!(log::Level::Debug) {
let content_type = request
.headers()
.get("content-type")
.map(|v| v.to_str().unwrap());
if content_type == Some("application/json") {
let body = request.body().as_ref().unwrap().as_bytes().unwrap();
let body = String::from_utf8_lossy(body);
debug!(
"Sending request_id={}: {:?} with body {}",
request_id, request, body
);
} else {
debug!("Sending request_id={}: {:?}", request_id, request);
}
}
}
/// Extract the request ID from the request headers.
/// If the request ID header is not set, this will generate a new one and set
/// it on the request headers
pub fn extract_request_id(&self, request: &mut Request) -> String {
// Set a request id.
// TODO: allow the user to supply this, through middleware?
let request_id = if let Some(request_id) = request.headers().get(REQUEST_ID_HEADER) {
request_id.to_str().unwrap().to_string()
} else {
let request_id = uuid::Uuid::new_v4().to_string();
self.set_request_id(request, &request_id);
request_id
};
request_id
}
/// Set the request ID header
pub fn set_request_id(&self, request: &mut Request, request_id: &str) {
let header = HeaderValue::from_str(request_id).unwrap();
request.headers_mut().insert(REQUEST_ID_HEADER, header);
}
pub async fn check_response(&self, request_id: &str, response: Response) -> Result<Response> { pub async fn check_response(&self, request_id: &str, response: Response) -> Result<Response> {
// Try to get the response text, but if that fails, just return the status code // Try to get the response text, but if that fails, just return the status code
let status = response.status(); let status = response.status();
@@ -513,6 +501,91 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
} }
} }
struct RetryCounter<'a> {
request_failures: u8,
connect_failures: u8,
read_failures: u8,
config: &'a ResolvedRetryConfig,
request_id: String,
}
impl<'a> RetryCounter<'a> {
fn new(config: &'a ResolvedRetryConfig, request_id: String) -> Self {
Self {
request_failures: 0,
connect_failures: 0,
read_failures: 0,
config,
request_id,
}
}
fn check_out_of_retries(
&self,
source: Box<dyn std::error::Error + Send + Sync>,
status_code: Option<reqwest::StatusCode>,
) -> Result<()> {
if self.request_failures >= self.config.retries
|| self.connect_failures >= self.config.connect_retries
|| self.read_failures >= self.config.read_retries
{
Err(Error::Retry {
request_id: self.request_id.clone(),
request_failures: self.request_failures,
max_request_failures: self.config.retries,
connect_failures: self.connect_failures,
max_connect_failures: self.config.connect_retries,
read_failures: self.read_failures,
max_read_failures: self.config.read_retries,
source,
status_code,
})
} else {
Ok(())
}
}
fn increment_request_failures(&mut self, source: crate::Error) -> Result<()> {
self.request_failures += 1;
let status_code = if let crate::Error::Http { status_code, .. } = &source {
*status_code
} else {
None
};
self.check_out_of_retries(Box::new(source), status_code)
}
fn increment_connect_failures(&mut self, source: reqwest::Error) -> Result<()> {
self.connect_failures += 1;
let status_code = source.status();
self.check_out_of_retries(Box::new(source), status_code)
}
fn increment_read_failures(&mut self, source: reqwest::Error) -> Result<()> {
self.read_failures += 1;
let status_code = source.status();
self.check_out_of_retries(Box::new(source), status_code)
}
fn next_sleep_time(&self) -> Duration {
let backoff = self.config.backoff_factor * (2.0f32.powi(self.request_failures as i32));
let jitter = rand::random::<f32>() * self.config.backoff_jitter;
let sleep_time = Duration::from_secs_f32(backoff + jitter);
debug!(
"Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}",
self.request_id,
self.connect_failures,
self.config.connect_retries,
self.request_failures,
self.config.retries,
self.read_failures,
self.config.read_retries,
sleep_time
);
sleep_time
}
}
pub trait RequestResultExt { pub trait RequestResultExt {
type Output; type Output;
fn err_to_http(self, request_id: String) -> Result<Self::Output>; fn err_to_http(self, request_id: String) -> Result<Self::Output>;

View File

@@ -52,10 +52,6 @@ impl ServerVersion {
pub fn support_multivector(&self) -> bool { pub fn support_multivector(&self) -> bool {
self.0 >= semver::Version::new(0, 2, 0) self.0 >= semver::Version::new(0, 2, 0)
} }
pub fn support_structural_fts(&self) -> bool {
self.0 >= semver::Version::new(0, 3, 0)
}
} }
pub const OPT_REMOTE_PREFIX: &str = "remote_database_"; pub const OPT_REMOTE_PREFIX: &str = "remote_database_";
@@ -255,7 +251,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
if let Some(start_after) = request.start_after { if let Some(start_after) = request.start_after {
req = req.query(&[("page_token", start_after)]); req = req.query(&[("page_token", start_after)]);
} }
let (request_id, rsp) = self.client.send_with_retry(req, None, true).await?; let (request_id, rsp) = self.client.send(req, true).await?;
let rsp = self.client.check_response(&request_id, rsp).await?; let rsp = self.client.check_response(&request_id, rsp).await?;
let version = parse_server_version(&request_id, &rsp)?; let version = parse_server_version(&request_id, &rsp)?;
let tables = rsp let tables = rsp
@@ -302,7 +298,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
.body(data_buffer) .body(data_buffer)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE); .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
let (request_id, rsp) = self.client.send(req).await?; let (request_id, rsp) = self.client.send(req, false).await?;
if rsp.status() == StatusCode::BAD_REQUEST { if rsp.status() == StatusCode::BAD_REQUEST {
let body = rsp.text().await.err_to_http(request_id.clone())?; let body = rsp.text().await.err_to_http(request_id.clone())?;
@@ -362,7 +358,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
let req = self let req = self
.client .client
.post(&format!("/v1/table/{}/describe/", request.name)); .post(&format!("/v1/table/{}/describe/", request.name));
let (request_id, rsp) = self.client.send_with_retry(req, None, true).await?; let (request_id, rsp) = self.client.send(req, true).await?;
if rsp.status() == StatusCode::NOT_FOUND { if rsp.status() == StatusCode::NOT_FOUND {
return Err(crate::Error::TableNotFound { name: request.name }); return Err(crate::Error::TableNotFound { name: request.name });
} }
@@ -383,7 +379,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
.client .client
.post(&format!("/v1/table/{}/rename/", current_name)); .post(&format!("/v1/table/{}/rename/", current_name));
let req = req.json(&serde_json::json!({ "new_table_name": new_name })); let req = req.json(&serde_json::json!({ "new_table_name": new_name }));
let (request_id, resp) = self.client.send(req).await?; let (request_id, resp) = self.client.send(req, false).await?;
self.client.check_response(&request_id, resp).await?; self.client.check_response(&request_id, resp).await?;
let table = self.table_cache.remove(current_name).await; let table = self.table_cache.remove(current_name).await;
if let Some(table) = table { if let Some(table) = table {
@@ -394,7 +390,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
async fn drop_table(&self, name: &str) -> Result<()> { async fn drop_table(&self, name: &str) -> Result<()> {
let req = self.client.post(&format!("/v1/table/{}/drop/", name)); let req = self.client.post(&format!("/v1/table/{}/drop/", name));
let (request_id, resp) = self.client.send(req).await?; let (request_id, resp) = self.client.send(req, true).await?;
self.client.check_response(&request_id, resp).await?; self.client.check_response(&request_id, resp).await?;
self.table_cache.remove(name).await; self.table_cache.remove(name).await;
Ok(()) Ok(())

View File

@@ -1,122 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use crate::remote::RetryConfig;
use crate::Error;
use log::debug;
use std::time::Duration;
pub struct RetryCounter<'a> {
pub request_failures: u8,
pub connect_failures: u8,
pub read_failures: u8,
pub config: &'a ResolvedRetryConfig,
pub request_id: String,
}
impl<'a> RetryCounter<'a> {
pub(crate) fn new(config: &'a ResolvedRetryConfig, request_id: String) -> Self {
Self {
request_failures: 0,
connect_failures: 0,
read_failures: 0,
config,
request_id,
}
}
fn check_out_of_retries(
&self,
source: Box<dyn std::error::Error + Send + Sync>,
status_code: Option<reqwest::StatusCode>,
) -> crate::Result<()> {
if self.request_failures >= self.config.retries
|| self.connect_failures >= self.config.connect_retries
|| self.read_failures >= self.config.read_retries
{
Err(Error::Retry {
request_id: self.request_id.clone(),
request_failures: self.request_failures,
max_request_failures: self.config.retries,
connect_failures: self.connect_failures,
max_connect_failures: self.config.connect_retries,
read_failures: self.read_failures,
max_read_failures: self.config.read_retries,
source,
status_code,
})
} else {
Ok(())
}
}
pub fn increment_request_failures(&mut self, source: crate::Error) -> crate::Result<()> {
self.request_failures += 1;
let status_code = if let crate::Error::Http { status_code, .. } = &source {
*status_code
} else {
None
};
self.check_out_of_retries(Box::new(source), status_code)
}
pub fn increment_connect_failures(&mut self, source: reqwest::Error) -> crate::Result<()> {
self.connect_failures += 1;
let status_code = source.status();
self.check_out_of_retries(Box::new(source), status_code)
}
pub fn increment_read_failures(&mut self, source: reqwest::Error) -> crate::Result<()> {
self.read_failures += 1;
let status_code = source.status();
self.check_out_of_retries(Box::new(source), status_code)
}
pub fn next_sleep_time(&self) -> Duration {
let backoff = self.config.backoff_factor * (2.0f32.powi(self.request_failures as i32));
let jitter = rand::random::<f32>() * self.config.backoff_jitter;
let sleep_time = Duration::from_secs_f32(backoff + jitter);
debug!(
"Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}",
self.request_id,
self.connect_failures,
self.config.connect_retries,
self.request_failures,
self.config.retries,
self.read_failures,
self.config.read_retries,
sleep_time
);
sleep_time
}
}
#[derive(Debug, Clone)]
pub struct ResolvedRetryConfig {
pub retries: u8,
pub connect_retries: u8,
pub read_retries: u8,
pub backoff_factor: f32,
pub backoff_jitter: f32,
pub statuses: Vec<reqwest::StatusCode>,
}
impl TryFrom<RetryConfig> for ResolvedRetryConfig {
type Error = Error;
fn try_from(retry_config: RetryConfig) -> crate::Result<Self> {
Ok(Self {
retries: retry_config.retries.unwrap_or(3),
connect_retries: retry_config.connect_retries.unwrap_or(3),
read_retries: retry_config.read_retries.unwrap_or(3),
backoff_factor: retry_config.backoff_factor.unwrap_or(0.25),
backoff_jitter: retry_config.backoff_jitter.unwrap_or(0.25),
statuses: retry_config
.statuses
.unwrap_or_else(|| vec![409, 429, 500, 502, 503, 504])
.into_iter()
.map(|status| reqwest::StatusCode::from_u16(status).unwrap())
.collect(),
})
}
}

Some files were not shown because too many files have changed in this diff Show More