Compare commits

..

3 Commits

Author SHA1 Message Date
Lei Xu
b57faf9835 rename gha task 2026-01-29 16:26:28 -08:00
Lei Xu
4800f31479 fixlint 2026-01-29 16:21:23 -08:00
Lei Xu
ec464ad01e remove pydantic 1 support 2026-01-29 16:18:56 -08:00
69 changed files with 1182 additions and 3668 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.26.1"
current_version = "0.24.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -42,7 +42,7 @@ jobs:
name: Report Workflow Failure
runs-on: ubuntu-latest
needs: [build]
if: always() && failure() && startsWith(github.ref, 'refs/tags/v')
if: always() && (github.event_name == 'release' || github.event_name == 'workflow_dispatch')
permissions:
contents: read
issues: write

View File

@@ -86,17 +86,16 @@ jobs:
You are running inside the lancedb repository on a GitHub Actions runner. Update the Lance dependency to version ${VERSION} and prepare a pull request for maintainers to review.
Follow these steps exactly:
1. Use script "ci/set_lance_version.py" to update Lance Rust dependencies. The script already refreshes Cargo metadata, so allow it to finish even if it takes time.
2. Update the Java lance-core dependency version in "java/pom.xml": change the "<lance-core.version>...</lance-core.version>" property to "${VERSION}".
3. Run "cargo clippy --workspace --tests --all-features -- -D warnings". If diagnostics appear, fix them yourself and rerun clippy until it exits cleanly. Do not skip any warnings.
4. After clippy succeeds, run "cargo fmt --all" to format the workspace.
5. Ensure the repository is clean except for intentional changes. Inspect "git status --short" and "git diff" to confirm the dependency update and any required fixes.
6. Create and switch to a new branch named "${BRANCH_NAME}" (replace any duplicated hyphens if necessary).
7. Stage all relevant files with "git add -A". Commit using the message "${COMMIT_TYPE}: update lance dependency to v${VERSION}".
8. Push the branch to origin. If the remote branch already exists, delete it first with "gh api -X DELETE repos/lancedb/lancedb/git/refs/heads/${BRANCH_NAME}" then push with "git push origin ${BRANCH_NAME}". Do NOT use "git push --force" or "git push -f".
9. env "GH_TOKEN" is available, use "gh" tools for github related operations like creating pull request.
10. Create a pull request targeting "main" with title "${COMMIT_TYPE}: update lance dependency to v${VERSION}". First, write the PR body to /tmp/pr-body.md using a heredoc (cat <<'EOF' > /tmp/pr-body.md). The body should summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}). Then run "gh pr create --body-file /tmp/pr-body.md".
11. After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results.
1. Use script "ci/set_lance_version.py" to update Lance dependencies. The script already refreshes Cargo metadata, so allow it to finish even if it takes time.
2. Run "cargo clippy --workspace --tests --all-features -- -D warnings". If diagnostics appear, fix them yourself and rerun clippy until it exits cleanly. Do not skip any warnings.
3. After clippy succeeds, run "cargo fmt --all" to format the workspace.
4. Ensure the repository is clean except for intentional changes. Inspect "git status --short" and "git diff" to confirm the dependency update and any required fixes.
5. Create and switch to a new branch named "${BRANCH_NAME}" (replace any duplicated hyphens if necessary).
6. Stage all relevant files with "git add -A". Commit using the message "${COMMIT_TYPE}: update lance dependency to v${VERSION}".
7. Push the branch to origin. If the branch already exists, force-push your changes.
8. env "GH_TOKEN" is available, use "gh" tools for github related operations like creating pull request.
9. Create a pull request targeting "main" with title "${COMMIT_TYPE}: update lance dependency to v${VERSION}". First, write the PR body to /tmp/pr-body.md using a heredoc (cat <<'EOF' > /tmp/pr-body.md). The body should summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}). Then run "gh pr create --body-file /tmp/pr-body.md".
10. After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results.
Constraints:
- Use bash commands; avoid modifying GitHub workflow files other than through the scripted task above.

View File

@@ -8,7 +8,6 @@ on:
paths:
- Cargo.toml
- nodejs/**
- docs/src/js/**
- .github/workflows/nodejs.yml
- docker-compose.yml

View File

@@ -318,7 +318,7 @@ jobs:
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 24
node-version: 20
cache: npm
cache-dependency-path: nodejs/package-lock.json
registry-url: "https://registry.npmjs.org"
@@ -348,9 +348,9 @@ jobs:
run: find npm
- name: Publish
env:
NODE_AUTH_TOKEN: ${{ secrets.LANCEDB_NPM_REGISTRY_TOKEN }}
DRY_RUN: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: |
npm config set provenance true
ARGS="--access public"
if [[ $DRY_RUN == "true" ]]; then
ARGS="$ARGS --dry-run"
@@ -363,7 +363,7 @@ jobs:
name: Report Workflow Failure
runs-on: ubuntu-latest
needs: [build-lancedb, test-lancedb, publish]
if: always() && failure() && startsWith(github.ref, 'refs/tags/v')
if: always() && (github.event_name == 'release' || github.event_name == 'workflow_dispatch')
permissions:
contents: read
issues: write

View File

@@ -181,7 +181,7 @@ jobs:
permissions:
contents: read
issues: write
if: always() && failure() && startsWith(github.ref, 'refs/tags/python-v')
if: always() && (github.event_name == 'release' || github.event_name == 'workflow_dispatch')
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/create-failure-issue

View File

@@ -25,7 +25,7 @@ jobs:
lint:
name: "Lint"
timeout-minutes: 30
runs-on: "ubuntu-22.04"
runs-on: "ubuntu-24.04"
defaults:
run:
shell: bash
@@ -195,7 +195,7 @@ jobs:
# Make sure wheels are not included in the Rust cache
- name: Delete wheels
run: rm -rf target/wheels
pydantic1x:
min-deps:
timeout-minutes: 30
runs-on: "ubuntu-24.04"
defaults:
@@ -217,7 +217,6 @@ jobs:
python-version: "3.10"
- name: Install lancedb
run: |
pip install "pydantic<2"
pip install pyarrow==16
pip install --extra-index-url https://pypi.fury.io/lance-format/ --extra-index-url https://pypi.fury.io/lancedb/ -e .[tests]
pip install tantivy

853
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -15,40 +15,39 @@ categories = ["database-implementations"]
rust-version = "1.88.0"
[workspace.dependencies]
lance = { "version" = "=2.0.0", default-features = false }
lance-core = "=2.0.0"
lance-datagen = "=2.0.0"
lance-file = "=2.0.0"
lance-io = { "version" = "=2.0.0", default-features = false }
lance-index = "=2.0.0"
lance-linalg = "=2.0.0"
lance-namespace = "=2.0.0"
lance-namespace-impls = { "version" = "=2.0.0", default-features = false }
lance-table = "=2.0.0"
lance-testing = "=2.0.0"
lance-datafusion = "=2.0.0"
lance-encoding = "=2.0.0"
lance-arrow = "=2.0.0"
lance = { "version" = "=1.0.4", default-features = false, "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=1.0.4", default-features = false, "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=1.0.4", default-features = false, "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "57.2", optional = false }
arrow-array = "57.2"
arrow-data = "57.2"
arrow-ipc = "57.2"
arrow-ord = "57.2"
arrow-schema = "57.2"
arrow-select = "57.2"
arrow-cast = "57.2"
arrow = { version = "56.2", optional = false }
arrow-array = "56.2"
arrow-data = "56.2"
arrow-ipc = "56.2"
arrow-ord = "56.2"
arrow-schema = "56.2"
arrow-select = "56.2"
arrow-cast = "56.2"
async-trait = "0"
datafusion = { version = "51.0", default-features = false }
datafusion-catalog = "51.0"
datafusion-common = { version = "51.0", default-features = false }
datafusion-execution = "51.0"
datafusion-expr = "51.0"
datafusion-physical-plan = "51.0"
datafusion-physical-expr = "51.0"
datafusion = { version = "50.1", default-features = false }
datafusion-catalog = "50.1"
datafusion-common = { version = "50.1", default-features = false }
datafusion-execution = "50.1"
datafusion-expr = "50.1"
datafusion-physical-plan = "50.1"
env_logger = "0.11"
half = { "version" = "2.7.1", default-features = false, features = [
half = { "version" = "2.6.0", default-features = false, features = [
"num-traits",
] }
futures = "0"

View File

@@ -66,7 +66,7 @@ Follow the [Quickstart](https://lancedb.com/docs/quickstart/) doc to set up Lanc
| Python SDK | https://lancedb.github.io/lancedb/python/python/ |
| Typescript SDK | https://lancedb.github.io/lancedb/js/globals/ |
| Rust SDK | https://docs.rs/lancedb/latest/lancedb/index.html |
| REST API | https://docs.lancedb.com/api-reference/rest |
| REST API | https://docs.lancedb.com/api-reference/introduction |
## **Join Us and Contribute**

View File

@@ -1,62 +0,0 @@
# VoyageAI Embeddings
Voyage AI provides cutting-edge embedding and rerankers.
Using voyageai API requires voyageai package, which can be installed using `pip install voyageai`. Voyage AI embeddings are used to generate embeddings for text data. The embeddings can be used for various tasks like semantic search, clustering, and classification.
You also need to set the `VOYAGE_API_KEY` environment variable to use the VoyageAI API.
Supported models are:
**Voyage-4 Series (Latest)**
- voyage-4 (1024 dims, general-purpose and multilingual retrieval, 320K batch tokens)
- voyage-4-lite (1024 dims, optimized for latency and cost, 1M batch tokens)
- voyage-4-large (1024 dims, best retrieval quality, 120K batch tokens)
**Voyage-3 Series**
- voyage-3
- voyage-3-lite
**Domain-Specific Models**
- voyage-finance-2
- voyage-multilingual-2
- voyage-law-2
- voyage-code-2
Supported parameters (to be passed in `create` method) are:
| Parameter | Type | Default Value | Description |
|---|---|--------|---------|
| `name` | `str` | `None` | The model ID of the model to use. Supported base models for Text Embeddings: voyage-4, voyage-4-lite, voyage-4-large, voyage-3, voyage-3-lite, voyage-finance-2, voyage-multilingual-2, voyage-law-2, voyage-code-2 |
| `input_type` | `str` | `None` | Type of the input text. Default to None. Other options: query, document. |
| `truncation` | `bool` | `True` | Whether to truncate the input texts to fit within the context length. |
Usage Example:
```python
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunctionRegistry
voyageai = EmbeddingFunctionRegistry
.get_instance()
.get("voyageai")
.create(name="voyage-3")
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
data = [ { "text": "hello world" },
{ "text": "goodbye world" }]
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(data)
```

View File

@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
<dependency>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-core</artifactId>
<version>0.26.1</version>
<version>0.24.1</version>
</dependency>
```

View File

@@ -367,27 +367,6 @@ Use [Table.listIndices](Table.md#listindices) to find the names of the indices.
***
### initialStorageOptions()
```ts
abstract initialStorageOptions(): Promise<undefined | null | Record<string, string>>
```
Get the initial storage options that were passed in when opening this table.
For dynamically refreshed options (e.g., credential vending), use
[Table.latestStorageOptions](Table.md#lateststorageoptions).
Warning: This is an internal API and the return value is subject to change.
#### Returns
`Promise`&lt;`undefined` \| `null` \| `Record`&lt;`string`, `string`&gt;&gt;
The storage options, or undefined if no storage options were configured.
***
### isOpen()
```ts
@@ -402,28 +381,6 @@ Return true if the table has not been closed
***
### latestStorageOptions()
```ts
abstract latestStorageOptions(): Promise<undefined | null | Record<string, string>>
```
Get the latest storage options, refreshing from provider if configured.
This method is useful for credential vending scenarios where storage options
may be refreshed dynamically. If no dynamic provider is configured, this
returns the initial static options.
Warning: This is an internal API and the return value is subject to change.
#### Returns
`Promise`&lt;`undefined` \| `null` \| `Record`&lt;`string`, `string`&gt;&gt;
The storage options, or undefined if no storage options were configured.
***
### listIndices()
```ts
@@ -748,11 +705,8 @@ Create a query that returns a subset of the rows in the table.
#### Parameters
* **rowIds**: readonly (`number` \| `bigint`)[]
* **rowIds**: `number`[]
The row ids of the rows to return.
Row ids returned by `withRowId()` are `bigint`, so `bigint[]` is supported.
For convenience / backwards compatibility, `number[]` is also accepted (for
small row ids that fit in a safe integer).
#### Returns

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.26.1-final.0</version>
<version>0.24.1-final.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.26.1-final.0</version>
<version>0.24.1-final.0</version>
<packaging>pom</packaging>
<name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description>
@@ -28,7 +28,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version>
<lance-core.version>2.0.0</lance-core.version>
<lance-core.version>1.0.0-rc.2</lance-core.version>
<spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
@@ -292,12 +292,11 @@
<plugin>
<groupId>org.sonatype.central</groupId>
<artifactId>central-publishing-maven-plugin</artifactId>
<version>0.8.0</version>
<version>0.4.0</version>
<extensions>true</extensions>
<configuration>
<publishingServerId>ossrh</publishingServerId>
<tokenAuth>true</tokenAuth>
<autoPublish>true</autoPublish>
</configuration>
</plugin>
<plugin>

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.26.1"
version = "0.24.1"
license.workspace = true
description.workspace = true
repository.workspace = true

View File

@@ -312,66 +312,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
expect(res.getChild("id")?.toJSON()).toEqual([2, 3]);
});
it("should support takeRowIds with bigint array", async () => {
await table.add([{ id: 1 }, { id: 2 }, { id: 3 }]);
// Get actual row IDs using withRowId()
const allRows = await table.query().withRowId().toArray();
const rowIds = allRows.map((row) => row._rowid) as bigint[];
// Verify row IDs are bigint
expect(typeof rowIds[0]).toBe("bigint");
// Use takeRowIds with bigint array (the main use case from issue #2722)
const res = await table.takeRowIds([rowIds[0], rowIds[2]]).toArray();
expect(res.map((r) => r.id)).toEqual([1, 3]);
});
it("should support takeRowIds with number array for backwards compatibility", async () => {
await table.add([{ id: 1 }, { id: 2 }, { id: 3 }]);
// Small row IDs can be passed as numbers
const res = await table.takeRowIds([0, 2]).toArray();
expect(res.map((r) => r.id)).toEqual([1, 3]);
});
it("should support takeRowIds with mixed bigint and number array", async () => {
await table.add([{ id: 1 }, { id: 2 }, { id: 3 }]);
// Mixed array of bigint and number
const res = await table.takeRowIds([0n, 1, 2n]).toArray();
expect(res.map((r) => r.id)).toEqual([1, 2, 3]);
});
it("should throw for non-integer number in takeRowIds", () => {
expect(() => table.takeRowIds([1.5])).toThrow(
"Row id must be an integer (or bigint)",
);
expect(() => table.takeRowIds([0, 1.1, 2])).toThrow(
"Row id must be an integer (or bigint)",
);
});
it("should throw for negative number in takeRowIds", () => {
expect(() => table.takeRowIds([-1])).toThrow("Row id cannot be negative");
expect(() => table.takeRowIds([0, -5, 2])).toThrow(
"Row id cannot be negative",
);
});
it("should throw for unsafe large number in takeRowIds", () => {
// Number.MAX_SAFE_INTEGER + 1 is not safe
const unsafeNumber = Number.MAX_SAFE_INTEGER + 1;
expect(() => table.takeRowIds([unsafeNumber])).toThrow(
"Row id is too large for number; use bigint instead",
);
});
it("should reject negative bigint in takeRowIds", async () => {
await table.add([{ id: 1 }]);
// Negative bigint should be rejected by the Rust layer
expect(() => {
table.takeRowIds([-1n]);
}).toThrow("Row id cannot be negative");
});
it("should return the table as an instance of an arrow table", async () => {
const arrowTbl = await table.toArrow();
expect(arrowTbl).toBeInstanceOf(ArrowTable);
@@ -1580,9 +1520,9 @@ describe("when optimizing a dataset", () => {
it("delete unverified", async () => {
const version = await table.version();
const versionFile = `${tmpDir.name}/${table.name}.lance/_versions/${String(
18446744073709551615n - (BigInt(version) - 1n),
).padStart(20, "0")}.manifest`;
const versionFile = `${tmpDir.name}/${table.name}.lance/_versions/${
version - 1
}.manifest`;
fs.rmSync(versionFile);
let stats = await table.optimize({ deleteUnverified: false });

View File

@@ -347,13 +347,9 @@ export abstract class Table {
/**
* Create a query that returns a subset of the rows in the table.
* @param rowIds The row ids of the rows to return.
*
* Row ids returned by `withRowId()` are `bigint`, so `bigint[]` is supported.
* For convenience / backwards compatibility, `number[]` is also accepted (for
* small row ids that fit in a safe integer).
* @returns A builder that can be used to parameterize the query.
*/
abstract takeRowIds(rowIds: readonly (bigint | number)[]): TakeQuery;
abstract takeRowIds(rowIds: number[]): TakeQuery;
/**
* Create a search query to find the nearest neighbors
@@ -542,35 +538,6 @@ export abstract class Table {
*
*/
abstract stats(): Promise<TableStatistics>;
/**
* Get the initial storage options that were passed in when opening this table.
*
* For dynamically refreshed options (e.g., credential vending), use
* {@link Table.latestStorageOptions}.
*
* Warning: This is an internal API and the return value is subject to change.
*
* @returns The storage options, or undefined if no storage options were configured.
*/
abstract initialStorageOptions(): Promise<
Record<string, string> | null | undefined
>;
/**
* Get the latest storage options, refreshing from provider if configured.
*
* This method is useful for credential vending scenarios where storage options
* may be refreshed dynamically. If no dynamic provider is configured, this
* returns the initial static options.
*
* Warning: This is an internal API and the return value is subject to change.
*
* @returns The storage options, or undefined if no storage options were configured.
*/
abstract latestStorageOptions(): Promise<
Record<string, string> | null | undefined
>;
}
export class LocalTable extends Table {
@@ -719,24 +686,8 @@ export class LocalTable extends Table {
return new TakeQuery(this.inner.takeOffsets(offsets));
}
takeRowIds(rowIds: readonly (bigint | number)[]): TakeQuery {
const ids = rowIds.map((id) => {
if (typeof id === "bigint") {
return id;
}
if (!Number.isInteger(id)) {
throw new Error("Row id must be an integer (or bigint)");
}
if (id < 0) {
throw new Error("Row id cannot be negative");
}
if (!Number.isSafeInteger(id)) {
throw new Error("Row id is too large for number; use bigint instead");
}
return BigInt(id);
});
return new TakeQuery(this.inner.takeRowIds(ids));
takeRowIds(rowIds: number[]): TakeQuery {
return new TakeQuery(this.inner.takeRowIds(rowIds));
}
query(): Query {
@@ -907,18 +858,6 @@ export class LocalTable extends Table {
return await this.inner.stats();
}
async initialStorageOptions(): Promise<
Record<string, string> | null | undefined
> {
return await this.inner.initialStorageOptions();
}
async latestStorageOptions(): Promise<
Record<string, string> | null | undefined
> {
return await this.inner.latestStorageOptions();
}
mergeInsert(on: string | string[]): MergeInsertBuilder {
on = Array.isArray(on) ? on : [on];
return new MergeInsertBuilder(this.inner.mergeInsert(on), this.schema());

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.26.1",
"version": "0.24.1",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",
@@ -8,9 +8,5 @@
"license": "Apache-2.0",
"engines": {
"node": ">= 18"
},
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb"
}
}

View File

@@ -0,0 +1,3 @@
# `@lancedb/lancedb-darwin-x64`
This is the **x86_64-apple-darwin** binary for `@lancedb/lancedb`

View File

@@ -0,0 +1,12 @@
{
"name": "@lancedb/lancedb-darwin-x64",
"version": "0.24.1",
"os": ["darwin"],
"cpu": ["x64"],
"main": "lancedb.darwin-x64.node",
"files": ["lancedb.darwin-x64.node"],
"license": "Apache-2.0",
"engines": {
"node": ">= 18"
}
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.26.1",
"version": "0.24.1",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",
@@ -9,9 +9,5 @@
"engines": {
"node": ">= 18"
},
"libc": ["glibc"],
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb"
}
"libc": ["glibc"]
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.26.1",
"version": "0.24.1",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",
@@ -9,9 +9,5 @@
"engines": {
"node": ">= 18"
},
"libc": ["musl"],
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb"
}
"libc": ["musl"]
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.26.1",
"version": "0.24.1",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",
@@ -9,9 +9,5 @@
"engines": {
"node": ">= 18"
},
"libc": ["glibc"],
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb"
}
"libc": ["glibc"]
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.26.1",
"version": "0.24.1",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",
@@ -9,9 +9,5 @@
"engines": {
"node": ">= 18"
},
"libc": ["musl"],
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb"
}
"libc": ["musl"]
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.26.1",
"version": "0.24.1",
"os": [
"win32"
],
@@ -14,9 +14,5 @@
"license": "Apache-2.0",
"engines": {
"node": ">= 18"
},
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb"
}
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.26.1",
"version": "0.24.1",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",
@@ -8,9 +8,5 @@
"license": "Apache-2.0",
"engines": {
"node": ">= 18"
},
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb"
}
}

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.26.1",
"version": "0.24.1",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.26.1",
"version": "0.24.1",
"cpu": [
"x64",
"arm64"

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.26.1",
"version": "0.24.1",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",
@@ -25,6 +25,7 @@
"triples": {
"defaults": false,
"additional": [
"x86_64-apple-darwin",
"aarch64-apple-darwin",
"x86_64-unknown-linux-gnu",
"aarch64-unknown-linux-gnu",
@@ -36,10 +37,6 @@
}
},
"license": "Apache-2.0",
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb"
},
"devDependencies": {
"@aws-sdk/client-dynamodb": "^3.33.0",
"@aws-sdk/client-kms": "^3.33.0",

View File

@@ -166,19 +166,6 @@ impl Table {
Ok(stats.into())
}
#[napi(catch_unwind)]
pub async fn initial_storage_options(&self) -> napi::Result<Option<HashMap<String, String>>> {
Ok(self.inner_ref()?.initial_storage_options().await)
}
#[napi(catch_unwind)]
pub async fn latest_storage_options(&self) -> napi::Result<Option<HashMap<String, String>>> {
self.inner_ref()?
.latest_storage_options()
.await
.default_error()
}
#[napi(catch_unwind)]
pub async fn update(
&self,
@@ -221,24 +208,18 @@ impl Table {
}
#[napi(catch_unwind)]
pub fn take_row_ids(&self, row_ids: Vec<BigInt>) -> napi::Result<TakeQuery> {
pub fn take_row_ids(&self, row_ids: Vec<i64>) -> napi::Result<TakeQuery> {
Ok(TakeQuery::new(
self.inner_ref()?.take_row_ids(
row_ids
.into_iter()
.map(|id| {
let (negative, value, lossless) = id.get_u64();
if negative {
Err(napi::Error::from_reason(
"Row id cannot be negative".to_string(),
.map(|o| {
u64::try_from(o).map_err(|e| {
napi::Error::from_reason(format!(
"Failed to convert row id to u64: {}",
e
))
} else if !lossless {
Err(napi::Error::from_reason(
"Row id is too large to fit in u64".to_string(),
))
} else {
Ok(value)
}
})
})
.collect::<Result<Vec<_>>>()?,
),

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.29.2"
current_version = "0.27.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.29.2"
version = "0.27.1"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true
@@ -14,15 +14,15 @@ name = "_lancedb"
crate-type = ["cdylib"]
[dependencies]
arrow = { version = "57.2", features = ["pyarrow"] }
arrow = { version = "56.2", features = ["pyarrow"] }
async-trait = "0.1"
lancedb = { path = "../rust/lancedb", default-features = false }
lance-core.workspace = true
lance-namespace.workspace = true
lance-io.workspace = true
env_logger.workspace = true
pyo3 = { version = "0.26", features = ["extension-module", "abi3-py39"] }
pyo3-async-runtimes = { version = "0.26", features = [
pyo3 = { version = "0.25", features = ["extension-module", "abi3-py310"] }
pyo3-async-runtimes = { version = "0.25", features = [
"attributes",
"tokio-runtime",
] }
@@ -32,9 +32,9 @@ snafu.workspace = true
tokio = { version = "1.40", features = ["sync"] }
[build-dependencies]
pyo3-build-config = { version = "0.26", features = [
pyo3-build-config = { version = "0.25", features = [
"extension-module",
"abi3-py39",
"abi3-py310",
] }
[features]

View File

@@ -8,7 +8,7 @@ dependencies = [
"overrides>=0.7; python_version<'3.12'",
"packaging",
"pyarrow>=16",
"pydantic>=1.10",
"pydantic>=2",
"tqdm>=4.27.0",
"lance-namespace>=0.3.2"
]

View File

@@ -180,8 +180,6 @@ class Table:
delete_unverified: Optional[bool] = None,
) -> OptimizeStats: ...
async def uri(self) -> str: ...
async def initial_storage_options(self) -> Optional[Dict[str, str]]: ...
async def latest_storage_options(self) -> Optional[Dict[str, str]]: ...
@property
def tags(self) -> Tags: ...
def query(self) -> Query: ...

View File

@@ -22,12 +22,7 @@ class BackgroundEventLoop:
self.thread.start()
def run(self, future):
concurrent_future = asyncio.run_coroutine_threadsafe(future, self.loop)
try:
return concurrent_future.result()
except BaseException:
concurrent_future.cancel()
raise
return asyncio.run_coroutine_threadsafe(future, self.loop).result()
LOOP = BackgroundEventLoop()

View File

@@ -275,7 +275,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
"""
Convert image inputs to PIL Images.
"""
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
requests = attempt_import_or_raise("requests", "requests")
images = self.sanitize_input(images)
pil_images = []
@@ -285,12 +285,12 @@ class ColPaliEmbeddings(EmbeddingFunction):
if image.startswith(("http://", "https://")):
response = requests.get(image, timeout=10)
response.raise_for_status()
pil_images.append(PIL_Image.open(io.BytesIO(response.content)))
pil_images.append(PIL.Image.open(io.BytesIO(response.content)))
else:
with PIL_Image.open(image) as im:
with PIL.Image.open(image) as im:
pil_images.append(im.copy())
elif isinstance(image, bytes):
pil_images.append(PIL_Image.open(io.BytesIO(image)))
pil_images.append(PIL.Image.open(io.BytesIO(image)))
else:
# Assume it's a PIL Image; will raise if invalid
pil_images.append(image)

View File

@@ -77,8 +77,8 @@ class JinaEmbeddings(EmbeddingFunction):
if isinstance(inputs, list):
inputs = inputs
else:
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(inputs, PIL_Image.Image):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(inputs, PIL.Image.Image):
inputs = [inputs]
return inputs
@@ -89,13 +89,13 @@ class JinaEmbeddings(EmbeddingFunction):
elif isinstance(image, (str, Path)):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if parsed.scheme == "file":
pil_image = PIL_Image.open(parsed.path)
pil_image = PIL.Image.open(parsed.path)
elif parsed.scheme == "":
pil_image = PIL_Image.open(image if os.name == "nt" else parsed.path)
pil_image = PIL.Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"):
pil_image = PIL_Image.open(io.BytesIO(url_retrieve(image)))
pil_image = PIL.Image.open(io.BytesIO(url_retrieve(image)))
else:
raise NotImplementedError("Only local and http(s) urls are supported")
buffered = io.BytesIO()
@@ -103,9 +103,9 @@ class JinaEmbeddings(EmbeddingFunction):
image_bytes = buffered.getvalue()
image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")}
else:
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, PIL_Image.Image):
if isinstance(image, PIL.Image.Image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
@@ -136,9 +136,9 @@ class JinaEmbeddings(EmbeddingFunction):
elif isinstance(query, (Path, bytes)):
return [self.generate_image_embedding(query)]
else:
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL_Image.Image):
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError(

View File

@@ -71,8 +71,8 @@ class OpenClipEmbeddings(EmbeddingFunction):
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(query, PIL_Image.Image):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("OpenClip supports str or PIL Image as query")
@@ -145,20 +145,20 @@ class OpenClipEmbeddings(EmbeddingFunction):
return self._encode_and_normalize_image(image)
def _to_pil(self, image: Union[str, bytes]):
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, bytes):
return PIL_Image.open(io.BytesIO(image))
if isinstance(image, PIL_Image.Image):
return PIL.Image.open(io.BytesIO(image))
if isinstance(image, PIL.Image.Image):
return image
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
if parsed.scheme == "file":
return PIL_Image.open(parsed.path)
return PIL.Image.open(parsed.path)
elif parsed.scheme == "":
return PIL_Image.open(image if os.name == "nt" else parsed.path)
return PIL.Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"):
return PIL_Image.open(io.BytesIO(url_retrieve(image)))
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
else:
raise NotImplementedError("Only local and http(s) urls are supported")

View File

@@ -56,8 +56,8 @@ class SigLipEmbeddings(EmbeddingFunction):
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(query, PIL_Image.Image):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("SigLIP supports str or PIL Image as query")
@@ -127,21 +127,21 @@ class SigLipEmbeddings(EmbeddingFunction):
return image_features.cpu().detach().numpy().squeeze()
def _to_pil(self, image: Union[str, bytes, "PIL.Image.Image"]):
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(image, PIL_Image.Image):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, PIL.Image.Image):
return image.convert("RGB") if image.mode != "RGB" else image
elif isinstance(image, bytes):
return PIL_Image.open(io.BytesIO(image)).convert("RGB")
return PIL.Image.open(io.BytesIO(image)).convert("RGB")
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
if parsed.scheme == "file":
return PIL_Image.open(parsed.path).convert("RGB")
return PIL.Image.open(parsed.path).convert("RGB")
elif parsed.scheme == "":
path = image if os.name == "nt" else parsed.path
return PIL_Image.open(path).convert("RGB")
return PIL.Image.open(path).convert("RGB")
elif parsed.scheme.startswith("http"):
image_bytes = url_retrieve(image)
return PIL_Image.open(io.BytesIO(image_bytes)).convert("RGB")
return PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB")
else:
raise NotImplementedError("Only local and http(s) urls are supported")
else:

View File

@@ -21,9 +21,6 @@ if TYPE_CHECKING:
# Token limits for different VoyageAI models
VOYAGE_TOTAL_TOKEN_LIMITS = {
"voyage-4": 320_000,
"voyage-4-lite": 1_000_000,
"voyage-4-large": 120_000,
"voyage-context-3": 32_000,
"voyage-3.5-lite": 1_000_000,
"voyage-3.5": 320_000,
@@ -64,7 +61,7 @@ def is_video_path(path: Path) -> bool:
def transform_input(input_data: Union[str, bytes, Path]):
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(input_data, str):
if is_valid_url(input_data):
if is_video_url(input_data):
@@ -73,7 +70,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
content = {"type": "image_url", "image_url": input_data}
else:
content = {"type": "text", "text": input_data}
elif isinstance(input_data, PIL_Image.Image):
elif isinstance(input_data, PIL.Image.Image):
buffered = BytesIO()
input_data.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -82,7 +79,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
"image_base64": "data:image/jpeg;base64," + img_str,
}
elif isinstance(input_data, bytes):
img = PIL_Image.open(BytesIO(input_data))
img = PIL.Image.open(BytesIO(input_data))
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -101,7 +98,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
"video_base64": video_str,
}
else:
img = PIL_Image.open(input_data)
img = PIL.Image.open(input_data)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -119,8 +116,8 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
"""
Sanitize the input to the embedding function.
"""
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(inputs, (str, bytes, Path, PIL_Image.Image)):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)):
inputs = [inputs]
elif isinstance(inputs, list):
pass # Already a list, use as-is
@@ -133,7 +130,7 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
f"Input type {type(inputs)} not allowed with multimodal model."
)
if not all(isinstance(x, (str, bytes, Path, PIL_Image.Image)) for x in inputs):
if not all(isinstance(x, (str, bytes, Path, PIL.Image.Image)) for x in inputs):
raise ValueError("Each input should be either str, bytes, Path or Image.")
return [transform_input(i) for i in inputs]
@@ -170,9 +167,6 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
name: str
The name of the model to use. List of acceptable models:
* voyage-4 (1024 dims, general-purpose and multilingual retrieval)
* voyage-4-lite (1024 dims, optimized for latency and cost)
* voyage-4-large (1024 dims, best retrieval quality)
* voyage-context-3
* voyage-3.5
* voyage-3.5-lite
@@ -221,9 +215,6 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
_FLEXIBLE_DIM_MODELS: ClassVar[list] = ["voyage-multimodal-3.5"]
_VALID_DIMENSIONS: ClassVar[list] = [256, 512, 1024, 2048]
text_embedding_models: list = [
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
"voyage-3.5",
"voyage-3.5-lite",
"voyage-3",
@@ -261,9 +252,6 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
elif self.name == "voyage-code-2":
return 1536
elif self.name in [
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
"voyage-context-3",
"voyage-3.5",
"voyage-3.5-lite",

View File

@@ -6,7 +6,6 @@
from __future__ import annotations
import inspect
import sys
import types
from abc import ABC, abstractmethod
from datetime import date, datetime
@@ -141,14 +140,6 @@ def Vector(
raise TypeError("A list of numbers or numpy.ndarray is needed")
return cls(v)
if PYDANTIC_VERSION.major < 2:
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["items"] = {"type": "number"}
field_schema["maxItems"] = dim
field_schema["minItems"] = dim
return FixedSizeList
@@ -226,26 +217,14 @@ def MultiVector(
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v1
@classmethod
def validate(cls, v):
if not isinstance(v, (list, range)):
raise TypeError("A list of vectors is needed")
for vec in v:
if not isinstance(vec, (list, range, np.ndarray)) or len(vec) != dim:
raise TypeError(f"Each vector must be a list of {dim} numbers")
return cls(v)
if PYDANTIC_VERSION.major < 2:
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["items"] = {
"type": "array",
"items": {"type": "number"},
"minItems": dim,
"maxItems": dim,
}
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["items"] = {
"type": "array",
"items": {"type": "number"},
"minItems": dim,
"maxItems": dim,
}
return MultiVectorList
@@ -281,20 +260,10 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
)
if PYDANTIC_VERSION.major < 2:
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
return [
_pydantic_to_field(name, field) for name, field in model.__fields__.items()
]
else:
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
return [
_pydantic_to_field(name, field)
for name, field in model.model_fields.items()
]
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
return [
_pydantic_to_field(name, field) for name, field in model.model_fields.items()
]
def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
@@ -334,7 +303,7 @@ def _unwrap_optional_annotation(annotation: Any) -> Any | None:
non_none = [arg for arg in args if arg is not type(None)]
if len(non_none) == 1 and len(non_none) != len(args):
return non_none[0]
elif sys.version_info >= (3, 10) and isinstance(annotation, types.UnionType):
elif isinstance(annotation, types.UnionType):
args = annotation.__args__
non_none = [arg for arg in args if arg is not type(None)]
if len(non_none) == 1 and len(non_none) != len(args):
@@ -367,7 +336,7 @@ def is_nullable(field: FieldInfo) -> bool:
if origin == Union:
if any(typ is type(None) for typ in args):
return True
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
elif isinstance(field.annotation, types.UnionType):
args = field.annotation.__args__
for typ in args:
if typ is type(None):
@@ -474,8 +443,6 @@ class LanceModel(pydantic.BaseModel):
@classmethod
def safe_get_fields(cls):
if PYDANTIC_VERSION.major < 2:
return cls.__fields__
return cls.model_fields
@classmethod
@@ -518,18 +485,8 @@ def get_extras(field_info: FieldInfo, key: str) -> Any:
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
if PYDANTIC_VERSION.major < 2:
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
"""
Convert a Pydantic model to a dictionary.
"""
return model.dict()
else:
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
"""
Convert a Pydantic model to a dictionary.
"""
return model.model_dump()
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
"""
Convert a Pydantic model to a dictionary.
"""
return model.model_dump()

View File

@@ -961,27 +961,22 @@ class LanceQueryBuilder(ABC):
>>> query = [100, 100]
>>> plan = table.search(query).analyze_plan()
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
AnalyzeExec verbose=true, elapsed=..., metrics=...
TracedExec, elapsed=..., metrics=...
ProjectionExec: elapsed=..., expr=[...],
metrics=[output_rows=..., elapsed_compute=..., output_bytes=...]
GlobalLimitExec: elapsed=..., skip=0, fetch=10,
metrics=[output_rows=..., elapsed_compute=..., output_bytes=...]
FilterExec: elapsed=..., _distance@2 IS NOT NULL, metrics=[...]
SortExec: elapsed=..., TopK(fetch=10), expr=[...],
AnalyzeExec verbose=true, metrics=[], cumulative_cpu=...
TracedExec, metrics=[], cumulative_cpu=...
ProjectionExec: expr=[...], metrics=[...], cumulative_cpu=...
GlobalLimitExec: skip=0, fetch=10, metrics=[...], cumulative_cpu=...
FilterExec: _distance@2 IS NOT NULL,
metrics=[output_rows=..., elapsed_compute=...], cumulative_cpu=...
SortExec: TopK(fetch=10), expr=[...],
preserve_partitioning=[...],
metrics=[output_rows=..., elapsed_compute=...,
output_bytes=..., row_replacements=...]
KNNVectorDistance: elapsed=..., metric=l2,
metrics=[output_rows=..., elapsed_compute=...,
output_bytes=..., output_batches=...]
LanceRead: elapsed=..., uri=..., projection=[vector],
num_fragments=..., range_before=None, range_after=None,
row_id=true, row_addr=false,
full_filter=--, refine_filter=--,
metrics=[output_rows=..., elapsed_compute=..., output_bytes=...,
fragments_scanned=..., ranges_scanned=1, rows_scanned=1,
bytes_read=..., iops=..., requests=..., task_wait_time=...]
metrics=[output_rows=..., elapsed_compute=..., row_replacements=...],
cumulative_cpu=...
KNNVectorDistance: metric=l2,
metrics=[output_rows=..., elapsed_compute=..., output_batches=...],
cumulative_cpu=...
LanceRead: uri=..., projection=[vector], ...
metrics=[output_rows=..., elapsed_compute=...,
bytes_read=..., iops=..., requests=...], cumulative_cpu=...
Returns
-------
@@ -1433,19 +1428,6 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._bypass_vector_index = True
return self
def fast_search(self) -> LanceVectorQueryBuilder:
"""
Skip a flat search of unindexed data. This will improve
search performance but search results will not include unindexed data.
Returns
-------
LanceVectorQueryBuilder
The LanceVectorQueryBuilder object.
"""
self._fast_search = True
return self
class LanceFtsQueryBuilder(LanceQueryBuilder):
"""A builder for full text search for LanceDB."""

View File

@@ -2222,37 +2222,6 @@ class LanceTable(Table):
def uri(self) -> str:
return LOOP.run(self._table.uri())
def initial_storage_options(self) -> Optional[Dict[str, str]]:
"""Get the initial storage options that were passed in when opening this table.
For dynamically refreshed options (e.g., credential vending), use
:meth:`latest_storage_options`.
Warning: This is an internal API and the return value is subject to change.
Returns
-------
Optional[Dict[str, str]]
The storage options, or None if no storage options were configured.
"""
return LOOP.run(self._table.initial_storage_options())
def latest_storage_options(self) -> Optional[Dict[str, str]]:
"""Get the latest storage options, refreshing from provider if configured.
This method is useful for credential vending scenarios where storage options
may be refreshed dynamically. If no dynamic provider is configured, this
returns the initial static options.
Warning: This is an internal API and the return value is subject to change.
Returns
-------
Optional[Dict[str, str]]
The storage options, or None if no storage options were configured.
"""
return LOOP.run(self._table.latest_storage_options())
def create_scalar_index(
self,
column: str,
@@ -3655,37 +3624,6 @@ class AsyncTable:
"""
return await self._inner.uri()
async def initial_storage_options(self) -> Optional[Dict[str, str]]:
"""Get the initial storage options that were passed in when opening this table.
For dynamically refreshed options (e.g., credential vending), use
:meth:`latest_storage_options`.
Warning: This is an internal API and the return value is subject to change.
Returns
-------
Optional[Dict[str, str]]
The storage options, or None if no storage options were configured.
"""
return await self._inner.initial_storage_options()
async def latest_storage_options(self) -> Optional[Dict[str, str]]:
"""Get the latest storage options, refreshing from provider if configured.
This method is useful for credential vending scenarios where storage options
may be refreshed dynamically. If no dynamic provider is configured, this
returns the initial static options.
Warning: This is an internal API and the return value is subject to change.
Returns
-------
Optional[Dict[str, str]]
The storage options, or None if no storage options were configured.
"""
return await self._inner.latest_storage_options()
async def add(
self,
data: DATA,

View File

@@ -517,36 +517,19 @@ def test_ollama_embedding(tmp_path):
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
@pytest.mark.parametrize(
"model_name,expected_dims",
[
("voyage-3", 1024),
("voyage-4", 1024),
("voyage-4-lite", 1024),
("voyage-4-large", 1024),
],
)
def test_voyageai_embedding_function(model_name, expected_dims, tmp_path):
"""Integration test for VoyageAI text embedding models with real API calls."""
voyageai = get_registry().get("voyageai").create(name=model_name, max_retries=0)
def test_voyageai_embedding_function():
voyageai = get_registry().get("voyageai").create(name="voyage-3", max_retries=0)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
assert voyageai.ndims() == expected_dims, (
f"{model_name} should have {expected_dims} dimensions"
)
# Test search functionality
result = tbl.search("hello").limit(1).to_pandas()
assert result["text"][0] == "hello world"
@pytest.mark.slow

View File

@@ -438,15 +438,11 @@ def test_filter_with_splits(mem_db):
row_count = permutation_tbl.count_rows()
assert row_count == 67
# Verify the permutation table only contains row_id and split_id
assert set(permutation_tbl.schema.names) == {"row_id", "split_id"}
row_ids = permutation_tbl.search(None).to_arrow().to_pydict()["row_id"]
data = tbl.take_row_ids(row_ids).to_arrow().to_pydict()
data = permutation_tbl.search(None).to_arrow().to_pydict()
categories = data["category"]
# All categories should be A or B
assert all(cat in ("A", "B") for cat in categories)
assert all(cat in ["A", "B"] for cat in categories)
def test_filter_with_shuffle(mem_db):

View File

@@ -1499,30 +1499,3 @@ def test_search_empty_table(mem_db):
# Search on empty table should return empty results, not crash
results = table.search([1.0, 2.0]).limit(5).to_list()
assert results == []
def test_fast_search(tmp_path):
db = lancedb.connect(tmp_path)
# Generate data matching the async test style
vectors = pa.FixedShapeTensorArray.from_numpy_ndarray(
np.random.rand(256, 32)
).storage
table = db.create_table("test", pa.table({"vector": vectors}))
# FIX: Pass arguments directly instead of using 'config=IvfPq(...)'
table.create_index(vector_column_name="vector", num_partitions=1, num_sub_vectors=1)
# Add data to ensure table has enough segments/rows
table.add(pa.table({"vector": vectors}))
q = [1.0] * 32
# 1. Normal Search -> Should include "LanceScan" (Brute Force / Scan)
plan = table.search(q).explain_plan(True)
assert "LanceScan" in plan
# 2. Fast Search -> Should NOT include "LanceScan" (Uses Index)
plan = table.search(q).fast_search().explain_plan(True)
assert "LanceScan" not in plan

View File

@@ -8,7 +8,7 @@ import http.server
import json
import threading
import time
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import uuid
from packaging.version import Version
@@ -601,6 +601,7 @@ def test_head():
def test_query_sync_minimal():
def handler(body):
assert body == {
"distance_type": "l2",
"k": 10,
"prefilter": True,
"refine_factor": None,
@@ -684,6 +685,7 @@ def test_query_sync_maximal():
def test_query_sync_nprobes():
def handler(body):
assert body == {
"distance_type": "l2",
"k": 10,
"prefilter": True,
"fast_search": True,
@@ -713,6 +715,7 @@ def test_query_sync_nprobes():
def test_query_sync_no_max_nprobes():
def handler(body):
assert body == {
"distance_type": "l2",
"k": 10,
"prefilter": True,
"fast_search": True,
@@ -835,6 +838,7 @@ def test_query_sync_hybrid():
else:
# Vector query
assert body == {
"distance_type": "l2",
"k": 42,
"prefilter": True,
"refine_factor": None,
@@ -1199,22 +1203,3 @@ async def test_header_provider_overrides_static_headers():
extra_headers={"X-API-Key": "static-key", "X-Extra": "extra-value"},
) as db:
await db.table_names()
@pytest.mark.parametrize("exception", [KeyboardInterrupt, SystemExit, GeneratorExit])
def test_background_loop_cancellation(exception):
"""Test that BackgroundEventLoop.run() cancels the future on interrupt."""
from lancedb.background_loop import BackgroundEventLoop
mock_future = MagicMock()
mock_future.result.side_effect = exception()
with (
patch.object(BackgroundEventLoop, "__init__", return_value=None),
patch("asyncio.run_coroutine_threadsafe", return_value=mock_future),
):
loop = BackgroundEventLoop()
loop.loop = MagicMock()
with pytest.raises(exception):
loop.run(None)
mock_future.cancel.assert_called_once()

View File

@@ -1880,13 +1880,8 @@ async def test_optimize_delete_unverified(tmp_db_async: AsyncConnection, tmp_pat
],
)
version = await table.version()
assert version == 2
# By removing a manifest file, we make the data files we just inserted unverified
version_name = 18446744073709551615 - (version - 1)
path = tmp_path / "test.lance" / "_versions" / f"{version_name:020}.manifest"
path = tmp_path / "test.lance" / "_versions" / f"{version - 1}.manifest"
os.remove(path)
stats = await table.optimize(delete_unverified=False)
assert stats.prune.old_versions_removed == 0
stats = await table.optimize(

View File

@@ -1,108 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Unit tests for VoyageAI embedding function.
These tests verify model registration and configuration without requiring API calls.
"""
import pytest
from unittest.mock import MagicMock, patch
from lancedb.embeddings import get_registry
@pytest.fixture(autouse=True)
def reset_voyageai_client():
"""Reset VoyageAI client before and after each test to avoid state pollution."""
from lancedb.embeddings.voyageai import VoyageAIEmbeddingFunction
VoyageAIEmbeddingFunction.client = None
yield
VoyageAIEmbeddingFunction.client = None
class TestVoyageAIModelRegistration:
"""Tests for VoyageAI model registration and configuration."""
@pytest.fixture
def mock_voyageai_client(self):
"""Mock VoyageAI client to avoid API calls."""
with patch.dict("os.environ", {"VOYAGE_API_KEY": "test-key"}):
with patch("lancedb.embeddings.voyageai.attempt_import_or_raise") as mock:
mock_client = MagicMock()
mock_voyageai = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock.return_value = mock_voyageai
yield mock_client
def test_voyageai_registered(self):
"""Test that VoyageAI is registered in the embedding function registry."""
registry = get_registry()
assert registry.get("voyageai") is not None
@pytest.mark.parametrize(
"model_name,expected_dims",
[
# Voyage-4 series (all 1024 dims)
("voyage-4", 1024),
("voyage-4-lite", 1024),
("voyage-4-large", 1024),
# Voyage-3 series
("voyage-3", 1024),
("voyage-3-lite", 512),
# Domain-specific models
("voyage-finance-2", 1024),
("voyage-multilingual-2", 1024),
("voyage-law-2", 1024),
("voyage-code-2", 1536),
# Multimodal
("voyage-multimodal-3", 1024),
],
)
def test_model_dimensions(self, model_name, expected_dims, mock_voyageai_client):
"""Test that each model returns the correct dimensions."""
registry = get_registry()
func = registry.get("voyageai").create(name=model_name)
assert func.ndims() == expected_dims, (
f"Model {model_name} should have {expected_dims} dimensions"
)
def test_unsupported_model_raises_error(self, mock_voyageai_client):
"""Test that unsupported models raise ValueError."""
registry = get_registry()
func = registry.get("voyageai").create(name="unsupported-model")
with pytest.raises(ValueError, match="not supported"):
func.ndims()
@pytest.mark.parametrize(
"model_name",
[
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
],
)
def test_voyage4_models_are_text_models(self, model_name, mock_voyageai_client):
"""Test that voyage-4 models are classified as text models (not multimodal)."""
registry = get_registry()
func = registry.get("voyageai").create(name=model_name)
assert not func._is_multimodal_model(model_name), (
f"{model_name} should be a text model, not multimodal"
)
def test_voyage4_models_in_text_embedding_list(self, mock_voyageai_client):
"""Test that voyage-4 models are in the text_embedding_models list."""
registry = get_registry()
func = registry.get("voyageai").create(name="voyage-4")
assert "voyage-4" in func.text_embedding_models
assert "voyage-4-lite" in func.text_embedding_models
assert "voyage-4-large" in func.text_embedding_models
def test_voyage4_models_not_in_multimodal_list(self, mock_voyageai_client):
"""Test that voyage-4 models are NOT in the multimodal_embedding_models list."""
registry = get_registry()
func = registry.get("voyageai").create(name="voyage-4")
assert "voyage-4" not in func.multimodal_embedding_models
assert "voyage-4-lite" not in func.multimodal_embedding_models
assert "voyage-4-large" not in func.multimodal_embedding_models

View File

@@ -10,7 +10,8 @@ use arrow::{
use futures::stream::StreamExt;
use lancedb::arrow::SendableRecordBatchStream;
use pyo3::{
exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, Py, PyAny, PyRef, PyResult, Python,
exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, PyAny, PyObject, PyRef, PyResult,
Python,
};
use pyo3_async_runtimes::tokio::future_into_py;
@@ -35,11 +36,8 @@ impl RecordBatchStream {
#[pymethods]
impl RecordBatchStream {
#[getter]
pub fn schema(&self, py: Python) -> PyResult<Py<PyAny>> {
(*self.schema)
.clone()
.into_pyarrow(py)
.map(|obj| obj.unbind())
pub fn schema(&self, py: Python) -> PyResult<PyObject> {
(*self.schema).clone().into_pyarrow(py)
}
pub fn __aiter__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> {
@@ -55,12 +53,7 @@ impl RecordBatchStream {
.next()
.await
.ok_or_else(|| PyStopAsyncIteration::new_err(""))?;
Python::attach(|py| {
inner_next
.infer_error()?
.to_pyarrow(py)
.map(|obj| obj.unbind())
})
Python::with_gil(|py| inner_next.infer_error()?.to_pyarrow(py))
})
}
}

View File

@@ -12,7 +12,7 @@ use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
pyclass, pyfunction, pymethods,
types::{PyDict, PyDictMethods},
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
Bound, FromPyObject, Py, PyAny, PyObject, PyRef, PyResult, Python,
};
use pyo3_async_runtimes::tokio::future_into_py;
@@ -114,7 +114,7 @@ impl Connection {
data: Bound<'_, PyAny>,
namespace: Vec<String>,
storage_options: Option<HashMap<String, String>>,
storage_options_provider: Option<Py<PyAny>>,
storage_options_provider: Option<PyObject>,
location: Option<String>,
) -> PyResult<Bound<'a, PyAny>> {
let inner = self_.get_inner()?.clone();
@@ -152,7 +152,7 @@ impl Connection {
schema: Bound<'_, PyAny>,
namespace: Vec<String>,
storage_options: Option<HashMap<String, String>>,
storage_options_provider: Option<Py<PyAny>>,
storage_options_provider: Option<PyObject>,
location: Option<String>,
) -> PyResult<Bound<'a, PyAny>> {
let inner = self_.get_inner()?.clone();
@@ -187,7 +187,7 @@ impl Connection {
name: String,
namespace: Vec<String>,
storage_options: Option<HashMap<String, String>>,
storage_options_provider: Option<Py<PyAny>>,
storage_options_provider: Option<PyObject>,
index_cache_size: Option<u32>,
location: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
@@ -307,7 +307,7 @@ impl Connection {
..Default::default()
};
let response = inner.list_namespaces(request).await.infer_error()?;
Python::attach(|py| -> PyResult<Py<PyDict>> {
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py);
dict.set_item("namespaces", response.namespaces)?;
dict.set_item("page_token", response.page_token)?;
@@ -345,7 +345,7 @@ impl Connection {
..Default::default()
};
let response = inner.create_namespace(request).await.infer_error()?;
Python::attach(|py| -> PyResult<Py<PyDict>> {
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py);
dict.set_item("properties", response.properties)?;
Ok(dict.unbind())
@@ -386,7 +386,7 @@ impl Connection {
..Default::default()
};
let response = inner.drop_namespace(request).await.infer_error()?;
Python::attach(|py| -> PyResult<Py<PyDict>> {
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py);
dict.set_item("properties", response.properties)?;
dict.set_item("transaction_id", response.transaction_id)?;
@@ -413,7 +413,7 @@ impl Connection {
..Default::default()
};
let response = inner.describe_namespace(request).await.infer_error()?;
Python::attach(|py| -> PyResult<Py<PyDict>> {
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py);
dict.set_item("properties", response.properties)?;
Ok(dict.unbind())
@@ -443,7 +443,7 @@ impl Connection {
..Default::default()
};
let response = inner.list_tables(request).await.infer_error()?;
Python::attach(|py| -> PyResult<Py<PyDict>> {
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py);
dict.set_item("tables", response.tables)?;
dict.set_item("page_token", response.page_token)?;

View File

@@ -40,7 +40,7 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
request_id,
source,
status_code,
} => Python::attach(|py| {
} => Python::with_gil(|py| {
let message = err.to_string();
let http_err_cls = py
.import(intern!(py, "lancedb.remote.errors"))?
@@ -75,7 +75,7 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
max_read_failures,
source,
status_code,
} => Python::attach(|py| {
} => Python::with_gil(|py| {
let cause_err = http_from_rust_error(
py,
source.as_ref(),

View File

@@ -12,7 +12,7 @@ pub struct PyHeaderProvider {
impl Clone for PyHeaderProvider {
fn clone(&self) -> Self {
Python::attach(|py| Self {
Python::with_gil(|py| Self {
provider: self.provider.clone_ref(py),
})
}
@@ -25,7 +25,7 @@ impl PyHeaderProvider {
/// Get headers from the Python provider (internal implementation)
fn get_headers_internal(&self) -> Result<HashMap<String, String>, String> {
Python::attach(|py| {
Python::with_gil(|py| {
// Call the get_headers method
let result = self.provider.call_method0(py, "get_headers");

View File

@@ -281,7 +281,7 @@ impl PyPermutationReader {
let reader = slf.reader.clone();
future_into_py(slf.py(), async move {
let schema = reader.output_schema(selection).await.infer_error()?;
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
Python::with_gil(|py| schema.to_pyarrow(py))
})
}

View File

@@ -453,7 +453,7 @@ impl Query {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?;
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
Python::with_gil(|py| schema.to_pyarrow(py))
})
}
@@ -532,7 +532,7 @@ impl TakeQuery {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?;
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
Python::with_gil(|py| schema.to_pyarrow(py))
})
}
@@ -627,7 +627,7 @@ impl FTSQuery {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?;
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
Python::with_gil(|py| schema.to_pyarrow(py))
})
}
@@ -806,7 +806,7 @@ impl VectorQuery {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?;
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
Python::with_gil(|py| schema.to_pyarrow(py))
})
}

View File

@@ -17,20 +17,20 @@ use pyo3::types::PyDict;
/// Internal wrapper around a Python object implementing StorageOptionsProvider
pub struct PyStorageOptionsProvider {
/// The Python object implementing fetch_storage_options()
inner: Py<PyAny>,
inner: PyObject,
}
impl Clone for PyStorageOptionsProvider {
fn clone(&self) -> Self {
Python::attach(|py| Self {
Python::with_gil(|py| Self {
inner: self.inner.clone_ref(py),
})
}
}
impl PyStorageOptionsProvider {
pub fn new(obj: Py<PyAny>) -> PyResult<Self> {
Python::attach(|py| {
pub fn new(obj: PyObject) -> PyResult<Self> {
Python::with_gil(|py| {
// Verify the object has a fetch_storage_options method
if !obj.bind(py).hasattr("fetch_storage_options")? {
return Err(pyo3::exceptions::PyTypeError::new_err(
@@ -60,7 +60,7 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
let py_provider = self.py_provider.clone();
tokio::task::spawn_blocking(move || {
Python::attach(|py| {
Python::with_gil(|py| {
// Call the Python fetch_storage_options method
let result = py_provider
.inner
@@ -119,7 +119,7 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
}
fn provider_id(&self) -> String {
Python::attach(|py| {
Python::with_gil(|py| {
// Call provider_id() method on the Python object
let obj = self.py_provider.inner.bind(py);
obj.call_method0("provider_id")
@@ -143,7 +143,7 @@ impl std::fmt::Debug for PyStorageOptionsProviderWrapper {
/// This is the main entry point for converting Python StorageOptionsProvider objects
/// to Rust trait objects that can be used by the Lance ecosystem.
pub fn py_object_to_storage_options_provider(
py_obj: Py<PyAny>,
py_obj: PyObject,
) -> PyResult<Arc<dyn StorageOptionsProvider>> {
let py_provider = PyStorageOptionsProvider::new(py_obj)?;
Ok(Arc::new(PyStorageOptionsProviderWrapper::new(py_provider)))

View File

@@ -287,7 +287,7 @@ impl Table {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
let schema = inner.schema().await.infer_error()?;
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
Python::with_gil(|py| schema.to_pyarrow(py))
})
}
@@ -437,7 +437,7 @@ impl Table {
future_into_py(self_.py(), async move {
let stats = inner.index_stats(&index_name).await.infer_error()?;
if let Some(stats) = stats {
Python::attach(|py| {
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item("num_indexed_rows", stats.num_indexed_rows)?;
dict.set_item("num_unindexed_rows", stats.num_unindexed_rows)?;
@@ -467,7 +467,7 @@ impl Table {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
let stats = inner.stats().await.infer_error()?;
Python::attach(|py| {
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item("total_bytes", stats.total_bytes)?;
dict.set_item("num_rows", stats.num_rows)?;
@@ -502,20 +502,6 @@ impl Table {
future_into_py(self_.py(), async move { inner.uri().await.infer_error() })
}
pub fn initial_storage_options(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
Ok(inner.initial_storage_options().await)
})
}
pub fn latest_storage_options(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
inner.latest_storage_options().await.infer_error()
})
}
pub fn __repr__(&self) -> String {
match &self.inner {
None => format!("ClosedTable({})", self.name),
@@ -535,7 +521,7 @@ impl Table {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
let versions = inner.list_versions().await.infer_error()?;
let versions_as_dict = Python::attach(|py| {
let versions_as_dict = Python::with_gil(|py| {
versions
.iter()
.map(|v| {
@@ -886,7 +872,7 @@ impl Tags {
let tags = inner.tags().await.infer_error()?;
let res = tags.list().await.infer_error()?;
Python::attach(|py| {
Python::with_gil(|py| {
let py_dict = PyDict::new(py);
for (key, contents) in res {
let value_dict = PyDict::new(py);

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.26.1"
version = "0.24.1"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true
@@ -25,7 +25,6 @@ datafusion-catalog.workspace = true
datafusion-common.workspace = true
datafusion-execution.workspace = true
datafusion-expr.workspace = true
datafusion-physical-expr.workspace = true
datafusion-physical-plan.workspace = true
datafusion.workspace = true
object_store = { workspace = true }

View File

@@ -251,36 +251,8 @@ impl CreateTableBuilder<false> {
/// Execute the create table operation
pub async fn execute(self) -> Result<Table> {
let parent = self.parent.clone();
let embedding_registry = self.embedding_registry.clone();
let request = self.into_request()?;
Ok(Table::new_with_embedding_registry(
parent.create_table(request).await?,
parent,
embedding_registry,
))
}
fn into_request(self) -> Result<CreateTableRequest> {
if self.embeddings.is_empty() {
return Ok(self.request);
}
let CreateTableData::Empty(table_def) = self.request.data else {
unreachable!("CreateTableBuilder<false> should always have Empty data")
};
let schema = table_def.schema.clone();
let empty_batch = arrow_array::RecordBatch::new_empty(schema.clone());
let reader = Box::new(std::iter::once(Ok(empty_batch)).collect::<Vec<_>>());
let reader = arrow_array::RecordBatchIterator::new(reader.into_iter(), schema);
let with_embeddings = WithEmbeddings::new(reader, self.embeddings);
let table_definition = with_embeddings.table_definition()?;
Ok(CreateTableRequest {
data: CreateTableData::Empty(table_definition),
..self.request
})
let table = parent.create_table(self.request).await?;
Ok(Table::new(table, parent))
}
}
@@ -1720,128 +1692,4 @@ mod tests {
let cloned_count = cloned_table.count_rows(None).await.unwrap();
assert_eq!(source_count, cloned_count);
}
#[tokio::test]
async fn test_create_empty_table_with_embeddings() {
use crate::embeddings::{EmbeddingDefinition, EmbeddingFunction};
use arrow_array::{
Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
};
use std::borrow::Cow;
#[derive(Debug, Clone)]
struct MockEmbedding {
dim: usize,
}
impl EmbeddingFunction for MockEmbedding {
fn name(&self) -> &str {
"test_embedding"
}
fn source_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
self.dim as i32,
true,
)))
}
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
let len = source.len();
let values = vec![1.0f32; len * self.dim];
let values = Arc::new(Float32Array::from(values));
let field = Arc::new(Field::new("item", DataType::Float32, true));
Ok(Arc::new(FixedSizeListArray::new(
field,
self.dim as i32,
values,
None,
)))
}
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let embed_func = Arc::new(MockEmbedding { dim: 128 });
db.embedding_registry()
.register("test_embedding", embed_func.clone())
.unwrap();
let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)]));
let ed = EmbeddingDefinition {
source_column: "name".to_owned(),
dest_column: Some("name_embedding".to_owned()),
embedding_name: "test_embedding".to_owned(),
};
let table = db
.create_empty_table("test", schema)
.mode(CreateTableMode::Overwrite)
.add_embedding(ed)
.unwrap()
.execute()
.await
.unwrap();
let table_schema = table.schema().await.unwrap();
assert!(table_schema.column_with_name("name").is_some());
assert!(table_schema.column_with_name("name_embedding").is_some());
let embedding_field = table_schema.field_with_name("name_embedding").unwrap();
assert_eq!(
embedding_field.data_type(),
&DataType::new_fixed_size_list(DataType::Float32, 128, true)
);
let input_schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)]));
let input_batch = RecordBatch::try_new(
input_schema.clone(),
vec![Arc::new(StringArray::from(vec![
Some("Alice"),
Some("Bob"),
Some("Charlie"),
]))],
)
.unwrap();
let input_reader = Box::new(RecordBatchIterator::new(
vec![Ok(input_batch)].into_iter(),
input_schema,
));
table.add(input_reader).execute().await.unwrap();
let results = table
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_rows(), 3);
assert!(batch.column_by_name("name_embedding").is_some());
let embedding_col = batch
.column_by_name("name_embedding")
.unwrap()
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap();
assert_eq!(embedding_col.len(), 3);
}
}

View File

@@ -12,8 +12,6 @@ use datafusion_common::hash_utils::create_hashes;
use futures::{StreamExt, TryStreamExt};
use lance_arrow::SchemaExt;
use lance_core::ROW_ID;
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
dataloader::{
@@ -362,15 +360,11 @@ impl Splitter {
pub fn project(&self, query: Query) -> Query {
match &self.strategy {
SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![
(SPLIT_ID_COLUMN.to_string(), calculation.clone()),
(ROW_ID.to_string(), ROW_ID.to_string()),
])),
SplitStrategy::Hash { columns, .. } => {
let mut cols = columns.clone();
cols.push(ROW_ID.to_string());
query.select(Select::Columns(cols))
}
SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![(
SPLIT_ID_COLUMN.to_string(),
calculation.clone(),
)])),
SplitStrategy::Hash { columns, .. } => query.select(Select::Columns(columns.clone())),
_ => query,
}
}

View File

@@ -1,8 +1,6 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
pub mod insert;
use crate::index::Index;
use crate::index::IndexStatistics;
use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest};
@@ -470,9 +468,7 @@ impl<S: HttpSend> RemoteTable<S> {
self.apply_query_params(&mut body, &query.base)?;
// Apply general parameters, before we dispatch based on number of query vectors.
if let Some(distance_type) = query.distance_type {
body["distance_type"] = serde_json::json!(distance_type);
}
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
// In 0.23.1 we migrated from `nprobes` to `minimum_nprobes` and `maximum_nprobes`.
// Old client / new server: since minimum_nprobes is missing, fallback to nprobes
// New client / old server: old server will only see nprobes, make sure to set both
@@ -1497,14 +1493,6 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
None
}
async fn initial_storage_options(&self) -> Option<HashMap<String, String>> {
None
}
async fn latest_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
Ok(None)
}
async fn stats(&self) -> Result<TableStatistics> {
let request = self
.client
@@ -1520,21 +1508,6 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
})?;
Ok(stats)
}
async fn create_insert_exec(
&self,
input: Arc<dyn ExecutionPlan>,
write_params: lance::dataset::WriteParams,
) -> Result<Arc<dyn ExecutionPlan>> {
let overwrite = matches!(write_params.mode, lance::dataset::WriteMode::Overwrite);
Ok(Arc::new(insert::RemoteInsertExec::new(
self.name.clone(),
self.identifier.clone(),
self.client.clone(),
input,
overwrite,
)))
}
}
#[derive(Serialize)]
@@ -2257,6 +2230,7 @@ mod tests {
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let mut expected_body = serde_json::json!({
"prefilter": true,
"distance_type": "l2",
"nprobes": 20,
"minimum_nprobes": 20,
"maximum_nprobes": 20,

View File

@@ -1,438 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! DataFusion ExecutionPlan for inserting data into remote LanceDB tables.
use std::any::Any;
use std::sync::{Arc, Mutex};
use arrow_array::{ArrayRef, RecordBatch, UInt64Array};
use arrow_ipc::CompressionType;
use arrow_schema::ArrowError;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_physical_expr::EquivalenceProperties;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use futures::StreamExt;
use http::header::CONTENT_TYPE;
use crate::remote::client::{HttpSend, RestfulLanceDbClient, Sender};
use crate::remote::table::RemoteTable;
use crate::remote::ARROW_STREAM_CONTENT_TYPE;
use crate::table::datafusion::insert::COUNT_SCHEMA;
use crate::table::AddResult;
use crate::Error;
/// ExecutionPlan for inserting data into a remote LanceDB table.
///
/// This plan:
/// 1. Requires single partition (no parallel remote inserts yet)
/// 2. Streams data as Arrow IPC to `/v1/table/{id}/insert/` endpoint
/// 3. Stores AddResult for retrieval after execution
#[derive(Debug)]
pub struct RemoteInsertExec<S: HttpSend = Sender> {
table_name: String,
identifier: String,
client: RestfulLanceDbClient<S>,
input: Arc<dyn ExecutionPlan>,
overwrite: bool,
properties: PlanProperties,
add_result: Arc<Mutex<Option<AddResult>>>,
}
impl<S: HttpSend + 'static> RemoteInsertExec<S> {
/// Create a new RemoteInsertExec.
pub fn new(
table_name: String,
identifier: String,
client: RestfulLanceDbClient<S>,
input: Arc<dyn ExecutionPlan>,
overwrite: bool,
) -> Self {
let schema = COUNT_SCHEMA.clone();
let properties = PlanProperties::new(
EquivalenceProperties::new(schema),
datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
datafusion_physical_plan::execution_plan::EmissionType::Final,
datafusion_physical_plan::execution_plan::Boundedness::Bounded,
);
Self {
table_name,
identifier,
client,
input,
overwrite,
properties,
add_result: Arc::new(Mutex::new(None)),
}
}
/// Get the add result after execution.
// TODO: this will be used when we wire this up to Table::add().
#[allow(dead_code)]
pub fn add_result(&self) -> Option<AddResult> {
self.add_result.lock().unwrap().clone()
}
fn stream_as_body(data: SendableRecordBatchStream) -> DataFusionResult<reqwest::Body> {
let options = arrow_ipc::writer::IpcWriteOptions::default()
.try_with_compression(Some(CompressionType::LZ4_FRAME))?;
let writer = arrow_ipc::writer::StreamWriter::try_new_with_options(
Vec::new(),
&data.schema(),
options,
)?;
let stream = futures::stream::try_unfold((data, writer), move |(mut data, mut writer)| {
async move {
match data.next().await {
Some(Ok(batch)) => {
writer.write(&batch)?;
let buffer = std::mem::take(writer.get_mut());
Ok(Some((buffer, (data, writer))))
}
Some(Err(e)) => Err(e),
None => {
if let Err(ArrowError::IpcError(_msg)) = writer.finish() {
// Will error if already closed.
return Ok(None);
};
let buffer = std::mem::take(writer.get_mut());
Ok(Some((buffer, (data, writer))))
}
}
}
});
Ok(reqwest::Body::wrap_stream(stream))
}
}
impl<S: HttpSend + 'static> DisplayAs for RemoteInsertExec<S> {
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(
f,
"RemoteInsertExec: table={}, overwrite={}",
self.table_name, self.overwrite
)
}
DisplayFormatType::TreeRender => {
write!(f, "RemoteInsertExec")
}
}
}
}
impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
fn name(&self) -> &str {
Self::static_name()
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn maintains_input_order(&self) -> Vec<bool> {
vec![false]
}
fn required_input_distribution(&self) -> Vec<datafusion_physical_plan::Distribution> {
// Until we have a separate commit endpoint, we need to do all inserts in a single partition
vec![datafusion_physical_plan::Distribution::SinglePartition]
}
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
vec![false]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
if children.len() != 1 {
return Err(DataFusionError::Internal(
"RemoteInsertExec requires exactly one child".to_string(),
));
}
Ok(Arc::new(Self::new(
self.table_name.clone(),
self.identifier.clone(),
self.client.clone(),
children[0].clone(),
self.overwrite,
)))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> DataFusionResult<SendableRecordBatchStream> {
if partition != 0 {
return Err(DataFusionError::Internal(
"RemoteInsertExec only supports single partition execution".to_string(),
));
}
let input_stream = self.input.execute(0, context)?;
let client = self.client.clone();
let identifier = self.identifier.clone();
let overwrite = self.overwrite;
let add_result = self.add_result.clone();
let table_name = self.table_name.clone();
let stream = futures::stream::once(async move {
let mut request = client
.post(&format!("/v1/table/{}/insert/", identifier))
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
if overwrite {
request = request.query(&[("mode", "overwrite")]);
}
let body = Self::stream_as_body(input_stream)?;
let request = request.body(body);
let (request_id, response) = client
.send(request)
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let response =
RemoteTable::<Sender>::handle_table_not_found(&table_name, response, &request_id)
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let response = client
.check_response(&request_id, response)
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let body_text = response.text().await.map_err(|e| {
DataFusionError::External(Box::new(Error::Http {
source: Box::new(e),
request_id: request_id.clone(),
status_code: None,
}))
})?;
let parsed_result = if body_text.trim().is_empty() {
// Backward compatible with old servers
AddResult { version: 0 }
} else {
serde_json::from_str(&body_text).map_err(|e| {
DataFusionError::External(Box::new(Error::Http {
source: format!("Failed to parse add response: {}", e).into(),
request_id: request_id.clone(),
status_code: None,
}))
})?
};
{
let mut res_lock = add_result.lock().map_err(|_| {
DataFusionError::Execution("Failed to acquire lock for add_result".to_string())
})?;
*res_lock = Some(parsed_result);
}
// Return a single batch with count 0 (actual count is tracked in add_result)
let count_array: ArrayRef = Arc::new(UInt64Array::from(vec![0u64]));
let batch = RecordBatch::try_new(COUNT_SCHEMA.clone(), vec![count_array])?;
Ok::<_, DataFusionError>(batch)
});
Ok(Box::pin(RecordBatchStreamAdapter::new(
COUNT_SCHEMA.clone(),
stream,
)))
}
}
#[cfg(test)]
mod tests {
use arrow_array::record_batch;
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use datafusion::prelude::SessionContext;
use datafusion_catalog::MemTable;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::remote::ARROW_STREAM_CONTENT_TYPE;
use crate::table::datafusion::BaseTableAdapter;
use crate::Table;
fn schema_json() -> &'static str {
r#"{"fields": [{"name": "id", "type": {"type": "int32"}, "nullable": true}]}"#
}
#[tokio::test]
async fn test_remote_insert_exec_execute_empty() {
let request_count = Arc::new(AtomicUsize::new(0));
let request_count_clone = request_count.clone();
let table = Table::new_with_handler("my_table", move |request| {
let path = request.url().path();
if path == "/v1/table/my_table/describe/" {
// Return schema for BaseTableAdapter::try_new
return http::Response::builder()
.status(200)
.body(format!(r#"{{"version": 1, "schema": {}}}"#, schema_json()))
.unwrap();
}
if path == "/v1/table/my_table/insert/" {
assert_eq!(request.method(), "POST");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
ARROW_STREAM_CONTENT_TYPE
);
request_count_clone.fetch_add(1, Ordering::SeqCst);
return http::Response::builder()
.status(200)
.body(r#"{"version": 2}"#.to_string())
.unwrap();
}
panic!("Unexpected request path: {}", path);
});
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"id",
DataType::Int32,
true,
)]));
// Create empty MemTable (no batches)
let source_table = MemTable::try_new(schema, vec![vec![]]).unwrap();
let ctx = SessionContext::new();
// Register the remote table as insert target
let provider = BaseTableAdapter::try_new(table.base_table().clone())
.await
.unwrap();
ctx.register_table("my_table", Arc::new(provider)).unwrap();
// Register empty source
ctx.register_table("empty_source", Arc::new(source_table))
.unwrap();
// Execute the INSERT
ctx.sql("INSERT INTO my_table SELECT * FROM empty_source")
.await
.unwrap()
.collect()
.await
.unwrap();
// Verify: should have made exactly one HTTP request even with empty input
assert_eq!(request_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_remote_insert_exec_multi_partition() {
let request_count = Arc::new(AtomicUsize::new(0));
let request_count_clone = request_count.clone();
let table = Table::new_with_handler("my_table", move |request| {
let path = request.url().path();
if path == "/v1/table/my_table/describe/" {
// Return schema for BaseTableAdapter::try_new
return http::Response::builder()
.status(200)
.body(format!(r#"{{"version": 1, "schema": {}}}"#, schema_json()))
.unwrap();
}
if path == "/v1/table/my_table/insert/" {
assert_eq!(request.method(), "POST");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
ARROW_STREAM_CONTENT_TYPE
);
request_count_clone.fetch_add(1, Ordering::SeqCst);
return http::Response::builder()
.status(200)
.body(r#"{"version": 2}"#.to_string())
.unwrap();
}
panic!("Unexpected request path: {}", path);
});
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"id",
DataType::Int32,
true,
)]));
// Create MemTable with multiple partitions and multiple batches
let source_table = MemTable::try_new(
schema,
vec![
// Partition 0
vec![
record_batch!(("id", Int32, [1, 2])).unwrap(),
record_batch!(("id", Int32, [3, 4])).unwrap(),
],
// Partition 1
vec![record_batch!(("id", Int32, [5, 6, 7])).unwrap()],
// Partition 2
vec![record_batch!(("id", Int32, [8])).unwrap()],
],
)
.unwrap();
let ctx = SessionContext::new();
// Register the remote table as insert target
let provider = BaseTableAdapter::try_new(table.base_table().clone())
.await
.unwrap();
ctx.register_table("my_table", Arc::new(provider)).unwrap();
// Register multi-partition source
ctx.register_table("multi_partition_source", Arc::new(source_table))
.unwrap();
// Get the physical plan and verify it includes a repartition to 1
let df = ctx
.sql("INSERT INTO my_table SELECT * FROM multi_partition_source")
.await
.unwrap();
let plan = df.clone().create_physical_plan().await.unwrap();
let plan_str = datafusion::physical_plan::displayable(plan.as_ref())
.indent(true)
.to_string();
// The plan should include a CoalescePartitionsExec to merge partitions
assert!(
plan_str.contains("CoalescePartitionsExec"),
"Expected CoalescePartitionsExec in plan:\n{}",
plan_str
);
// Execute the INSERT
df.collect().await.unwrap();
// Verify: should have made exactly one HTTP request despite multiple input partitions
assert_eq!(request_count.load(Ordering::SeqCst), 1);
}
}

View File

@@ -23,7 +23,9 @@ pub use lance::dataset::ColumnAlteration;
pub use lance::dataset::NewColumnTransform;
pub use lance::dataset::ReadParams;
pub use lance::dataset::Version;
use lance::dataset::{InsertBuilder, WhenMatched, WriteMode, WriteParams};
use lance::dataset::{
InsertBuilder, UpdateBuilder as LanceUpdateBuilder, WhenMatched, WriteMode, WriteParams,
};
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::index::vector::utils::infer_vector_dim;
use lance::index::vector::VectorIndexParams;
@@ -77,14 +79,10 @@ use self::merge::MergeInsertBuilder;
pub mod datafusion;
pub(crate) mod dataset;
pub mod delete;
pub mod merge;
pub mod schema_evolution;
pub mod update;
use crate::index::waiter::wait_for_index;
pub use chrono::Duration;
pub use delete::DeleteResult;
use futures::future::{join_all, Either};
pub use lance::dataset::optimize::CompactionOptions;
pub use lance::dataset::refs::{TagContents, Tags as LanceTags};
@@ -92,9 +90,7 @@ pub use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::dataset::statistics::DatasetStatisticsExt;
use lance_index::frag_reuse::FRAG_REUSE_INDEX_NAME;
pub use lance_index::optimize::OptimizeOptions;
pub use schema_evolution::{AddColumnsResult, AlterColumnsResult, DropColumnsResult};
use serde_with::skip_serializing_none;
pub use update::{UpdateBuilder, UpdateResult};
/// Defines the type of column
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -331,6 +327,72 @@ impl<T: IntoArrow> AddDataBuilder<T> {
}
}
/// A builder for configuring an [`Table::update`] operation
#[derive(Debug, Clone)]
pub struct UpdateBuilder {
parent: Arc<dyn BaseTable>,
pub(crate) filter: Option<String>,
pub(crate) columns: Vec<(String, String)>,
}
impl UpdateBuilder {
fn new(parent: Arc<dyn BaseTable>) -> Self {
Self {
parent,
filter: None,
columns: Vec::new(),
}
}
/// Limits the update operation to rows matching the given filter
///
/// If a row does not match the filter then it will be left unchanged.
pub fn only_if(mut self, filter: impl Into<String>) -> Self {
self.filter = Some(filter.into());
self
}
/// Specifies a column to update
///
/// This method may be called multiple times to update multiple columns
///
/// The `update_expr` should be an SQL expression explaining how to calculate
/// the new value for the column. The expression will be evaluated against the
/// previous row's value.
///
/// # Examples
///
/// ```
/// # use lancedb::Table;
/// # async fn doctest_helper(tbl: Table) {
/// let mut operation = tbl.update();
/// // Increments the `bird_count` value by 1
/// operation = operation.column("bird_count", "bird_count + 1");
/// operation.execute().await.unwrap();
/// # }
/// ```
pub fn column(
mut self,
column_name: impl Into<String>,
update_expr: impl Into<String>,
) -> Self {
self.columns.push((column_name.into(), update_expr.into()));
self
}
/// Executes the update operation.
/// Returns the update result
pub async fn execute(self) -> Result<UpdateResult> {
if self.columns.is_empty() {
Err(Error::InvalidInput {
message: "at least one column must be specified in an update operation".to_string(),
})
} else {
self.parent.clone().update(self).await
}
}
}
/// Filters that can be used to limit the rows returned by a query
pub enum Filter {
/// A SQL filter string
@@ -364,6 +426,17 @@ pub trait Tags: Send + Sync {
async fn update(&mut self, tag: &str, version: u64) -> Result<()>;
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct UpdateResult {
#[serde(default)]
pub rows_updated: u64,
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct AddResult {
// The commit version associated with the operation.
@@ -373,6 +446,15 @@ pub struct AddResult {
pub version: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct DeleteResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct MergeResult {
// The commit version associated with the operation.
@@ -398,6 +480,33 @@ pub struct MergeResult {
pub num_attempts: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct AddColumnsResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct AlterColumnsResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct DropColumnsResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
/// A trait for anything "table-like". This is used for both native tables (which target
/// Lance datasets) and remote tables (which target LanceDB cloud)
///
@@ -502,17 +611,7 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
/// Get the table URI (storage location)
async fn uri(&self) -> Result<String>;
/// Get the storage options used when opening this table, if any.
#[deprecated(since = "0.25.0", note = "Use initial_storage_options() instead")]
async fn storage_options(&self) -> Option<HashMap<String, String>>;
/// Get the initial storage options that were passed in when opening this table.
///
/// For dynamically refreshed options (e.g., credential vending), use [`Self::latest_storage_options`].
async fn initial_storage_options(&self) -> Option<HashMap<String, String>>;
/// Get the latest storage options, refreshing from provider if configured.
///
/// Returns `Ok(Some(options))` if storage options are available (static or refreshed),
/// `Ok(None)` if no storage options were configured, or `Err(...)` if refresh failed.
async fn latest_storage_options(&self) -> Result<Option<HashMap<String, String>>>;
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
/// are not fully indexed within the timeout.
async fn wait_for_index(
@@ -522,19 +621,6 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
) -> Result<()>;
/// Get statistics on the table
async fn stats(&self) -> Result<TableStatistics>;
/// Create an ExecutionPlan for inserting data into the table.
///
/// This is used by the DataFusion TableProvider implementation to support
/// INSERT INTO statements.
async fn create_insert_exec(
&self,
_input: Arc<dyn datafusion_physical_plan::ExecutionPlan>,
_write_params: WriteParams,
) -> Result<Arc<dyn datafusion_physical_plan::ExecutionPlan>> {
Err(Error::NotSupported {
message: "create_insert_exec not implemented".to_string(),
})
}
}
/// A Table is a collection of strong typed Rows.
@@ -1242,32 +1328,10 @@ impl Table {
/// Get the storage options used when opening this table, if any.
///
/// Warning: This is an internal API and the return value is subject to change.
#[deprecated(since = "0.25.0", note = "Use initial_storage_options() instead")]
pub async fn storage_options(&self) -> Option<HashMap<String, String>> {
#[allow(deprecated)]
self.inner.storage_options().await
}
/// Get the initial storage options that were passed in when opening this table.
///
/// For dynamically refreshed options (e.g., credential vending), use [`Self::latest_storage_options`].
///
/// Warning: This is an internal API and the return value is subject to change.
pub async fn initial_storage_options(&self) -> Option<HashMap<String, String>> {
self.inner.initial_storage_options().await
}
/// Get the latest storage options, refreshing from provider if configured.
///
/// This method is useful for credential vending scenarios where storage options
/// may be refreshed dynamically. If no dynamic provider is configured, this
/// returns the initial static options.
///
/// Warning: This is an internal API and the return value is subject to change.
pub async fn latest_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
self.inner.latest_storage_options().await
}
/// Get statistics about an index.
/// Returns None if the index does not exist.
pub async fn index_stats(
@@ -1361,9 +1425,7 @@ impl Table {
})
.collect::<Vec<_>>();
let unioned = UnionExec::try_new(projected_plans).map_err(|err| Error::Runtime {
message: err.to_string(),
})?;
let unioned = Arc::new(UnionExec::new(projected_plans));
// We require 1 partition in the final output
let repartitioned = RepartitionExec::try_new(
unioned,
@@ -2740,8 +2802,25 @@ impl BaseTable for NativeTable {
}
async fn update(&self, update: UpdateBuilder) -> Result<UpdateResult> {
// Delegate to the submodule implementation
update::execute_update(self, update).await
let dataset = self.dataset.get().await?.clone();
let mut builder = LanceUpdateBuilder::new(Arc::new(dataset));
if let Some(predicate) = update.filter {
builder = builder.update_where(&predicate)?;
}
for (column, value) in update.columns {
builder = builder.set(column, &value)?;
}
let operation = builder.build()?;
let res = operation.execute().await?;
self.dataset
.set_latest(res.new_dataset.as_ref().clone())
.await;
Ok(UpdateResult {
rows_updated: res.rows_updated,
version: res.new_dataset.version().version,
})
}
async fn create_plan(
@@ -2999,8 +3078,11 @@ impl BaseTable for NativeTable {
/// Delete rows from the table
async fn delete(&self, predicate: &str) -> Result<DeleteResult> {
// Delegate to the submodule implementation
delete::execute_delete(self, predicate).await
let mut dataset = self.dataset.get_mut().await?;
dataset.delete(predicate).await?;
Ok(DeleteResult {
version: dataset.version().version,
})
}
async fn tags(&self) -> Result<Box<dyn Tags + '_>> {
@@ -3066,15 +3148,27 @@ impl BaseTable for NativeTable {
transforms: NewColumnTransform,
read_columns: Option<Vec<String>>,
) -> Result<AddColumnsResult> {
schema_evolution::execute_add_columns(self, transforms, read_columns).await
let mut dataset = self.dataset.get_mut().await?;
dataset.add_columns(transforms, read_columns, None).await?;
Ok(AddColumnsResult {
version: dataset.version().version,
})
}
async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<AlterColumnsResult> {
schema_evolution::execute_alter_columns(self, alterations).await
let mut dataset = self.dataset.get_mut().await?;
dataset.alter_columns(alterations).await?;
Ok(AlterColumnsResult {
version: dataset.version().version,
})
}
async fn drop_columns(&self, columns: &[&str]) -> Result<DropColumnsResult> {
schema_evolution::execute_drop_columns(self, columns).await
let mut dataset = self.dataset.get_mut().await?;
dataset.drop_columns(columns).await?;
Ok(DropColumnsResult {
version: dataset.version().version,
})
}
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
@@ -3137,10 +3231,6 @@ impl BaseTable for NativeTable {
}
async fn storage_options(&self) -> Option<HashMap<String, String>> {
self.initial_storage_options().await
}
async fn initial_storage_options(&self) -> Option<HashMap<String, String>> {
self.dataset
.get()
.await
@@ -3148,11 +3238,6 @@ impl BaseTable for NativeTable {
.and_then(|dataset| dataset.initial_storage_options().cloned())
}
async fn latest_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
let dataset = self.dataset.get().await?;
Ok(dataset.latest_storage_options().await?.map(|o| o.0))
}
async fn index_stats(&self, index_name: &str) -> Result<Option<IndexStatistics>> {
let stats = match self
.dataset
@@ -3266,21 +3351,6 @@ impl BaseTable for NativeTable {
};
Ok(stats)
}
async fn create_insert_exec(
&self,
input: Arc<dyn datafusion_physical_plan::ExecutionPlan>,
write_params: WriteParams,
) -> Result<Arc<dyn datafusion_physical_plan::ExecutionPlan>> {
let ds = self.dataset.get().await?;
let dataset = Arc::new((*ds).clone());
Ok(Arc::new(datafusion::insert::InsertExec::new(
self.dataset.clone(),
dataset,
input,
write_params,
)))
}
}
#[skip_serializing_none]
@@ -3336,12 +3406,15 @@ mod tests {
use arrow_array::{
builder::{ListBuilder, StringBuilder},
Array, BooleanArray, FixedSizeListArray, Float32Array, Int32Array, LargeStringArray,
RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray,
Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array, Float64Array,
Int32Array, Int64Array, LargeStringArray, RecordBatch, RecordBatchIterator,
RecordBatchReader, StringArray, TimestampMillisecondArray, TimestampNanosecondArray,
UInt32Array,
};
use arrow_array::{BinaryArray, LargeBinaryArray};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{DataType, Field, Schema};
use arrow_schema::{DataType, Field, Schema, TimeUnit};
use futures::TryStreamExt;
use lance::dataset::WriteMode;
use lance::io::{ObjectStoreParams, WrappingObjectStore};
use lance::Dataset;
@@ -3353,6 +3426,7 @@ mod tests {
use crate::connection::ConnectBuilder;
use crate::index::scalar::{BTreeIndexBuilder, BitmapIndexBuilder};
use crate::index::vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder};
use crate::query::{ExecutableQuery, QueryBase};
#[tokio::test]
async fn test_open() {
@@ -3574,6 +3648,306 @@ mod tests {
assert_eq!(table.name(), "test");
}
#[tokio::test]
async fn test_update_with_predicate() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let conn = connect(uri)
.read_consistency_interval(Duration::from_secs(0))
.execute()
.await
.unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));
let record_batch_iter = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..10)),
Arc::new(StringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),
);
let table = conn
.create_table("my_table", record_batch_iter)
.execute()
.await
.unwrap();
table
.update()
.only_if("id > 5")
.column("name", "'foo'")
.execute()
.await
.unwrap();
let mut batches = table
.query()
.select(Select::columns(&["id", "name"]))
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
while let Some(batch) = batches.pop() {
let ids = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.iter()
.collect::<Vec<_>>();
let names = batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for (i, name) in names.iter().enumerate() {
let id = ids[i].unwrap();
let name = name.unwrap();
if id > 5 {
assert_eq!(name, "foo");
} else {
assert_eq!(name, &format!("{}", (b'a' + id as u8) as char));
}
}
}
}
#[tokio::test]
async fn test_update_all_types() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let conn = connect(uri)
.read_consistency_interval(Duration::from_secs(0))
.execute()
.await
.unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("int32", DataType::Int32, false),
Field::new("int64", DataType::Int64, false),
Field::new("uint32", DataType::UInt32, false),
Field::new("string", DataType::Utf8, false),
Field::new("large_string", DataType::LargeUtf8, false),
Field::new("float32", DataType::Float32, false),
Field::new("float64", DataType::Float64, false),
Field::new("bool", DataType::Boolean, false),
Field::new("date32", DataType::Date32, false),
Field::new(
"timestamp_ns",
DataType::Timestamp(TimeUnit::Nanosecond, None),
false,
),
Field::new(
"timestamp_ms",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new(
"vec_f32",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
false,
),
Field::new(
"vec_f64",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 2),
false,
),
]));
let record_batch_iter = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..10)),
Arc::new(Int64Array::from_iter_values(0..10)),
Arc::new(UInt32Array::from_iter_values(0..10)),
Arc::new(StringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])),
Arc::new(LargeStringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])),
Arc::new(Float32Array::from_iter_values((0..10).map(|i| i as f32))),
Arc::new(Float64Array::from_iter_values((0..10).map(|i| i as f64))),
Arc::new(Into::<BooleanArray>::into(vec![
true, false, true, false, true, false, true, false, true, false,
])),
Arc::new(Date32Array::from_iter_values(0..10)),
Arc::new(TimestampNanosecondArray::from_iter_values(0..10)),
Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
Arc::new(
create_fixed_size_list(
Float32Array::from_iter_values((0..20).map(|i| i as f32)),
2,
)
.unwrap(),
),
Arc::new(
create_fixed_size_list(
Float64Array::from_iter_values((0..20).map(|i| i as f64)),
2,
)
.unwrap(),
),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),
);
let table = conn
.create_table("my_table", record_batch_iter)
.execute()
.await
.unwrap();
// check it can do update for each type
let updates: Vec<(&str, &str)> = vec![
("string", "'foo'"),
("large_string", "'large_foo'"),
("int32", "1"),
("int64", "1"),
("uint32", "1"),
("float32", "1.0"),
("float64", "1.0"),
("bool", "true"),
("date32", "1"),
("timestamp_ns", "1"),
("timestamp_ms", "1"),
("vec_f32", "[1.0, 1.0]"),
("vec_f64", "[1.0, 1.0]"),
];
let mut update_op = table.update();
for (column, value) in updates {
update_op = update_op.column(column, value);
}
update_op.execute().await.unwrap();
let mut batches = table
.query()
.select(Select::columns(&[
"string",
"large_string",
"int32",
"int64",
"uint32",
"float32",
"float64",
"bool",
"date32",
"timestamp_ns",
"timestamp_ms",
"vec_f32",
"vec_f64",
]))
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = batches.pop().unwrap();
macro_rules! assert_column {
($column:expr, $array_type:ty, $expected:expr) => {
let array = $column
.as_any()
.downcast_ref::<$array_type>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
assert_eq!(v, Some($expected));
}
};
}
assert_column!(batch.column(0), StringArray, "foo");
assert_column!(batch.column(1), LargeStringArray, "large_foo");
assert_column!(batch.column(2), Int32Array, 1);
assert_column!(batch.column(3), Int64Array, 1);
assert_column!(batch.column(4), UInt32Array, 1);
assert_column!(batch.column(5), Float32Array, 1.0);
assert_column!(batch.column(6), Float64Array, 1.0);
assert_column!(batch.column(7), BooleanArray, true);
assert_column!(batch.column(8), Date32Array, 1);
assert_column!(batch.column(9), TimestampNanosecondArray, 1);
assert_column!(batch.column(10), TimestampMillisecondArray, 1);
let array = batch
.column(11)
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
let v = v.unwrap();
let f32array = v.as_any().downcast_ref::<Float32Array>().unwrap();
for v in f32array {
assert_eq!(v, Some(1.0));
}
}
let array = batch
.column(12)
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
let v = v.unwrap();
let f64array = v.as_any().downcast_ref::<Float64Array>().unwrap();
for v in f64array {
assert_eq!(v, Some(1.0));
}
}
}
#[tokio::test]
async fn test_update_via_expr() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let conn = connect(uri)
.read_consistency_interval(Duration::from_secs(0))
.execute()
.await
.unwrap();
let tbl = conn
.create_table("my_table", make_test_batches())
.execute()
.await
.unwrap();
assert_eq!(1, tbl.count_rows(Some("i == 0".to_string())).await.unwrap());
tbl.update().column("i", "i+1").execute().await.unwrap();
assert_eq!(0, tbl.count_rows(Some("i == 0".to_string())).await.unwrap());
}
#[derive(Default, Debug)]
struct NoOpCacheWrapper {
called: AtomicBool,

View File

@@ -3,7 +3,6 @@
//! This module contains adapters to allow LanceDB tables to be used as DataFusion table providers.
pub mod insert;
pub mod udtf;
use std::{collections::HashMap, sync::Arc};
@@ -14,12 +13,11 @@ use async_trait::async_trait;
use datafusion_catalog::{Session, TableProvider};
use datafusion_common::{DataFusionError, Result as DataFusionResult, Statistics};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_expr::{dml::InsertOp, Expr, TableProviderFilterPushDown, TableType};
use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType};
use datafusion_physical_plan::{
stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
};
use futures::{TryFutureExt, TryStreamExt};
use lance::dataset::{WriteMode, WriteParams};
use super::{AnyQuery, BaseTable};
use crate::{
@@ -252,33 +250,6 @@ impl TableProvider for BaseTableAdapter {
// TODO
None
}
async fn insert_into(
&self,
_state: &dyn Session,
input: Arc<dyn ExecutionPlan>,
insert_op: InsertOp,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
let mode = match insert_op {
InsertOp::Append => WriteMode::Append,
InsertOp::Overwrite => WriteMode::Overwrite,
InsertOp::Replace => {
return Err(DataFusionError::NotImplemented(
"Replace mode is not supported for LanceDB tables".to_string(),
))
}
};
let write_params = WriteParams {
mode,
..Default::default()
};
self.table
.create_insert_exec(input, write_params)
.await
.map_err(|e| DataFusionError::External(e.into()))
}
}
#[cfg(test)]

View File

@@ -1,446 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! DataFusion ExecutionPlan for inserting data into LanceDB tables.
use std::any::Any;
use std::sync::{Arc, LazyLock, Mutex};
use arrow_array::{RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
};
use lance::dataset::transaction::{Operation, Transaction};
use lance::dataset::{CommitBuilder, InsertBuilder, WriteParams};
use lance::Dataset;
use lance_table::format::Fragment;
use crate::table::dataset::DatasetConsistencyWrapper;
pub(crate) static COUNT_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
Arc::new(ArrowSchema::new(vec![Field::new(
"count",
DataType::UInt64,
false,
)]))
});
fn operation_fragments(operation: &Operation) -> &[Fragment] {
match operation {
Operation::Append { fragments } => fragments,
Operation::Overwrite { fragments, .. } => fragments,
_ => &[],
}
}
fn count_rows_from_operation(operation: &Operation) -> u64 {
operation_fragments(operation)
.iter()
.map(|f| f.num_rows().unwrap_or(0) as u64)
.sum()
}
fn operation_fragments_mut(operation: &mut Operation) -> &mut Vec<Fragment> {
match operation {
Operation::Append { fragments } => fragments,
Operation::Overwrite { fragments, .. } => fragments,
_ => panic!("Unsupported operation type for getting mutable fragments"),
}
}
fn merge_transactions(mut transactions: Vec<Transaction>) -> Option<Transaction> {
let mut first = transactions.pop()?;
for txn in transactions {
let first_fragments = operation_fragments_mut(&mut first.operation);
let txn_fragments = operation_fragments(&txn.operation);
first_fragments.extend_from_slice(txn_fragments);
}
Some(first)
}
/// ExecutionPlan for inserting data into a native LanceDB table.
///
/// This plan executes inserts by:
/// 1. Each partition writes data independently using InsertBuilder::execute_uncommitted_stream
/// 2. The last partition to complete commits all transactions atomically
/// 3. Returns the count of inserted rows per partition
#[derive(Debug)]
pub struct InsertExec {
ds_wrapper: DatasetConsistencyWrapper,
dataset: Arc<Dataset>,
input: Arc<dyn ExecutionPlan>,
write_params: WriteParams,
properties: PlanProperties,
partial_transactions: Arc<Mutex<Vec<Transaction>>>,
}
impl InsertExec {
pub fn new(
ds_wrapper: DatasetConsistencyWrapper,
dataset: Arc<Dataset>,
input: Arc<dyn ExecutionPlan>,
write_params: WriteParams,
) -> Self {
let schema = COUNT_SCHEMA.clone();
let num_partitions = input.output_partitioning().partition_count();
let properties = PlanProperties::new(
EquivalenceProperties::new(schema),
Partitioning::UnknownPartitioning(num_partitions),
EmissionType::Final,
Boundedness::Bounded,
);
Self {
ds_wrapper,
dataset,
input,
write_params,
properties,
partial_transactions: Arc::new(Mutex::new(Vec::with_capacity(num_partitions))),
}
}
}
impl DisplayAs for InsertExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "InsertExec: mode={:?}", self.write_params.mode)
}
DisplayFormatType::TreeRender => {
write!(f, "InsertExec")
}
}
}
}
impl ExecutionPlan for InsertExec {
fn name(&self) -> &str {
Self::static_name()
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn maintains_input_order(&self) -> Vec<bool> {
vec![false]
}
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
vec![false]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
if children.len() != 1 {
return Err(DataFusionError::Internal(
"InsertExec requires exactly one child".to_string(),
));
}
Ok(Arc::new(Self::new(
self.ds_wrapper.clone(),
self.dataset.clone(),
children[0].clone(),
self.write_params.clone(),
)))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> DataFusionResult<SendableRecordBatchStream> {
let input_stream = self.input.execute(partition, context)?;
let dataset = self.dataset.clone();
let write_params = self.write_params.clone();
let partial_transactions = self.partial_transactions.clone();
let total_partitions = self.input.output_partitioning().partition_count();
let ds_wrapper = self.ds_wrapper.clone();
let stream = futures::stream::once(async move {
let transaction = InsertBuilder::new(dataset.clone())
.with_params(&write_params)
.execute_uncommitted_stream(input_stream)
.await?;
let num_rows = count_rows_from_operation(&transaction.operation);
let to_commit = {
// Don't hold the lock over an await point.
let mut txns = partial_transactions.lock().unwrap();
txns.push(transaction);
if txns.len() == total_partitions {
Some(std::mem::take(&mut *txns))
} else {
None
}
};
if let Some(transactions) = to_commit {
if let Some(merged_txn) = merge_transactions(transactions) {
let new_dataset = CommitBuilder::new(dataset.clone())
.execute(merged_txn)
.await?;
ds_wrapper.set_latest(new_dataset).await;
}
}
Ok(RecordBatch::try_new(
COUNT_SCHEMA.clone(),
vec![Arc::new(UInt64Array::from(vec![num_rows]))],
)?)
});
Ok(Box::pin(RecordBatchStreamAdapter::new(
COUNT_SCHEMA.clone(),
stream,
)))
}
}
#[cfg(test)]
mod tests {
use std::vec;
use super::*;
use arrow_array::{record_batch, Int32Array, RecordBatchIterator};
use datafusion::prelude::SessionContext;
use datafusion_catalog::MemTable;
use tempfile::tempdir;
use crate::connect;
#[tokio::test]
async fn test_insert_via_sql() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
// Create initial table
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
let schema = batch.schema();
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let table = db
.create_table("test_insert", Box::new(reader))
.execute()
.await
.unwrap();
// Verify initial count
assert_eq!(table.count_rows(None).await.unwrap(), 3);
let ctx = SessionContext::new();
let provider =
crate::table::datafusion::BaseTableAdapter::try_new(table.base_table().clone())
.await
.unwrap();
ctx.register_table("test_insert", Arc::new(provider))
.unwrap();
ctx.sql("INSERT INTO test_insert VALUES (4), (5), (6)")
.await
.unwrap()
.collect()
.await
.unwrap();
// Verify final count
table.checkout_latest().await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 6);
}
#[tokio::test]
async fn test_insert_overwrite_via_sql() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
// Create initial table with 3 rows
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
let schema = batch.schema();
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let table = db
.create_table("test_overwrite", Box::new(reader))
.execute()
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 3);
let ctx = SessionContext::new();
let provider =
crate::table::datafusion::BaseTableAdapter::try_new(table.base_table().clone())
.await
.unwrap();
ctx.register_table("test_overwrite", Arc::new(provider))
.unwrap();
ctx.sql("INSERT OVERWRITE INTO test_overwrite VALUES (10), (20)")
.await
.unwrap()
.collect()
.await
.unwrap();
// Verify: should have 2 rows (overwritten, not appended)
table.checkout_latest().await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 2);
}
#[tokio::test]
async fn test_insert_empty_batch() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
// Create initial table
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"id",
DataType::Int32,
false,
)]));
let batches = vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap()];
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
let table = db
.create_table("test_empty", Box::new(reader))
.execute()
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 3);
let ctx = SessionContext::new();
let provider =
crate::table::datafusion::BaseTableAdapter::try_new(table.base_table().clone())
.await
.unwrap();
ctx.register_table("test_empty", Arc::new(provider))
.unwrap();
let source_schema = Arc::new(ArrowSchema::new(vec![Field::new(
"id",
DataType::Int32,
false,
)]));
// Empty batches
let source_reader = RecordBatchIterator::new(
std::iter::empty::<Result<RecordBatch, arrow_schema::ArrowError>>(),
source_schema,
);
let source_table = db
.create_table("empty_source", Box::new(source_reader))
.execute()
.await
.unwrap();
let source_provider =
crate::table::datafusion::BaseTableAdapter::try_new(source_table.base_table().clone())
.await
.unwrap();
ctx.register_table("empty_source", Arc::new(source_provider))
.unwrap();
// Execute INSERT with empty source
ctx.sql("INSERT INTO test_empty SELECT * FROM empty_source")
.await
.unwrap()
.collect()
.await
.unwrap();
// Verify: should still have 3 rows (nothing inserted)
table.checkout_latest().await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 3);
}
#[tokio::test]
async fn test_insert_multiple_batches() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
// Create initial table
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"id",
DataType::Int32,
true,
)]));
let batches =
vec![
RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from(vec![1]))])
.unwrap(),
];
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
let table = db
.create_table("test_multi_batch", Box::new(reader))
.execute()
.await
.unwrap();
let ctx = SessionContext::new();
let provider =
crate::table::datafusion::BaseTableAdapter::try_new(table.base_table().clone())
.await
.unwrap();
ctx.register_table("test_multi_batch", Arc::new(provider))
.unwrap();
// Memtable with multiple batches and multiple partitions
let source_table = MemTable::try_new(
schema.clone(),
vec![
// Partition 0
vec![
record_batch!(("id", Int32, [2, 3])).unwrap(),
record_batch!(("id", Int32, [4, 5])).unwrap(),
],
// Partition 1
vec![record_batch!(("id", Int32, [6, 7, 8])).unwrap()],
],
)
.unwrap();
ctx.register_table("multi_batch_source", Arc::new(source_table))
.unwrap();
ctx.sql("INSERT INTO test_multi_batch SELECT * FROM multi_batch_source")
.await
.unwrap()
.collect()
.await
.unwrap();
// Verify: should have 1 + 2 + 2 + 3 = 8 rows
table.checkout_latest().await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 8);
}
}

View File

@@ -100,8 +100,7 @@ impl DatasetRef {
let should_checkout = match &target_ref {
refs::Ref::Version(_, Some(target_ver)) => version != target_ver,
refs::Ref::Version(_, None) => true, // No specific version, always checkout
refs::Ref::VersionNumber(target_ver) => version != target_ver,
refs::Ref::Tag(_) => true, // Always checkout for tags
refs::Ref::Tag(_) => true, // Always checkout for tags
};
if should_checkout {

View File

@@ -1,161 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use serde::{Deserialize, Serialize};
use super::NativeTable;
use crate::Result;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct DeleteResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
/// Internal implementation of the delete logic
///
/// This logic was moved from NativeTable::delete to keep table.rs clean.
pub(crate) async fn execute_delete(table: &NativeTable, predicate: &str) -> Result<DeleteResult> {
// We access the dataset from the table. Since this is in the same module hierarchy (super),
// and 'dataset' is pub(crate), we can access it.
let mut dataset = table.dataset.get_mut().await?;
// Perform the actual delete on the Lance dataset
dataset.delete(predicate).await?;
// Return the result with the new version
Ok(DeleteResult {
version: dataset.version().version,
})
}
#[cfg(test)]
mod tests {
use crate::connect;
use arrow_array::{record_batch, Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
use crate::query::ExecutableQuery;
use futures::TryStreamExt;
#[tokio::test]
async fn test_delete_simple() {
let conn = connect("memory://").execute().await.unwrap();
// 1. Create a table with values 0 to 9
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))],
)
.unwrap();
let table = conn
.create_table(
"test_delete",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// 2. Verify initial state
assert_eq!(table.count_rows(None).await.unwrap(), 10);
// 3. Execute Delete (removes values > 5)
table.delete("i > 5").await.unwrap();
// 4. Verify results
assert_eq!(table.count_rows(None).await.unwrap(), 6); // 0, 1, 2, 3, 4, 5 remain
// 5. Verify specific data consistency
let batches = table
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &batches[0];
let array = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
// Ensure no value > 5 exists
for val in array.iter() {
assert!(val.unwrap() <= 5);
}
}
#[tokio::test]
async fn rows_removed_schema_same() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(
("id", Int32, [1, 2, 3, 4, 5]),
("name", Utf8, ["a", "b", "c", "d", "e"])
)
.unwrap();
let original_schema = batch.schema();
let table = conn
.create_table(
"test_delete_all",
RecordBatchIterator::new(vec![Ok(batch)], original_schema.clone()),
)
.execute()
.await
.unwrap();
table.delete("true").await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 0);
let current_schema = table.schema().await.unwrap();
//check if the original schema is the same as current
assert_eq!(current_schema, original_schema);
}
#[tokio::test]
async fn test_delete_false_increments_version() {
let conn = connect("memory://").execute().await.unwrap();
// Create a table with 5 rows
let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_delete_noop",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Capture the initial state (Rows = 5, Version = 1)
let initial_rows = table.count_rows(None).await.unwrap();
let initial_version = table.version().await.unwrap();
assert_eq!(initial_rows, 5);
table.delete("false").await.unwrap();
// Rows should still be 5
let current_rows = table.count_rows(None).await.unwrap();
assert_eq!(
current_rows, initial_rows,
"Data should not change when predicate is false"
);
// version check
let current_version = table.version().await.unwrap();
assert!(
current_version > initial_version,
"Table version must increment after delete operation"
);
}
}

View File

@@ -1,666 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Schema evolution operations for LanceDB tables.
//!
//! This module provides functionality to modify the schema of existing tables:
//! - [`add_columns`](execute_add_columns): Add new columns using SQL expressions
//! - [`alter_columns`](execute_alter_columns): Rename columns, change types, or modify nullability
//! - [`drop_columns`](execute_drop_columns): Remove columns from the table
use lance::dataset::{ColumnAlteration, NewColumnTransform};
use serde::{Deserialize, Serialize};
use super::NativeTable;
use crate::Result;
/// The result of an add columns operation.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct AddColumnsResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
/// The result of an alter columns operation.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct AlterColumnsResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
/// The result of a drop columns operation.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct DropColumnsResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
/// Internal implementation of the add columns logic.
///
/// Adds new columns to the table using the provided transforms.
pub(crate) async fn execute_add_columns(
table: &NativeTable,
transforms: NewColumnTransform,
read_columns: Option<Vec<String>>,
) -> Result<AddColumnsResult> {
let mut dataset = table.dataset.get_mut().await?;
dataset.add_columns(transforms, read_columns, None).await?;
Ok(AddColumnsResult {
version: dataset.version().version,
})
}
/// Internal implementation of the alter columns logic.
///
/// Alters existing columns in the table (rename, change type, or modify nullability).
pub(crate) async fn execute_alter_columns(
table: &NativeTable,
alterations: &[ColumnAlteration],
) -> Result<AlterColumnsResult> {
let mut dataset = table.dataset.get_mut().await?;
dataset.alter_columns(alterations).await?;
Ok(AlterColumnsResult {
version: dataset.version().version,
})
}
/// Internal implementation of the drop columns logic.
///
/// Removes columns from the table.
pub(crate) async fn execute_drop_columns(
table: &NativeTable,
columns: &[&str],
) -> Result<DropColumnsResult> {
let mut dataset = table.dataset.get_mut().await?;
dataset.drop_columns(columns).await?;
Ok(DropColumnsResult {
version: dataset.version().version,
})
}
#[cfg(test)]
mod tests {
use arrow_array::{record_batch, Int32Array, RecordBatchIterator, StringArray};
use arrow_schema::DataType;
use futures::TryStreamExt;
use lance::dataset::ColumnAlteration;
use crate::connect;
use crate::query::{ExecutableQuery, QueryBase, Select};
use crate::table::NewColumnTransform;
// Add Columns Tests
#[tokio::test]
async fn test_add_columns_with_sql_expression() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_add_columns",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
let initial_version = table.version().await.unwrap();
// Add a computed column
let result = table
.add_columns(
NewColumnTransform::SqlExpressions(vec![("doubled".into(), "id * 2".into())]),
None,
)
.await
.unwrap();
// Version should increment
assert!(result.version > initial_version);
// Verify the new column exists with correct values
let batches = table
.query()
.select(Select::columns(&["id", "doubled"]))
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &batches[0];
let ids: Vec<i32> = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.iter()
.map(|v| v.unwrap())
.collect();
let doubled: Vec<i32> = batch
.column(1)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.iter()
.map(|v| v.unwrap())
.collect();
for (id, d) in ids.iter().zip(doubled.iter()) {
assert_eq!(*d, id * 2);
}
}
#[tokio::test]
async fn test_add_multiple_columns() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("x", Int32, [10, 20, 30])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_add_multi_columns",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Add multiple columns at once
table
.add_columns(
NewColumnTransform::SqlExpressions(vec![
("y".into(), "x + 1".into()),
("z".into(), "x * x".into()),
]),
None,
)
.await
.unwrap();
// Verify schema has all columns
let schema = table.schema().await.unwrap();
assert_eq!(schema.fields().len(), 3);
assert!(schema.field_with_name("x").is_ok());
assert!(schema.field_with_name("y").is_ok());
assert!(schema.field_with_name("z").is_ok());
}
#[tokio::test]
async fn test_add_column_with_constant_expression() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_add_const_column",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Add a column with a constant value
table
.add_columns(
NewColumnTransform::SqlExpressions(vec![("constant".into(), "42".into())]),
None,
)
.await
.unwrap();
let schema = table.schema().await.unwrap();
assert!(schema.field_with_name("constant").is_ok());
// Verify all values are 42
let batches = table
.query()
.select(Select::columns(&["constant"]))
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &batches[0];
let values = batch["constant"]
.as_any()
.downcast_ref::<arrow_array::Int64Array>()
.unwrap()
.values();
assert!(values.iter().all(|&v| v == 42));
}
// Alter Columns Tests
#[tokio::test]
async fn test_alter_column_rename() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("old_name", Int32, [1, 2, 3])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_alter_rename",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
let initial_version = table.version().await.unwrap();
// Rename the column
let result = table
.alter_columns(&[ColumnAlteration::new("old_name".into()).rename("new_name".into())])
.await
.unwrap();
// Version should increment
assert!(result.version > initial_version);
// Verify rename
let schema = table.schema().await.unwrap();
assert!(schema.field_with_name("old_name").is_err());
assert!(schema.field_with_name("new_name").is_ok());
}
#[tokio::test]
async fn test_alter_column_set_nullable() {
use arrow_array::RecordBatch;
use arrow_schema::{Field, Schema};
use std::sync::Arc;
let conn = connect("memory://").execute().await.unwrap();
// Create a schema with a non-nullable field
let schema = Arc::new(Schema::new(vec![Field::new(
"value",
DataType::Int32,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let table = conn
.create_table(
"test_alter_nullable",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Initially non-nullable
let schema = table.schema().await.unwrap();
assert!(!schema.field_with_name("value").unwrap().is_nullable());
// Make it nullable
table
.alter_columns(&[ColumnAlteration::new("value".into()).set_nullable(true)])
.await
.unwrap();
// Verify it's now nullable
let schema = table.schema().await.unwrap();
assert!(schema.field_with_name("value").unwrap().is_nullable());
}
#[tokio::test]
async fn test_alter_column_cast_type() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("num", Int32, [1, 2, 3])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_cast_type",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Cast Int32 to Int64 (a supported cast)
table
.alter_columns(&[ColumnAlteration::new("num".into()).cast_to(DataType::Int64)])
.await
.unwrap();
// Verify type changed
let schema = table.schema().await.unwrap();
assert_eq!(
schema.field_with_name("num").unwrap().data_type(),
&DataType::Int64
);
// Query the data and verify the returned type is correct
let batches = table
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &batches[0];
let values = batch["num"]
.as_any()
.downcast_ref::<arrow_array::Int64Array>()
.unwrap()
.values();
assert_eq!(values.as_ref(), &[1i64, 2, 3]);
}
#[tokio::test]
async fn test_alter_column_invalid_cast_fails() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("num", Int32, [1, 2, 3])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_invalid_cast",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Casting Int32 to Float64 is not supported
let result = table
.alter_columns(&[ColumnAlteration::new("num".into()).cast_to(DataType::Float64)])
.await;
let err = result.unwrap_err();
assert!(
err.to_string().contains("cast"),
"Expected error message to contain 'cast', got: {}",
err
);
}
#[tokio::test]
async fn test_alter_multiple_columns() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("a", Int32, [1, 2, 3]), ("b", Int32, [4, 5, 6])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_alter_multi",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Alter multiple columns at once
table
.alter_columns(&[
ColumnAlteration::new("a".into()).rename("alpha".into()),
ColumnAlteration::new("b".into()).set_nullable(true),
])
.await
.unwrap();
let schema = table.schema().await.unwrap();
assert!(schema.field_with_name("alpha").is_ok());
assert!(schema.field_with_name("a").is_err());
assert!(schema.field_with_name("b").unwrap().is_nullable());
}
// Drop Columns Tests
#[tokio::test]
async fn test_drop_single_column() {
let conn = connect("memory://").execute().await.unwrap();
let batch =
record_batch!(("keep", Int32, [1, 2, 3]), ("remove", Int32, [4, 5, 6])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_drop_single",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
let initial_version = table.version().await.unwrap();
// Drop a column
let result = table.drop_columns(&["remove"]).await.unwrap();
// Version should increment
assert!(result.version > initial_version);
// Verify column was dropped
let schema = table.schema().await.unwrap();
assert_eq!(schema.fields().len(), 1);
assert!(schema.field_with_name("keep").is_ok());
assert!(schema.field_with_name("remove").is_err());
}
#[tokio::test]
async fn test_drop_multiple_columns() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(
("a", Int32, [1, 2]),
("b", Int32, [3, 4]),
("c", Int32, [5, 6]),
("d", Int32, [7, 8])
)
.unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_drop_multi",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Drop multiple columns
table.drop_columns(&["b", "d"]).await.unwrap();
// Verify only a and c remain
let schema = table.schema().await.unwrap();
assert_eq!(schema.fields().len(), 2);
assert!(schema.field_with_name("a").is_ok());
assert!(schema.field_with_name("c").is_ok());
assert!(schema.field_with_name("b").is_err());
assert!(schema.field_with_name("d").is_err());
}
#[tokio::test]
async fn test_drop_column_preserves_data() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(
("id", Int32, [1, 2, 3]),
("name", Utf8, ["a", "b", "c"]),
("extra", Int32, [10, 20, 30])
)
.unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_drop_preserves",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Drop the extra column
table.drop_columns(&["extra"]).await.unwrap();
// Verify remaining data is intact
let batches = table
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &batches[0];
assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.num_rows(), 3);
let ids: Vec<i32> = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.iter()
.map(|v| v.unwrap())
.collect();
assert_eq!(ids, vec![1, 2, 3]);
let names: Vec<&str> = batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.map(|v| v.unwrap())
.collect();
assert_eq!(names, vec!["a", "b", "c"]);
}
// Error Case Tests
#[tokio::test]
async fn test_drop_nonexistent_column_fails() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("existing", Int32, [1, 2, 3])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_drop_nonexistent",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Try to drop a column that doesn't exist
let result = table.drop_columns(&["nonexistent"]).await;
let err = result.unwrap_err();
assert!(
err.to_string().contains("nonexistent"),
"Expected error message to contain column name 'nonexistent', got: {}",
err
);
}
#[tokio::test]
async fn test_alter_nonexistent_column_fails() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("existing", Int32, [1, 2, 3])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_alter_nonexistent",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Try to alter a column that doesn't exist
let result = table
.alter_columns(&[ColumnAlteration::new("nonexistent".into()).rename("new".into())])
.await;
let err = result.unwrap_err();
assert!(
err.to_string().contains("nonexistent"),
"Expected error message to contain column name 'nonexistent', got: {}",
err
);
}
// Version Tracking Tests
#[tokio::test]
async fn test_schema_operations_increment_version() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("a", Int32, [1, 2, 3]), ("b", Int32, [4, 5, 6])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_version_increment",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
let v1 = table.version().await.unwrap();
// Add column increments version
let add_result = table
.add_columns(
NewColumnTransform::SqlExpressions(vec![("c".into(), "a + b".into())]),
None,
)
.await
.unwrap();
assert!(add_result.version > v1);
let v2 = table.version().await.unwrap();
assert_eq!(add_result.version, v2);
// Alter column increments version
let alter_result = table
.alter_columns(&[ColumnAlteration::new("c".into()).rename("sum".into())])
.await
.unwrap();
assert!(alter_result.version > v2);
let v3 = table.version().await.unwrap();
assert_eq!(alter_result.version, v3);
// Drop column increments version
let drop_result = table.drop_columns(&["b"]).await.unwrap();
assert!(drop_result.version > v3);
let v4 = table.version().await.unwrap();
assert_eq!(drop_result.version, v4);
}
}

View File

@@ -1,441 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::Arc;
use lance::dataset::UpdateBuilder as LanceUpdateBuilder;
use serde::{Deserialize, Serialize};
use super::{BaseTable, NativeTable};
use crate::Error;
use crate::Result;
/// The result of an update operation
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct UpdateResult {
#[serde(default)]
pub rows_updated: u64,
/// The commit version associated with the operation.
#[serde(default)]
pub version: u64,
}
/// A builder for configuring a [`crate::table::Table::update`] operation
#[derive(Debug, Clone)]
pub struct UpdateBuilder {
parent: Arc<dyn BaseTable>,
pub(crate) filter: Option<String>,
pub(crate) columns: Vec<(String, String)>,
}
impl UpdateBuilder {
pub(crate) fn new(parent: Arc<dyn BaseTable>) -> Self {
Self {
parent,
filter: None,
columns: Vec::new(),
}
}
/// Limits the update operation to rows matching the given filter
///
/// If a row does not match the filter then it will be left unchanged.
pub fn only_if(mut self, filter: impl Into<String>) -> Self {
self.filter = Some(filter.into());
self
}
/// Specifies a column to update
///
/// This method may be called multiple times to update multiple columns
///
/// The `update_expr` should be an SQL expression explaining how to calculate
/// the new value for the column. The expression will be evaluated against the
/// previous row's value.
pub fn column(
mut self,
column_name: impl Into<String>,
update_expr: impl Into<String>,
) -> Self {
self.columns.push((column_name.into(), update_expr.into()));
self
}
/// Executes the update operation.
pub async fn execute(self) -> Result<UpdateResult> {
if self.columns.is_empty() {
Err(Error::InvalidInput {
message: "at least one column must be specified in an update operation".to_string(),
})
} else {
self.parent.clone().update(self).await
}
}
}
/// Internal implementation of the update logic
pub(crate) async fn execute_update(
table: &NativeTable,
update: UpdateBuilder,
) -> Result<UpdateResult> {
// 1. Snapshot the current dataset
let dataset = table.dataset.get().await?.clone();
// 2. Initialize the Lance Core builder
let mut builder = LanceUpdateBuilder::new(Arc::new(dataset));
// 3. Apply the filter (WHERE clause)
if let Some(predicate) = update.filter {
builder = builder.update_where(&predicate)?;
}
// 4. Apply the columns (SET clause)
for (column, value) in update.columns {
builder = builder.set(column, &value)?;
}
// 5. Execute the operation (Write new files)
let operation = builder.build()?;
let res = operation.execute().await?;
// 6. Update the table's view of the latest version
table
.dataset
.set_latest(res.new_dataset.as_ref().clone())
.await;
Ok(UpdateResult {
rows_updated: res.rows_updated,
version: res.new_dataset.version().version,
})
}
#[cfg(test)]
mod tests {
use crate::connect;
use crate::query::QueryBase;
use crate::query::{ExecutableQuery, Select};
use arrow_array::{
record_batch, Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array,
Float64Array, Int32Array, Int64Array, LargeStringArray, RecordBatch, RecordBatchIterator,
RecordBatchReader, StringArray, TimestampMillisecondArray, TimestampNanosecondArray,
UInt32Array,
};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{ArrowError, DataType, Field, Schema, TimeUnit};
use futures::TryStreamExt;
use std::sync::Arc;
use std::time::Duration;
#[tokio::test]
async fn test_update_all_types() {
let conn = connect("memory://")
.read_consistency_interval(Duration::from_secs(0))
.execute()
.await
.unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("int32", DataType::Int32, false),
Field::new("int64", DataType::Int64, false),
Field::new("uint32", DataType::UInt32, false),
Field::new("string", DataType::Utf8, false),
Field::new("large_string", DataType::LargeUtf8, false),
Field::new("float32", DataType::Float32, false),
Field::new("float64", DataType::Float64, false),
Field::new("bool", DataType::Boolean, false),
Field::new("date32", DataType::Date32, false),
Field::new(
"timestamp_ns",
DataType::Timestamp(TimeUnit::Nanosecond, None),
false,
),
Field::new(
"timestamp_ms",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new(
"vec_f32",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
false,
),
Field::new(
"vec_f64",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 2),
false,
),
]));
let record_batch_iter = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..10)),
Arc::new(Int64Array::from_iter_values(0..10)),
Arc::new(UInt32Array::from_iter_values(0..10)),
Arc::new(StringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])),
Arc::new(LargeStringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])),
Arc::new(Float32Array::from_iter_values((0..10).map(|i| i as f32))),
Arc::new(Float64Array::from_iter_values((0..10).map(|i| i as f64))),
Arc::new(Into::<BooleanArray>::into(vec![
true, false, true, false, true, false, true, false, true, false,
])),
Arc::new(Date32Array::from_iter_values(0..10)),
Arc::new(TimestampNanosecondArray::from_iter_values(0..10)),
Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
Arc::new(
create_fixed_size_list(
Float32Array::from_iter_values((0..20).map(|i| i as f32)),
2,
)
.unwrap(),
),
Arc::new(
create_fixed_size_list(
Float64Array::from_iter_values((0..20).map(|i| i as f64)),
2,
)
.unwrap(),
),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),
);
let table = conn
.create_table("my_table", record_batch_iter)
.execute()
.await
.unwrap();
// check it can do update for each type
let updates: Vec<(&str, &str)> = vec![
("string", "'foo'"),
("large_string", "'large_foo'"),
("int32", "1"),
("int64", "1"),
("uint32", "1"),
("float32", "1.0"),
("float64", "1.0"),
("bool", "true"),
("date32", "1"),
("timestamp_ns", "1"),
("timestamp_ms", "1"),
("vec_f32", "[1.0, 1.0]"),
("vec_f64", "[1.0, 1.0]"),
];
let mut update_op = table.update();
for (column, value) in updates {
update_op = update_op.column(column, value);
}
update_op.execute().await.unwrap();
let mut batches = table
.query()
.select(Select::columns(&[
"string",
"large_string",
"int32",
"int64",
"uint32",
"float32",
"float64",
"bool",
"date32",
"timestamp_ns",
"timestamp_ms",
"vec_f32",
"vec_f64",
]))
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = batches.pop().unwrap();
macro_rules! assert_column {
($column:expr, $array_type:ty, $expected:expr) => {
let array = $column
.as_any()
.downcast_ref::<$array_type>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
assert_eq!(v, Some($expected));
}
};
}
assert_column!(batch.column(0), StringArray, "foo");
assert_column!(batch.column(1), LargeStringArray, "large_foo");
assert_column!(batch.column(2), Int32Array, 1);
assert_column!(batch.column(3), Int64Array, 1);
assert_column!(batch.column(4), UInt32Array, 1);
assert_column!(batch.column(5), Float32Array, 1.0);
assert_column!(batch.column(6), Float64Array, 1.0);
assert_column!(batch.column(7), BooleanArray, true);
assert_column!(batch.column(8), Date32Array, 1);
assert_column!(batch.column(9), TimestampNanosecondArray, 1);
assert_column!(batch.column(10), TimestampMillisecondArray, 1);
let array = batch
.column(11)
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
let v = v.unwrap();
let f32array = v.as_any().downcast_ref::<Float32Array>().unwrap();
for v in f32array {
assert_eq!(v, Some(1.0));
}
}
let array = batch
.column(12)
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
let v = v.unwrap();
let f64array = v.as_any().downcast_ref::<Float64Array>().unwrap();
for v in f64array {
assert_eq!(v, Some(1.0));
}
}
}
///Two helper functions
fn create_fixed_size_list<T: Array>(
values: T,
list_size: i32,
) -> Result<FixedSizeListArray, ArrowError> {
let list_type = DataType::FixedSizeList(
Arc::new(Field::new("item", values.data_type().clone(), true)),
list_size,
);
let data = ArrayDataBuilder::new(list_type)
.len(values.len() / list_size as usize)
.add_child_data(values.into_data())
.build()
.unwrap();
Ok(FixedSizeListArray::from(data))
}
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))],
)],
schema,
)
}
#[tokio::test]
async fn test_update_with_predicate() {
let conn = connect("memory://")
.read_consistency_interval(Duration::from_secs(0))
.execute()
.await
.unwrap();
let batch = record_batch!(
("id", Int32, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
(
"name",
Utf8,
["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]
)
)
.unwrap();
let schema = batch.schema();
// need the iterator for create table
let record_batch_iter = RecordBatchIterator::new(vec![Ok(batch)], schema);
let table = conn
.create_table("my_table", record_batch_iter)
.execute()
.await
.unwrap();
table
.update()
.only_if("id > 5")
.column("name", "'foo'")
.execute()
.await
.unwrap();
let mut batches = table
.query()
.select(Select::columns(&["id", "name"]))
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
while let Some(batch) = batches.pop() {
let ids = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.iter()
.collect::<Vec<_>>();
let names = batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for (i, name) in names.iter().enumerate() {
let id = ids[i].unwrap();
let name = name.unwrap();
if id > 5 {
assert_eq!(name, "foo");
} else {
assert_eq!(name, &format!("{}", (b'a' + id as u8) as char));
}
}
}
}
#[tokio::test]
async fn test_update_via_expr() {
let conn = connect("memory://")
.read_consistency_interval(Duration::from_secs(0))
.execute()
.await
.unwrap();
let tbl = conn
.create_table("my_table", make_test_batches())
.execute()
.await
.unwrap();
assert_eq!(1, tbl.count_rows(Some("i == 0".to_string())).await.unwrap());
tbl.update().column("i", "i+1").execute().await.unwrap();
assert_eq!(0, tbl.count_rows(Some("i == 0".to_string())).await.unwrap());
}
}