Compare commits

..

16 Commits

Author SHA1 Message Date
Brendan Clement
5696df2791 test: assert branch handle read-consistency behavior 2026-06-03 13:02:27 -07:00
Brendan Clement
09518a3c1b test: skip async namespace branch test without lance 2026-06-03 11:55:01 -07:00
Brendan Clement
59824ab438 fix: address review comments on branch support 2026-06-03 11:18:14 -07:00
Brendan Clement
88c48a1bf0 docs: trim branch API comments 2026-06-03 10:07:57 -07:00
Brendan Clement
735a7ce6fe fix: validate branch inputs (empty names, negative versions) 2026-06-03 09:57:45 -07:00
Brendan Clement
1ee490d125 fix(python): skip server-side query pushdown on branch handles 2026-06-03 09:25:33 -07:00
Brendan Clement
08745dc1e1 test(python): guard pylance-dependent branch tests with importorskip 2026-06-02 23:54:40 -07:00
Brendan Clement
2660f96475 fix(python): branch-scope to_lance so branch handles don't read main 2026-06-02 23:20:47 -07:00
Brendan Clement
d96ae4b986 fix(python): keep open_table override signatures compatible with branch 2026-06-02 23:20:47 -07:00
Brendan Clement
38454969cd feat: support opening a branch directly via open_table 2026-06-02 22:40:13 -07:00
Brendan Clement
c13c3184cf address sync / namespace issue in python sdk 2026-06-02 22:04:31 -07:00
Brendan Clement
a7a7350eb3 feat(typescript): export Branches from the public API 2026-06-02 17:46:04 -07:00
Brendan Clement
c3c2887c02 refactor: use Self in branch method return types 2026-06-02 17:39:55 -07:00
Brendan Clement
2ca6d41f17 feat(typescript): add table branch support 2026-06-02 17:29:41 -07:00
Brendan Clement
341cb04c2f feat(python): add table branch support 2026-06-02 17:11:47 -07:00
Brendan Clement
0d4cb346f9 feat: add table branch support to the Rust core 2026-06-02 16:35:53 -07:00
57 changed files with 3454 additions and 4035 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.30.1-beta.2"
current_version = "0.30.1-beta.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -21,14 +21,3 @@ updates:
update-types:
- minor
- patch
- package-ecosystem: pip
directory: /python
schedule:
interval: weekly
# Only update uv.lock, never widen version requirements in pyproject.toml.
versioning-strategy: lockfile-only
groups:
python-deps:
patterns:
- "*"

527
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -13,20 +13,20 @@ categories = ["database-implementations"]
rust-version = "1.91.0"
[workspace.dependencies]
lance = { "version" = "=8.0.0-beta.4", default-features = false, "tag" = "v8.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=8.0.0-beta.4", "tag" = "v8.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=8.0.0-beta.4", "tag" = "v8.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=8.0.0-beta.4", "tag" = "v8.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=8.0.0-beta.4", default-features = false, "tag" = "v8.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=8.0.0-beta.4", "tag" = "v8.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=8.0.0-beta.4", "tag" = "v8.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=8.0.0-beta.4", "tag" = "v8.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=8.0.0-beta.4", default-features = false, "tag" = "v8.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=8.0.0-beta.4", "tag" = "v8.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=8.0.0-beta.4", "tag" = "v8.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=8.0.0-beta.4", "tag" = "v8.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=8.0.0-beta.4", "tag" = "v8.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=8.0.0-beta.4", "tag" = "v8.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance = { "version" = "=7.2.0-beta.3", default-features = false, "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=7.2.0-beta.3", default-features = false, "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=7.2.0-beta.3", default-features = false, "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "58.0.0", optional = false }

View File

@@ -1,26 +0,0 @@
# Code review guidelines
Repo-specific guidance for automated PR reviews.
## Cross-SDK parity
LanceDB exposes the same core (`rust/lancedb`) through Python, TypeScript (`nodejs`),
and Java bindings. Behavioral drift between SDKs is a recurring problem, so watch for
parity gaps when reviewing — but only flag real ones:
* If the change adds or modifies user-facing API or behavior in the shared core
(`rust/lancedb`), check whether each binding that should expose it (`python`,
`nodejs`) does. A core change with no corresponding binding update is worth a note.
* If the change adds or modifies a public API in one SDK but not the other, open the
sibling SDK's corresponding module and state whether an equivalent exists. If not,
note it as a possible parity gap and suggest a follow-up issue.
* For bug fixes, first read the sibling SDK's analogous code path to check whether the
same bug exists there. Only raise parity if it actually does. Do not ask to "port" a
fix for a bug that only ever existed in one binding.
* Stay silent on internal-only refactors, tests, docs, and changes with no cross-SDK
surface.
* Parity expectations apply to the Python and TypeScript (`nodejs`) SDKs. Java currently
implements only the remote table, not the local/embedded backend, so it is expected to
be partial — do not flag Java for missing local-only functionality.
* Keep parity feedback to a short, clearly-labeled note (e.g. "Possible SDK parity
gap: …"). It is advisory, not a merge blocker.

View File

@@ -147,14 +147,6 @@ allow = [
"CDLA-Permissive-2.0",
]
confidence-threshold = 0.8
# Per-crate license exceptions: allow a license for a specific crate only,
# rather than globally via the `allow` list above.
exceptions = [
# CDDL-1.0 (copyleft) is pulled in only as a dev/profiling dependency via
# `inferno` -> `pprof` -> `lance-testing`; it is a test dependency that we
# do not distribute, so scope the allowance to `inferno` alone.
{ allow = ["CDDL-1.0"], crate = "inferno" },
]
# Crates whose license cannot be determined from Cargo metadata but whose
# license we've manually confirmed from upstream. Keep this list minimal.
[[licenses.clarify]]

View File

@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
<dependency>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-core</artifactId>
<version>0.30.1-beta.2</version>
<version>0.30.1-beta.0</version>
</dependency>
```

View File

@@ -0,0 +1,43 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / BranchContents
# Class: BranchContents
## Constructors
### new BranchContents()
```ts
new BranchContents(): BranchContents
```
#### Returns
[`BranchContents`](BranchContents.md)
## Properties
### manifestSize
```ts
manifestSize: number;
```
***
### parentBranch?
```ts
optional parentBranch: string;
```
***
### parentVersion
```ts
parentVersion: number;
```

View File

@@ -0,0 +1,90 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / Branches
# Class: Branches
Branch manager for a [Table](Table.md).
Unlike tags, `create` and `checkout` return a new [Table](Table.md) handle scoped
to the branch; writes on it do not affect `main`.
## Methods
### checkout()
```ts
checkout(name): Promise<Table>
```
Check out an existing branch and return a handle scoped to it.
#### Parameters
* **name**: `string`
#### Returns
`Promise`&lt;[`Table`](Table.md)&gt;
***
### create()
```ts
create(
name,
fromRef?,
fromVersion?): Promise<Table>
```
Create a branch and return a handle scoped to it.
#### Parameters
* **name**: `string`
Name of the new branch.
* **fromRef?**: `string`
Source branch to fork from. Defaults to `main`.
* **fromVersion?**: `number`
A specific version on `fromRef`. Defaults to latest.
#### Returns
`Promise`&lt;[`Table`](Table.md)&gt;
***
### delete()
```ts
delete(name): Promise<void>
```
Delete a branch.
#### Parameters
* **name**: `string`
#### Returns
`Promise`&lt;`void`&gt;
***
### list()
```ts
list(): Promise<Record<string, BranchContents>>
```
List all branches, mapping name to branch metadata.
#### Returns
`Promise`&lt;`Record`&lt;`string`, [`BranchContents`](BranchContents.md)&gt;&gt;

View File

@@ -110,6 +110,23 @@ containing the new version number of the table after altering the columns.
***
### branches()
```ts
abstract branches(): Promise<Branches>
```
Get the branch manager for this table.
Branches are isolated, writable lines of history forked from another
branch (or version). Writes on a branch do not affect `main`.
#### Returns
`Promise`&lt;[`Branches`](Branches.md)&gt;
***
### checkout()
```ts

View File

@@ -19,6 +19,8 @@
- [BooleanQuery](classes/BooleanQuery.md)
- [BoostQuery](classes/BoostQuery.md)
- [BranchContents](classes/BranchContents.md)
- [Branches](classes/Branches.md)
- [Connection](classes/Connection.md)
- [HeaderProvider](classes/HeaderProvider.md)
- [Index](classes/Index.md)

View File

@@ -8,6 +8,18 @@
## Properties
### branch?
```ts
optional branch: string;
```
Open the table scoped to this branch instead of the default branch.
Reads and writes on the returned table operate in the branch's context.
***
### ~~indexCacheSize?~~
```ts

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.30.1-beta.2</version>
<version>0.30.1-beta.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.30.1-beta.2</version>
<version>0.30.1-beta.0</version>
<packaging>pom</packaging>
<name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description>
@@ -28,7 +28,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version>
<lance-core.version>8.0.0-beta.4</lance-core.version>
<lance-core.version>7.2.0-beta.1</lance-core.version>
<spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.30.1-beta.2"
version = "0.30.1-beta.0"
publish = false
license.workspace = true
description.workspace = true

View File

@@ -85,6 +85,64 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
await expect(table.countRows()).resolves.toBe(3);
});
it("should support branches", async () => {
await table.add([{ id: 1 }]);
expect(await table.countRows()).toBe(1);
// fork an isolated, writable branch from main
const branch = await (await table.branches()).create("exp");
expect(await branch.countRows()).toBe(1);
await branch.add([{ id: 2 }]);
expect(await branch.countRows()).toBe(2);
// main is untouched by branch writes
expect(await table.countRows()).toBe(1);
// listed, with main (null) as the parent
const list = await (await table.branches()).list();
expect(Object.keys(list)).toContain("exp");
expect(list["exp"].parentBranch).toBeNull();
// fromRef="main" is equivalent to the default
await (await table.branches()).create("exp2", "main");
const list2 = await (await table.branches()).list();
expect(list2["exp2"].parentBranch).toBeNull();
// checkout returns a handle scoped to the branch's latest
const checkedOut = await (await table.branches()).checkout("exp");
expect(await checkedOut.countRows()).toBe(2);
// delete removes it
await (await table.branches()).delete("exp");
await (await table.branches()).delete("exp2");
const after = await (await table.branches()).list();
expect(Object.keys(after)).not.toContain("exp");
});
it("should open a branch via open_table", async () => {
const db = await connect(tmpDir.name);
await table.add([{ id: 1 }]);
const branch = await (await table.branches()).create("exp");
await branch.add([{ id: 2 }]);
// open_table(..., { branch }) returns a handle scoped to the branch
const opened = await db.openTable("some_table", undefined, {
branch: "exp",
});
expect(await opened.countRows()).toBe(2);
// opening without branch still tracks main
expect(await (await db.openTable("some_table")).countRows()).toBe(1);
});
it("rejects invalid branch inputs", async () => {
const branches = await table.branches();
await expect(branches.create("")).rejects.toThrow("non-empty");
await expect(branches.checkout("")).rejects.toThrow("non-empty");
await expect(branches.delete("")).rejects.toThrow("non-empty");
await expect(branches.create("bad", "main", -1)).rejects.toThrow(
"non-negative",
);
});
it("should show table stats", async () => {
await table.add([{ id: 1 }, { id: 2 }]);
await table.add([{ id: 1 }]);

View File

@@ -84,6 +84,12 @@ export interface CreateTableOptions {
}
export interface OpenTableOptions {
/**
* Open the table scoped to this branch instead of the default branch.
*
* Reads and writes on the returned table operate in the branch's context.
*/
branch?: string;
/**
* Configuration for object storage.
*
@@ -483,7 +489,11 @@ export class LocalConnection extends Connection {
options?.indexCacheSize,
);
return new LocalTable(innerTable);
const table = new LocalTable(innerTable);
if (options?.branch != null) {
return (await table.branches()).checkout(options.branch);
}
return table;
}
async cloneTable(

View File

@@ -38,6 +38,7 @@ export {
FragmentSummaryStats,
Tags,
TagContents,
BranchContents,
MergeResult,
AddResult,
AddColumnsResult,
@@ -111,6 +112,7 @@ export {
export {
Table,
Branches,
AddDataOptions,
UpdateOptions,
OptimizeOptions,

View File

@@ -25,10 +25,12 @@ import {
AddColumnsSql,
AddResult,
AlterColumnsResult,
BranchContents,
DeleteResult,
DropColumnsResult,
IndexConfig,
IndexStatistics,
Branches as NativeBranches,
OptimizeStats,
TableStatistics,
Tags,
@@ -653,6 +655,14 @@ export abstract class Table {
*/
abstract tags(): Promise<Tags>;
/**
* Get the branch manager for this table.
*
* Branches are isolated, writable lines of history forked from another
* branch (or version). Writes on a branch do not affect `main`.
*/
abstract branches(): Promise<Branches>;
/**
* Restore the table to the currently checked out version
*
@@ -1108,6 +1118,10 @@ export class LocalTable extends Table {
return await this.inner.tags();
}
async branches(): Promise<Branches> {
return new Branches(await this.inner.branches());
}
async optimize(options?: Partial<OptimizeOptions>): Promise<OptimizeStats> {
let cleanupOlderThanMs;
if (
@@ -1238,3 +1252,51 @@ export interface FieldMetadataUpdate {
/** If true, replace the field's entire metadata map instead of merging. */
replace?: boolean;
}
/**
* Branch manager for a {@link Table}.
*
* Unlike tags, `create` and `checkout` return a new {@link Table} handle scoped
* to the branch; writes on it do not affect `main`.
*/
export class Branches {
#inner: NativeBranches;
/**
* Construct a Branches manager. Internal use only.
* @hidden
*/
constructor(inner: NativeBranches) {
this.#inner = inner;
}
/** List all branches, mapping name to branch metadata. */
async list(): Promise<Record<string, BranchContents>> {
return await this.#inner.list();
}
/**
* Create a branch and return a handle scoped to it.
*
* @param name Name of the new branch.
* @param fromRef Source branch to fork from. Defaults to `main`.
* @param fromVersion A specific version on `fromRef`. Defaults to latest.
*/
async create(
name: string,
fromRef?: string,
fromVersion?: number,
): Promise<Table> {
return new LocalTable(await this.#inner.create(name, fromRef, fromVersion));
}
/** Check out an existing branch and return a handle scoped to it. */
async checkout(name: string): Promise<Table> {
return new LocalTable(await this.#inner.checkout(name));
}
/** Delete a branch. */
async delete(name: string): Promise<void> {
return await this.#inner.delete(name);
}
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.30.1-beta.2",
"version": "0.30.1-beta.0",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.30.1-beta.2",
"version": "0.30.1-beta.0",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.30.1-beta.2",
"version": "0.30.1-beta.0",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.30.1-beta.2",
"version": "0.30.1-beta.0",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.30.1-beta.2",
"version": "0.30.1-beta.0",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.30.1-beta.2",
"version": "0.30.1-beta.0",
"os": [
"win32"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.30.1-beta.2",
"version": "0.30.1-beta.0",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.30.1-beta.2",
"version": "0.30.1-beta.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.30.1-beta.2",
"version": "0.30.1-beta.0",
"cpu": [
"x64",
"arm64"

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.30.1-beta.2",
"version": "0.30.1-beta.0",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -7,7 +7,7 @@ use lancedb::ipc::{ipc_file_to_batches, ipc_file_to_schema};
use lancedb::table::{
AddDataMode, ColumnAlteration as LanceColumnAlteration, Duration,
FieldMetadataUpdate as LanceFieldMetadataUpdate, NewColumnTransform, OptimizeAction,
OptimizeOptions, Table as LanceDbTable,
OptimizeOptions, Ref, Table as LanceDbTable,
};
use napi::bindgen_prelude::*;
use napi::threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode};
@@ -478,6 +478,13 @@ impl Table {
})
}
#[napi(catch_unwind)]
pub async fn branches(&self) -> napi::Result<Branches> {
Ok(Branches {
inner: self.inner_ref()?.clone(),
})
}
#[napi(catch_unwind)]
pub async fn optimize(
&self,
@@ -1060,6 +1067,13 @@ pub struct TagContents {
pub manifest_size: i64,
}
#[napi]
pub struct BranchContents {
pub parent_branch: Option<String>,
pub parent_version: i64,
pub manifest_size: i64,
}
#[napi]
pub struct Tags {
inner: LanceDbTable,
@@ -1128,3 +1142,65 @@ impl Tags {
.default_error()
}
}
#[napi]
pub struct Branches {
inner: LanceDbTable,
}
#[napi]
impl Branches {
#[napi]
pub async fn list(&self) -> napi::Result<HashMap<String, BranchContents>> {
let branches = self.inner.list_branches().await.default_error()?;
let result = branches
.into_iter()
.map(|(k, v)| {
(
k,
BranchContents {
parent_branch: v.parent_branch,
parent_version: v.parent_version as i64,
manifest_size: v.manifest_size as i64,
},
)
})
.collect();
Ok(result)
}
#[napi]
pub async fn create(
&self,
name: String,
from_ref: Option<String>,
from_version: Option<i64>,
) -> napi::Result<Table> {
let from_ref = from_ref.filter(|b| b != "main");
let from_version = from_version
.map(|v| {
u64::try_from(v).map_err(|_| {
napi::Error::from_reason("from_version must be a non-negative integer")
})
})
.transpose()?;
let from = Ref::Version(from_ref, from_version);
let table = self
.inner
.create_branch(&name, from)
.await
.default_error()?;
Ok(Table::new(table))
}
#[napi]
pub async fn checkout(&self, name: String) -> napi::Result<Table> {
let table = self.inner.checkout_branch(&name).await.default_error()?;
Ok(Table::new(table))
}
#[napi]
pub async fn delete(&self, name: String) -> napi::Result<()> {
self.inner.delete_branch(&name).await.default_error()
}
}

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.33.1-beta.2"
current_version = "0.33.1-beta.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.33.1-beta.2"
version = "0.33.1-beta.0"
publish = false
edition.workspace = true
description = "Python bindings for LanceDB"

View File

@@ -226,6 +226,9 @@ class Table:
async def close_lsm_writers(self) -> None: ...
@property
def tags(self) -> Tags: ...
@property
def branches(self) -> Branches: ...
def current_branch(self) -> Optional[str]: ...
def query(self) -> Query: ...
def take_offsets(self, offsets: list[int]) -> TakeQuery: ...
def take_row_ids(self, row_ids: list[int]) -> TakeQuery: ...
@@ -238,6 +241,17 @@ class Tags:
async def delete(self, tag: str): ...
async def update(self, tag: str, version: int): ...
class Branches:
async def list(self) -> Dict[str, Any]: ...
async def create(
self,
name: str,
from_ref: Optional[str] = None,
from_version: Optional[int] = None,
) -> Table: ...
async def checkout(self, name: str) -> Table: ...
async def delete(self, name: str) -> None: ...
class IndexConfig:
name: str
index_type: str

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import asyncio
import concurrent.futures
import os
import threading
import warnings
@@ -38,24 +37,6 @@ class BackgroundEventLoop:
LOOP = BackgroundEventLoop()
def _new_embedding_executor() -> concurrent.futures.ThreadPoolExecutor:
return concurrent.futures.ThreadPoolExecutor(thread_name_prefix="lancedb-embedding")
# Embedding functions can block for a long time -- a heavy local model or an
# HTTP request to a remote embeddings API. Running them on asyncio's default
# executor lets them starve the unrelated blocking I/O that shares that pool,
# so they get a dedicated one. See
# https://github.com/lancedb/lancedb/issues/3310.
_EMBEDDING_EXECUTOR = _new_embedding_executor()
def embedding_executor() -> concurrent.futures.ThreadPoolExecutor:
"""Return the executor dedicated to running blocking embedding calls."""
return _EMBEDDING_EXECUTOR
_FORK_WARNED = False
@@ -66,12 +47,6 @@ def _reset_after_fork():
# the new state. The Rust-side tokio runtime is reset analogously by a
# pthread_atfork hook installed in the _lancedb extension.
LOOP._start()
# The embedding executor's worker threads are dead in the child as well.
# Replace it with a fresh pool (threads are spawned lazily, so this is
# cheap); we don't shut down the old one, since joining its dead workers
# could hang.
global _EMBEDDING_EXECUTOR
_EMBEDDING_EXECUTOR = _new_embedding_executor()
global _FORK_WARNED
if not _FORK_WARNED:
_FORK_WARNED = True

View File

@@ -416,6 +416,7 @@ class DBConnection(EnforceOverrides):
namespace_path: Optional[List[str]] = None,
storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None,
branch: Optional[str] = None,
) -> Table:
"""Open a Lance Table in the database.
@@ -444,6 +445,9 @@ class DBConnection(EnforceOverrides):
connection will be inherited by the table, but can be overridden here.
See available options at
<https://docs.lancedb.com/storage/>
branch: str, optional
If provided, open a handle scoped to this branch instead of the
default branch. Reads and writes operate in the branch's context.
Returns
-------
@@ -958,6 +962,7 @@ class LanceDBConnection(DBConnection):
namespace_path: Optional[List[str]] = None,
storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None,
branch: Optional[str] = None,
) -> LanceTable:
"""Open a table in the database.
@@ -968,6 +973,9 @@ class LanceDBConnection(DBConnection):
namespace_path: List[str], optional
The namespace to open the table from. When non-empty, the
table is resolved through the directory namespace client.
branch: str, optional
If provided, open a handle scoped to this branch instead of the
default branch. Reads and writes operate in the branch's context.
Returns
-------
@@ -987,20 +995,24 @@ class LanceDBConnection(DBConnection):
)
if namespace_path:
return self._namespace_conn().open_table(
tbl = self._namespace_conn().open_table(
name,
namespace_path=namespace_path,
storage_options=storage_options,
index_cache_size=index_cache_size,
)
else:
tbl = LanceTable.open(
self,
name,
namespace_path=namespace_path,
storage_options=storage_options,
index_cache_size=index_cache_size,
)
return LanceTable.open(
self,
name,
namespace_path=namespace_path,
storage_options=storage_options,
index_cache_size=index_cache_size,
)
if branch is not None:
return tbl.branches.checkout(branch)
return tbl
def clone_table(
self,
@@ -1641,6 +1653,7 @@ class AsyncConnection(object):
location: Optional[str] = None,
namespace_client: Optional[Any] = None,
managed_versioning: Optional[bool] = None,
branch: Optional[str] = None,
) -> AsyncTable:
"""Open a Lance Table in the database.
@@ -1676,6 +1689,9 @@ class AsyncConnection(object):
managed_versioning: bool, optional
Whether managed versioning is enabled for this table. If provided,
avoids a redundant describe_table call when namespace_client is set.
branch: str, optional
If provided, open a handle scoped to this branch instead of the
default branch. Reads and writes operate in the branch's context.
Returns
-------
@@ -1692,7 +1708,10 @@ class AsyncConnection(object):
namespace_client=namespace_client,
managed_versioning=managed_versioning,
)
return AsyncTable(table)
tbl = AsyncTable(table)
if branch is not None:
return await tbl.branches.checkout(branch)
return tbl
async def clone_table(
self,

View File

@@ -544,6 +544,7 @@ class LanceNamespaceDBConnection(DBConnection):
namespace_path: Optional[List[str]] = None,
storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None,
branch: Optional[str] = None,
) -> Table:
if namespace_path is None:
namespace_path = []
@@ -562,7 +563,7 @@ class LanceNamespaceDBConnection(DBConnection):
raise TableNotFoundError(f"Table not found: {'$'.join(table_id)}")
raise
return LanceTable(
tbl = LanceTable(
self,
name,
namespace_path=namespace_path,
@@ -570,6 +571,9 @@ class LanceNamespaceDBConnection(DBConnection):
pushdown_operations=self._namespace_client_pushdown_operations,
_async=async_table,
)
if branch is not None:
return tbl.branches.checkout(branch)
return tbl
@override
def drop_table(self, name: str, namespace_path: Optional[List[str]] = None):
@@ -974,12 +978,13 @@ class AsyncLanceNamespaceDBConnection:
namespace_path: Optional[List[str]] = None,
storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None,
branch: Optional[str] = None,
) -> AsyncTable:
"""Open an existing table from the namespace."""
if namespace_path is None:
namespace_path = []
try:
return await self._inner.open_table(
tbl = await self._inner.open_table(
name,
namespace_path=namespace_path,
storage_options=storage_options,
@@ -990,6 +995,9 @@ class AsyncLanceNamespaceDBConnection:
table_id = namespace_path + [name]
raise TableNotFoundError(f"Table not found: {'$'.join(table_id)}")
raise
if branch is not None:
return await tbl.branches.checkout(branch)
return tbl
async def drop_table(self, name: str, namespace_path: Optional[List[str]] = None):
"""Drop a table from the namespace."""

View File

@@ -41,14 +41,6 @@ from .rerankers.rrf import RRFReranker
from .rerankers.util import check_reranker_result
from .util import flatten_columns
BlobMode = Literal["lazy", "bytes", "descriptions"]
_BLOB_MODE_TO_HANDLING = {
"lazy": "blobs_descriptions",
"bytes": "all_binary",
"descriptions": "blobs_descriptions",
}
if TYPE_CHECKING:
import sys
@@ -63,7 +55,7 @@ if TYPE_CHECKING:
from ._lancedb import VectorQuery as LanceVectorQuery
from .common import VEC
from .pydantic import LanceModel
from .table import AsyncTable, Table
from .table import Table
if sys.version_info >= (3, 11):
from typing import Self
@@ -73,179 +65,6 @@ if TYPE_CHECKING:
T = TypeVar("T", bound="LanceModel")
def _validate_blob_mode(blob_mode: BlobMode) -> None:
if blob_mode not in _BLOB_MODE_TO_HANDLING:
modes = ", ".join(repr(mode) for mode in _BLOB_MODE_TO_HANDLING)
raise ValueError(f"blob_mode must be one of {modes}, got {blob_mode!r}")
def _field_is_blob(field: pa.Field) -> bool:
metadata = field.metadata or {}
return metadata.get(b"lance-encoding:blob") == b"true" or (
metadata.get("lance-encoding:blob") == "true"
)
def _schema_has_blob_field(schema: pa.Schema) -> bool:
return any(_field_is_blob(field) for field in schema)
def _blob_mode_requires_native_pandas(blob_mode: BlobMode, schema: pa.Schema) -> bool:
return blob_mode in _BLOB_MODE_TO_HANDLING and _schema_has_blob_field(schema)
def _unsupported_blob_pandas_error(reason: str) -> RuntimeError:
return RuntimeError(
"blob columns require Lance native scanner conversion for query "
f"to_pandas(), but {reason}. Use a plain scan query or remove blob "
"columns from the projection."
)
def _query_is_plain_scan(query: Query) -> bool:
return (
query.vector is None
and query.full_text_query is None
and not query.postfilter
and not query.order_by
)
def _filter_to_sql(filter: Optional[Union[str, Expr]]) -> Optional[str]:
if filter is None:
return None
if isinstance(filter, Expr):
return filter.to_sql()
return filter
def _projection_to_scanner_kwargs(
columns: Optional[
Union[
List[str], List[Tuple[str, Union[str, Expr]]], Dict[str, Union[str, Expr]]
]
],
) -> Dict[str, Any]:
if columns is None:
return {}
if isinstance(columns, list):
if all(isinstance(column, str) for column in columns):
return {"columns": columns}
if all(isinstance(column, tuple) and len(column) == 2 for column in columns):
return {
"columns": {
name: expr.to_sql() if isinstance(expr, Expr) else expr
for name, expr in columns
}
}
# Let Lance raise the detailed projection validation error.
return {"columns": columns}
projection = {}
for name, expr in columns.items():
if isinstance(expr, Expr):
expr = expr.to_sql()
projection[name] = expr
return {"columns": projection}
def _scanner_kwargs_for_query(
query: Query, blob_mode: BlobMode, dataset: Optional[Any] = None
) -> Dict[str, Any]:
fragments = _scanner_fragments_for_query(query, dataset)
kwargs = {
**_projection_to_scanner_kwargs(query.columns),
"filter": _filter_to_sql(query.filter),
"limit": query.limit,
"offset": query.offset,
"with_row_id": query.with_row_id,
"with_row_address": query.with_row_address,
"fast_search": query.fast_search,
"blob_handling": _BLOB_MODE_TO_HANDLING[blob_mode],
"fragments": fragments,
}
return {key: value for key, value in kwargs.items() if value is not None}
def _scanner_fragments_for_query(query: Query, dataset: Optional[Any]) -> Optional[Any]:
if query.fragments is not None and query.fragment_ids is not None:
raise ValueError("fragments and fragment_ids cannot both be set")
if query.fragments is not None:
return query.fragments
if query.fragment_ids is None:
return None
if dataset is None:
raise ValueError("fragment_ids require a Lance dataset")
requested = set(query.fragment_ids)
fragments = [
fragment
for fragment in dataset.get_fragments()
if fragment.fragment_id in requested
]
found = {fragment.fragment_id for fragment in fragments}
missing = requested - found
if missing:
missing_ids = ", ".join(str(fragment_id) for fragment_id in sorted(missing))
raise ValueError(f"fragment_ids not found in dataset: {missing_ids}")
return fragments
def _ensure_lazy_blob_frame(
df: "pd.DataFrame", schema: pa.Schema, blob_mode: BlobMode
) -> "pd.DataFrame":
if blob_mode != "lazy" or not _schema_has_blob_field(schema) or len(df) == 0:
return df
for field in schema:
if not _field_is_blob(field) or field.name not in df.columns:
continue
value = df[field.name].iloc[0]
if value is not None and not hasattr(value, "readall"):
raise _unsupported_blob_pandas_error(
"the Lance scanner did not return lazy blob files"
)
return df
def _scanner_to_table(scanner: Any) -> pa.Table:
if hasattr(scanner, "to_pyarrow"):
reader = scanner.to_pyarrow()
return reader.read_all()
if hasattr(scanner, "to_table"):
return scanner.to_table()
reader = scanner.to_reader()
return reader.read_all()
def _scanner_to_pandas(scanner: Any, blob_mode: BlobMode, **kwargs) -> "pd.DataFrame":
schema = getattr(scanner, "projected_schema", None)
if schema is None:
schema = getattr(scanner, "schema", None)
if schema is None:
schema = getattr(scanner, "dataset_schema", None)
if callable(schema):
schema = schema()
if hasattr(scanner, "to_pandas"):
try:
df = scanner.to_pandas(blob_mode=blob_mode, **kwargs)
except TypeError as err:
message = str(err)
if "blob_mode" not in message and "unexpected keyword" not in message:
raise
df = scanner.to_pandas(**kwargs)
if schema is not None:
return _ensure_lazy_blob_frame(df, schema, blob_mode)
return df
tbl = _scanner_to_table(scanner)
if blob_mode == "lazy" and _schema_has_blob_field(tbl.schema):
raise _unsupported_blob_pandas_error(
"the Lance scanner does not expose to_pandas"
)
return tbl.to_pandas(**kwargs)
# Pydantic validation function for vector queries
def ensure_vector_query(
val: Any,
@@ -680,13 +499,6 @@ class Query(pydantic.BaseModel):
# if true, include the row id in the results
with_row_id: Optional[bool] = None
# if true, include the row address in the results
with_row_address: Optional[bool] = None
# Lance fragments or fragment ids to scan on scanner-backed plain queries
fragments: Optional[Any] = None
fragment_ids: Optional[List[int]] = None
# offset to start fetching results from
offset: Optional[int] = None
@@ -879,9 +691,6 @@ class LanceQueryBuilder(ABC):
self._where = None
self._postfilter = None
self._with_row_id = None
self._with_row_address = None
self._fragments = None
self._fragment_ids = None
self._vector = None
self._text = None
self._ef = None
@@ -909,7 +718,6 @@ class LanceQueryBuilder(ABC):
self,
flatten: Optional[Union[int, bool]] = None,
*,
blob_mode: BlobMode = "lazy",
timeout: Optional[timedelta] = None,
**kwargs,
) -> "pd.DataFrame":
@@ -929,41 +737,11 @@ class LanceQueryBuilder(ABC):
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
blob_mode: str, default "lazy"
Controls how blob columns are returned for plain scan queries.
Vector, FTS, hybrid, and other non-native query shapes keep the
existing Arrow conversion path and only support blob descriptions.
**kwargs
Forwarded to pyarrow.Table.to_pandas after query execution and
optional flattening.
"""
_validate_blob_mode(blob_mode)
output_schema = getattr(self, "output_schema", None)
if output_schema is not None:
schema = output_schema()
if _blob_mode_requires_native_pandas(blob_mode, schema):
native_error = None
if (flatten is None or blob_mode == "descriptions") and timeout is None:
try:
df = self._plain_scan_to_pandas(
blob_mode, flatten=flatten, **kwargs
)
if df is not None:
return df
except Exception as err:
native_error = err
reason = (
"this query shape cannot use Lance native pandas conversion"
if native_error is None
else str(native_error)
)
raise _unsupported_blob_pandas_error(reason) from native_error
tbl = flatten_columns(self.to_arrow(timeout=timeout), flatten)
if _blob_mode_requires_native_pandas(blob_mode, tbl.schema):
raise _unsupported_blob_pandas_error(
"this query shape cannot use Lance native pandas conversion"
)
return tbl.to_pandas(**kwargs)
@abstractmethod
@@ -1169,32 +947,6 @@ class LanceQueryBuilder(ABC):
self._with_row_id = with_row_id
return self
def with_row_address(self, with_row_address: bool = True) -> Self:
"""Set whether to return row addresses.
Parameters
----------
with_row_address: bool, default True
If True, return the _rowaddr column in the results.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._with_row_address = with_row_address
return self
def with_fragments(self, fragments: Any) -> Self:
"""Set the Lance fragments to scan for plain scanner-backed queries."""
self._fragments = fragments
return self
def fragment_ids(self, fragment_ids: List[int]) -> Self:
"""Set the Lance fragment ids to scan for plain scanner-backed queries."""
self._fragment_ids = fragment_ids
return self
def explain_plan(self, verbose: Optional[bool] = False) -> str:
"""Return the execution plan for this query.
@@ -1334,25 +1086,6 @@ class LanceQueryBuilder(ABC):
"""
raise NotImplementedError
def _plain_scan_to_pandas(
self,
blob_mode: BlobMode,
flatten: Optional[Union[int, bool]] = None,
**kwargs,
) -> Optional["pd.DataFrame"]:
query = self.to_query_object()
if not _query_is_plain_scan(query):
return None
dataset = self._table.to_lance()
scanner = dataset.scanner(
**_scanner_kwargs_for_query(query, blob_mode, dataset)
)
if flatten is not None:
tbl = flatten_columns(_scanner_to_table(scanner), flatten)
return tbl.to_pandas(**kwargs)
return _scanner_to_pandas(scanner, blob_mode, **kwargs)
@abstractmethod
def to_query_object(self) -> Query:
"""Return a serializable representation of the query
@@ -1624,9 +1357,6 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
refine_factor=self._refine_factor,
vector_column=self._vector_column,
with_row_id=self._with_row_id,
with_row_address=self._with_row_address,
fragments=self._fragments,
fragment_ids=self._fragment_ids,
offset=self._offset,
fast_search=self._fast_search,
ef=self._ef,
@@ -1829,9 +1559,6 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
limit=self._limit,
postfilter=self._postfilter,
with_row_id=self._with_row_id,
with_row_address=self._with_row_address,
fragments=self._fragments,
fragment_ids=self._fragment_ids,
full_text_query=FullTextSearchQuery(
query=self._query, columns=self._fts_columns
),
@@ -1902,9 +1629,6 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
filter=self._where,
limit=self._limit,
with_row_id=self._with_row_id,
with_row_address=self._with_row_address,
fragments=self._fragments,
fragment_ids=self._fragment_ids,
offset=self._offset,
order_by=self._order_by,
)
@@ -2483,11 +2207,7 @@ class AsyncQueryBase(object):
Base class for all async queries (take, scan, vector, fts, hybrid)
"""
def __init__(
self,
inner: Union[LanceQuery, LanceVectorQuery, LanceTakeQuery],
table: Optional["AsyncTable"] = None,
):
def __init__(self, inner: Union[LanceQuery, LanceVectorQuery, LanceTakeQuery]):
"""
Construct an AsyncQueryBase
@@ -2495,10 +2215,6 @@ class AsyncQueryBase(object):
[AsyncTable.query][lancedb.table.AsyncTable.query] method to create a query.
"""
self._inner = inner
self._table = table
self._with_row_address = None
self._fragments = None
self._fragment_ids = None
def to_query_object(self) -> Query:
"""
@@ -2507,11 +2223,7 @@ class AsyncQueryBase(object):
This is currently experimental but can be useful as the query object is pure
python and more easily serializable.
"""
query = Query.from_inner(self._inner.to_query_request())
query.with_row_address = self._with_row_address
query.fragments = self._fragments
query.fragment_ids = self._fragment_ids
return query
return Query.from_inner(self._inner.to_query_request())
def select(self, columns: Union[List[str], dict[str, str]]) -> Self:
"""
@@ -2568,27 +2280,6 @@ class AsyncQueryBase(object):
self._inner.with_row_id()
return self
def with_row_address(self, with_row_address: bool = True) -> Self:
"""
Include the _rowaddr column in scanner-backed plain query results.
"""
self._with_row_address = with_row_address
return self
def with_fragments(self, fragments: Any) -> Self:
"""
Restrict scanner-backed plain query results to the given Lance fragments.
"""
self._fragments = fragments
return self
def fragment_ids(self, fragment_ids: List[int]) -> Self:
"""
Restrict scanner-backed plain query results to the given Lance fragment ids.
"""
self._fragment_ids = fragment_ids
return self
async def to_batches(
self,
*,
@@ -2666,8 +2357,6 @@ class AsyncQueryBase(object):
self,
flatten: Optional[Union[int, bool]] = None,
timeout: Optional[timedelta] = None,
*,
blob_mode: BlobMode = "lazy",
**kwargs,
) -> "pd.DataFrame":
"""
@@ -2701,63 +2390,13 @@ class AsyncQueryBase(object):
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
blob_mode: str, default "lazy"
Controls how blob columns are returned for plain scan queries.
Vector, FTS, hybrid, and other non-native query shapes keep the
existing Arrow conversion path and only support blob descriptions.
**kwargs
Forwarded to pyarrow.Table.to_pandas after query execution and
optional flattening.
"""
_validate_blob_mode(blob_mode)
if hasattr(self._inner, "output_schema"):
schema = await self.output_schema()
if _blob_mode_requires_native_pandas(blob_mode, schema):
native_error = None
if (flatten is None or blob_mode == "descriptions") and timeout is None:
try:
df = await self._plain_scan_to_pandas(
blob_mode, flatten=flatten, **kwargs
)
if df is not None:
return df
except Exception as err:
native_error = err
reason = (
"this query shape cannot use Lance native pandas conversion"
if native_error is None
else str(native_error)
)
raise _unsupported_blob_pandas_error(reason) from native_error
tbl = flatten_columns(await self.to_arrow(timeout=timeout), flatten)
if _blob_mode_requires_native_pandas(blob_mode, tbl.schema):
raise _unsupported_blob_pandas_error(
"this query shape cannot use Lance native pandas conversion"
)
return tbl.to_pandas(**kwargs)
async def _plain_scan_to_pandas(
self,
blob_mode: BlobMode,
flatten: Optional[Union[int, bool]] = None,
**kwargs,
) -> Optional["pd.DataFrame"]:
if self._table is None:
return None
query = self.to_query_object()
if not _query_is_plain_scan(query):
return None
dataset = await self._table._to_lance()
scanner = dataset.scanner(
**_scanner_kwargs_for_query(query, blob_mode, dataset)
)
if flatten is not None:
tbl = flatten_columns(_scanner_to_table(scanner), flatten)
return tbl.to_pandas(**kwargs)
return _scanner_to_pandas(scanner, blob_mode, **kwargs)
return (
flatten_columns(await self.to_arrow(timeout=timeout), flatten)
).to_pandas(**kwargs)
async def to_polars(
self,
@@ -2864,18 +2503,14 @@ class AsyncStandardQuery(AsyncQueryBase):
Base class for "standard" async queries (all but take currently)
"""
def __init__(
self,
inner: Union[LanceQuery, LanceVectorQuery],
table: Optional["AsyncTable"] = None,
):
def __init__(self, inner: Union[LanceQuery, LanceVectorQuery]):
"""
Construct an AsyncStandardQuery
This method is not intended to be called directly. Instead, use the
[AsyncTable.query][lancedb.table.AsyncTable.query] method to create a query.
"""
super().__init__(inner, table)
super().__init__(inner)
def where(self, predicate: Union[str, Expr]) -> Self:
"""
@@ -2981,14 +2616,14 @@ class AsyncStandardQuery(AsyncQueryBase):
class AsyncQuery(AsyncStandardQuery):
def __init__(self, inner: LanceQuery, table: Optional["AsyncTable"] = None):
def __init__(self, inner: LanceQuery):
"""
Construct an AsyncQuery
This method is not intended to be called directly. Instead, use the
[AsyncTable.query][lancedb.table.AsyncTable.query] method to create a query.
"""
super().__init__(inner, table)
super().__init__(inner)
self._inner = inner
@classmethod
@@ -3072,11 +2707,10 @@ class AsyncQuery(AsyncStandardQuery):
new_self = self._inner.nearest_to(query_vectors[0])
for v in query_vectors[1:]:
new_self.add_query_vector(v)
return AsyncVectorQuery(new_self, self._table)
return AsyncVectorQuery(new_self)
else:
return AsyncVectorQuery(
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector)),
self._table,
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
)
def nearest_to_text(
@@ -3109,18 +2743,17 @@ class AsyncQuery(AsyncStandardQuery):
if isinstance(query, str):
return AsyncFTSQuery(
self._inner.nearest_to_text({"query": query, "columns": columns}),
self._table,
self._inner.nearest_to_text({"query": query, "columns": columns})
)
# FullTextQuery object
return AsyncFTSQuery(self._inner.nearest_to_text({"query": query}), self._table)
return AsyncFTSQuery(self._inner.nearest_to_text({"query": query}))
class AsyncFTSQuery(AsyncStandardQuery):
"""A query for full text search for LanceDB."""
def __init__(self, inner: LanceFTSQuery, table: Optional["AsyncTable"] = None):
super().__init__(inner, table)
def __init__(self, inner: LanceFTSQuery):
super().__init__(inner)
self._inner = inner
self._reranker = None
@@ -3202,11 +2835,10 @@ class AsyncFTSQuery(AsyncStandardQuery):
new_self = self._inner.nearest_to(query_vectors[0])
for v in query_vectors[1:]:
new_self.add_query_vector(v)
return AsyncHybridQuery(new_self, self._table)
return AsyncHybridQuery(new_self)
else:
return AsyncHybridQuery(
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector)),
self._table,
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
)
async def to_batches(
@@ -3397,7 +3029,7 @@ class AsyncVectorQueryBase:
class AsyncVectorQuery(AsyncStandardQuery, AsyncVectorQueryBase):
def __init__(self, inner: LanceVectorQuery, table: Optional["AsyncTable"] = None):
def __init__(self, inner: LanceVectorQuery):
"""
Construct an AsyncVectorQuery
@@ -3407,7 +3039,7 @@ class AsyncVectorQuery(AsyncStandardQuery, AsyncVectorQueryBase):
a vector query. Or you can use
[AsyncTable.vector_search][lancedb.table.AsyncTable.vector_search]
"""
super().__init__(inner, table)
super().__init__(inner)
self._inner = inner
self._reranker = None
self._query_string = None
@@ -3461,13 +3093,10 @@ class AsyncVectorQuery(AsyncStandardQuery, AsyncVectorQueryBase):
if isinstance(query, str):
return AsyncHybridQuery(
self._inner.nearest_to_text({"query": query, "columns": columns}),
self._table,
self._inner.nearest_to_text({"query": query, "columns": columns})
)
# FullTextQuery object
return AsyncHybridQuery(
self._inner.nearest_to_text({"query": query}), self._table
)
return AsyncHybridQuery(self._inner.nearest_to_text({"query": query}))
async def to_batches(
self,
@@ -3494,8 +3123,8 @@ class AsyncHybridQuery(AsyncStandardQuery, AsyncVectorQueryBase):
in the `rerank` method to convert the scores to ranks and then normalize them.
"""
def __init__(self, inner: LanceHybridQuery, table: Optional["AsyncTable"] = None):
super().__init__(inner, table)
def __init__(self, inner: LanceHybridQuery):
super().__init__(inner)
self._inner = inner
self._norm = "score"
self._reranker = RRFReranker()
@@ -3536,8 +3165,8 @@ class AsyncHybridQuery(AsyncStandardQuery, AsyncVectorQueryBase):
max_batch_length: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> AsyncRecordBatchReader:
fts_query = AsyncFTSQuery(self._inner.to_fts_query(), self._table)
vec_query = AsyncVectorQuery(self._inner.to_vector_query(), self._table)
fts_query = AsyncFTSQuery(self._inner.to_fts_query())
vec_query = AsyncVectorQuery(self._inner.to_vector_query())
# save the row ID choice that was made on the query builder and force it
# to actually fetch the row ids because we need this for reranking
@@ -3637,16 +3266,8 @@ class AsyncTakeQuery(AsyncQueryBase):
Builder for parameterizing and executing take queries.
"""
def __init__(self, inner: LanceTakeQuery, table: Optional["AsyncTable"] = None):
super().__init__(inner, table)
async def _plain_scan_to_pandas(
self,
blob_mode: BlobMode,
flatten: Optional[Union[int, bool]] = None,
**kwargs,
) -> Optional["pd.DataFrame"]:
return None
def __init__(self, inner: LanceTakeQuery):
super().__init__(inner)
class BaseQueryBuilder(object):
@@ -3698,27 +3319,6 @@ class BaseQueryBuilder(object):
self._inner.with_row_id()
return self
def with_row_address(self, with_row_address: bool = True) -> Self:
"""
Include the _rowaddr column in scanner-backed plain query results.
"""
self._inner.with_row_address(with_row_address)
return self
def with_fragments(self, fragments: Any) -> Self:
"""
Restrict scanner-backed plain query results to the given Lance fragments.
"""
self._inner.with_fragments(fragments)
return self
def fragment_ids(self, fragment_ids: List[int]) -> Self:
"""
Restrict scanner-backed plain query results to the given Lance fragment ids.
"""
self._inner.fragment_ids(fragment_ids)
return self
def output_schema(self) -> pa.Schema:
"""
Return the output schema for the query
@@ -3800,8 +3400,6 @@ class BaseQueryBuilder(object):
self,
flatten: Optional[Union[int, bool]] = None,
timeout: Optional[timedelta] = None,
*,
blob_mode: BlobMode = "lazy",
**kwargs,
) -> "pd.DataFrame":
"""
@@ -3835,15 +3433,11 @@ class BaseQueryBuilder(object):
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
blob_mode: str, default "lazy"
Controls how blob columns are returned for plain scan queries.
**kwargs
Forwarded to pyarrow.Table.to_pandas after query execution and
optional flattening.
"""
return LOOP.run(
self._inner.to_pandas(flatten, timeout, blob_mode=blob_mode, **kwargs)
)
return LOOP.run(self._inner.to_pandas(flatten, timeout, **kwargs))
def to_polars(
self,

View File

@@ -383,6 +383,7 @@ class RemoteDBConnection(DBConnection):
namespace_path: Optional[List[str]] = None,
storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None,
branch: Optional[str] = None,
) -> Table:
"""Open a Lance Table in the database.
@@ -400,6 +401,9 @@ class RemoteDBConnection(DBConnection):
"""
from .table import RemoteTable
if branch is not None:
raise NotImplementedError("branching is not yet supported on remote tables")
if namespace_path is None:
namespace_path = []
if index_cache_size is not None:

View File

@@ -27,9 +27,6 @@ class LanceDBClientError(RuntimeError):
self.request_id = request_id
self.status_code = status_code
def __reduce__(self) -> tuple[type, tuple]:
return (self.__class__, (str(self), self.request_id, self.status_code))
class HttpError(LanceDBClientError):
"""An error that occurred during an HTTP request.
@@ -104,19 +101,3 @@ class RetryError(LanceDBClientError):
self.max_request_failures = max_request_failures
self.max_connect_failures = max_connect_failures
self.max_read_failures = max_read_failures
def __reduce__(self) -> tuple[type, tuple]:
return (
self.__class__,
(
str(self),
self.request_id,
self.request_failures,
self.connect_failures,
self.read_failures,
self.max_request_failures,
self.max_connect_failures,
self.max_read_failures,
self.status_code,
),
)

View File

@@ -125,9 +125,6 @@ class MRRReranker(Reranker):
This cannot reuse rerank_hybrid because MRR semantics require treating
each vector result as a separate ranking system.
"""
if not vector_results:
raise ValueError("vector_results must not be empty")
if not all(isinstance(v, type(vector_results[0])) for v in vector_results):
raise ValueError(
"All elements in vector_results should be of the same type"

View File

@@ -82,9 +82,6 @@ class RRFReranker(Reranker):
results from multiple vector searches as it doesn't support reranking
vector results individually.
"""
if not vector_results:
raise ValueError("vector_results must not be empty")
# Make sure all elements are of the same type
if not all(isinstance(v, type(vector_results[0])) for v in vector_results):
raise ValueError(

View File

@@ -30,7 +30,7 @@ from lancedb.scannable import _register_optional_converters, to_scannable
from . import __version__
from lancedb.arrow import peek_reader
from lancedb.background_loop import LOOP, embedding_executor
from lancedb.background_loop import LOOP
from .dependencies import (
_check_for_hugging_face,
_check_for_lance,
@@ -89,26 +89,6 @@ from .index import lang_mapping
BlobMode = Literal["lazy", "bytes", "descriptions"]
_VALID_BLOB_MODES = ("lazy", "bytes", "descriptions")
def _validate_blob_mode(blob_mode: BlobMode) -> None:
if blob_mode not in _VALID_BLOB_MODES:
modes = ", ".join(repr(mode) for mode in _VALID_BLOB_MODES)
raise ValueError(f"blob_mode must be one of {modes}, got {blob_mode!r}")
def _field_is_blob(field: pa.Field) -> bool:
metadata = field.metadata or {}
return metadata.get(b"lance-encoding:blob") == b"true" or (
metadata.get("lance-encoding:blob") == "true"
)
def _schema_has_blob_field(schema: pa.Schema) -> bool:
return any(_field_is_blob(field) for field in schema)
_MODEL_BACKED_TOKENIZER_PREFIXES = ("jieba", "lindera")
_MODEL_BACKED_TOKENIZER_ERRORS = (
"unknown base tokenizer",
@@ -778,6 +758,15 @@ class Table(ABC):
"""
raise NotImplementedError
@property
def branches(self) -> "Branches":
"""Branch management for the table.
Branches are isolated, writable lines of history forked from another
branch (or version). Writes on a branch do not affect ``main``.
"""
raise NotImplementedError
def __len__(self) -> int:
"""The number of rows in this Table"""
return self.count_rows(None)
@@ -2106,22 +2095,27 @@ class LanceTable(Table):
"Please install with `pip install pylance`."
)
branch = self.current_branch()
version = None if branch is not None else self.version
if self._namespace_client is not None:
table_id = self._namespace_path + [self.name]
return lance.dataset(
version=self.version,
ds = lance.dataset(
version=version,
storage_options=self._conn.storage_options,
namespace_client=self._namespace_client,
table_id=table_id,
**kwargs,
)
return lance.dataset(
self._dataset_path,
version=self.version,
storage_options=self._conn.storage_options,
**kwargs,
)
else:
ds = lance.dataset(
self._dataset_path,
version=version,
storage_options=self._conn.storage_options,
**kwargs,
)
if branch is not None:
ds = ds.checkout_version((branch, self.version))
return ds
@property
def schema(self) -> pa.Schema:
@@ -2187,6 +2181,19 @@ class LanceTable(Table):
"""
return Tags(self._table)
@property
def branches(self) -> "Branches":
"""Branch management for the table.
``create``/``checkout`` return a new table handle scoped to the branch;
writes on it do not affect ``main``.
"""
return Branches(self)
def current_branch(self) -> Optional[str]:
"""The branch this table handle is scoped to, or ``None`` for ``main``."""
return self._table.current_branch()
def checkout(self, version: Union[int, str]):
"""Checkout a version of the table. This is an in-place operation.
@@ -2314,14 +2321,9 @@ class LanceTable(Table):
-------
pd.DataFrame
"""
_validate_blob_mode(blob_mode)
if blob_mode == "descriptions" or not _schema_has_blob_field(self.schema):
return self.to_arrow().to_pandas(**kwargs)
if (
blob_mode == "lazy"
and self._namespace_client is None
and get_uri_scheme(self._dataset_path) == "memory"
if blob_mode == "lazy" and (
self._namespace_client is not None
or get_uri_scheme(self._dataset_path) == "memory"
):
return self.to_arrow().to_pandas(**kwargs)
@@ -3446,9 +3448,13 @@ class LanceTable(Table):
batch_size: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader:
# Branch queries run locally: the server-side query protocol can't
# carry a branch yet.
# TODO: push down server-side once it can (with remote table support).
if (
"QueryTable" in self._pushdown_operations
and self._namespace_client is not None
and self.current_branch() is None
):
from lancedb.namespace import _execute_server_side_query
@@ -4342,7 +4348,7 @@ class AsyncTable:
can be executed with methods like [to_arrow][lancedb.query.AsyncQuery.to_arrow],
[to_pandas][lancedb.query.AsyncQuery.to_pandas] and more.
"""
return AsyncQuery(self._inner.query(), self)
return AsyncQuery(self._inner.query())
async def _to_lance(self, **kwargs) -> lance.LanceDataset:
try:
@@ -4353,12 +4359,20 @@ class AsyncTable:
"Please install with `pip install pylance`."
)
return lance.dataset(
# lance.dataset() can't open a branch directly, so open the base table
# and check out the branch ref (a None branch resolves to main).
branch = self.current_branch()
table_version = await self.version()
version = None if branch is not None else table_version
ds = lance.dataset(
await self.uri(),
version=await self.version(),
version=version,
storage_options=await self.latest_storage_options(),
**kwargs,
)
if branch is not None:
ds = ds.checkout_version((branch, table_version))
return ds
async def to_pandas(self, blob_mode: BlobMode = "lazy", **kwargs) -> "pd.DataFrame":
"""Return the table as a pandas DataFrame.
@@ -4374,13 +4388,7 @@ class AsyncTable:
-------
pd.DataFrame
"""
_validate_blob_mode(blob_mode)
if blob_mode == "descriptions" or not _schema_has_blob_field(
await self.schema()
):
return (await self.to_arrow()).to_pandas(**kwargs)
if blob_mode == "lazy" and get_uri_scheme(await self.uri()) == "memory":
if blob_mode == "lazy":
return (await self.to_arrow()).to_pandas(**kwargs)
return (await self._to_lance()).to_pandas(blob_mode=blob_mode, **kwargs)
@@ -4908,13 +4916,10 @@ class AsyncTable:
if embedding is not None:
loop = asyncio.get_running_loop()
# This function is likely to block, since it either calls an expensive
# function or makes an HTTP request to an embeddings REST API. Run it
# on a dedicated executor so it can't starve the default executor that
# other blocking I/O shares. See
# https://github.com/lancedb/lancedb/issues/3310.
# function or makes an HTTP request to an embeddings REST API.
return (
await loop.run_in_executor(
embedding_executor(),
None,
embedding.function.compute_query_embeddings_with_retry,
query,
)
@@ -5427,7 +5432,7 @@ class AsyncTable:
pa.RecordBatch
A record batch containing the rows at the given offsets.
"""
return AsyncTakeQuery(self._inner.take_offsets(offsets), self)
return AsyncTakeQuery(self._inner.take_offsets(offsets))
def take_row_ids(self, row_ids: list[int]) -> AsyncTakeQuery:
"""
@@ -5456,7 +5461,7 @@ class AsyncTable:
AsyncTakeQuery
A query object that can be executed to get the rows.
"""
return AsyncTakeQuery(self._inner.take_row_ids(row_ids), self)
return AsyncTakeQuery(self._inner.take_row_ids(row_ids))
@property
def tags(self) -> AsyncTags:
@@ -5476,6 +5481,19 @@ class AsyncTable:
"""
return AsyncTags(self._inner)
@property
def branches(self) -> AsyncBranches:
"""Branch management for the table.
Branches are isolated, writable lines of history forked from another
branch (or version). Writes on a branch do not affect ``main``.
"""
return AsyncBranches(self._inner)
def current_branch(self) -> Optional[str]:
"""The branch this table handle is scoped to, or ``None`` for ``main``."""
return self._inner.current_branch()
async def optimize(
self,
*,
@@ -5811,6 +5829,65 @@ class Tags:
LOOP.run(self._table.tags.update(tag, version))
class Branches:
"""
Table branch manager.
"""
def __init__(self, parent: "LanceTable"):
self._parent = parent
self._table = parent._table
def list(self) -> Dict[str, Any]:
"""List all branches, mapping name to branch metadata."""
return LOOP.run(self._table.branches.list())
def create(
self,
name: str,
from_ref: Optional[str] = None,
from_version: Optional[int] = None,
) -> "LanceTable":
"""Create a branch and return a handle scoped to it.
Parameters
----------
name: str
Name of the new branch.
from_ref: str, optional
Source branch to fork from. Defaults to ``main``.
from_version: int, optional
A specific version on ``from_ref`` to fork from. Defaults to latest.
"""
async_table = LOOP.run(
self._table.branches.create(name, from_ref, from_version)
)
return self._wrap(async_table)
def checkout(self, name: str) -> "LanceTable":
"""Check out an existing branch and return a handle scoped to it."""
async_table = LOOP.run(self._table.branches.checkout(name))
return self._wrap(async_table)
def delete(self, name: str) -> None:
"""Delete a branch."""
LOOP.run(self._table.branches.delete(name))
def _wrap(self, async_table: "AsyncTable") -> "LanceTable":
# Reuse the parent's connection + namespace context; from_inner would drop
# it and break identity/query routing for namespace-backed tables.
parent = self._parent
return LanceTable(
parent._conn,
async_table.name,
namespace_path=parent._namespace_path,
namespace_client=parent._namespace_client,
pushdown_operations=parent._pushdown_operations,
location=parent._location,
_async=async_table,
)
class AsyncTags:
"""
Async table tag manager.
@@ -5878,3 +5955,47 @@ class AsyncTags:
The new table version to tag.
"""
await self._table.tags.update(tag, version)
class AsyncBranches:
"""Async table branch manager."""
def __init__(self, table):
self._table = table
async def list(self) -> Dict[str, Any]:
"""List all branches, mapping name to branch metadata."""
return await self._table.branches.list()
async def create(
self,
name: str,
from_ref: Optional[str] = None,
from_version: Optional[int] = None,
) -> "AsyncTable":
"""Create a branch and return a handle scoped to it.
Parameters
----------
name: str
Name of the new branch.
from_ref: str, optional
Source branch to fork from. Defaults to ``main``.
from_version: int, optional
A specific version on ``from_ref`` to fork from. Defaults to latest.
"""
# "main" and None are two spellings of the root branch in lance; normalize
# so from_ref="main" behaves identically to the default.
if from_ref == "main":
from_ref = None
inner = await self._table.branches.create(name, from_ref, from_version)
return AsyncTable(inner)
async def checkout(self, name: str) -> "AsyncTable":
"""Check out an existing branch and return a handle scoped to it."""
inner = await self._table.branches.checkout(name)
return AsyncTable(inner)
async def delete(self, name: str) -> None:
"""Delete a branch."""
await self._table.branches.delete(name)

View File

@@ -1,56 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import pickle
from lancedb.remote.errors import HttpError, LanceDBClientError, RetryError
def test_pickle_lancedb_client_error():
err = LanceDBClientError("something went wrong", "req-123", 400)
restored = pickle.loads(pickle.dumps(err))
assert str(restored) == "something went wrong"
assert restored.request_id == "req-123"
assert restored.status_code == 400
def test_pickle_lancedb_client_error_no_status_code():
err = LanceDBClientError("fail", "req-456")
restored = pickle.loads(pickle.dumps(err))
assert str(restored) == "fail"
assert restored.request_id == "req-456"
assert restored.status_code is None
def test_pickle_http_error():
err = HttpError("not found", "req-789", 404)
restored = pickle.loads(pickle.dumps(err))
assert isinstance(restored, HttpError)
assert str(restored) == "not found"
assert restored.request_id == "req-789"
assert restored.status_code == 404
def test_pickle_retry_error():
err = RetryError(
"max retries exceeded",
"req-abc",
request_failures=3,
connect_failures=1,
read_failures=2,
max_request_failures=5,
max_connect_failures=3,
max_read_failures=3,
status_code=503,
)
restored = pickle.loads(pickle.dumps(err))
assert isinstance(restored, RetryError)
assert str(restored) == "max retries exceeded"
assert restored.request_id == "req-abc"
assert restored.request_failures == 3
assert restored.connect_failures == 1
assert restored.read_failures == 2
assert restored.max_request_failures == 5
assert restored.max_connect_failures == 3
assert restored.max_read_failures == 3
assert restored.status_code == 503

View File

@@ -76,35 +76,6 @@ class TestNamespaceConnection:
assert len(result) == 0
assert list(result.columns) == ["id", "vector", "text"]
def test_table_to_pandas_blob_lazy_through_namespace(self):
"""Namespace-backed tables should use Lance blob-aware pandas conversion."""
pytest.importorskip("lance")
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
db.create_namespace(["test_ns"])
data = pa.table(
{
"id": pa.array([1, 2], pa.int64()),
"blob": pa.array([b"hello", b"world"], pa.large_binary()),
},
schema=pa.schema(
[
pa.field("id", pa.int64()),
pa.field(
"blob",
pa.large_binary(),
metadata={"lance-encoding:blob": "true"},
),
]
),
)
table = db.create_table("blob_table", data, namespace_path=["test_ns"])
df = table.to_pandas(blob_mode="lazy").sort_values("id")
blob = df["blob"].iloc[0]
assert hasattr(blob, "readall")
assert blob.readall() == b"hello"
def test_open_table_through_namespace(self):
"""Test opening an existing table through namespace."""
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})

View File

@@ -39,35 +39,6 @@ from utils import exception_output
from importlib.util import find_spec
def _blob_query_data():
return pa.table(
{
"id": pa.array([1, 2, 3, 4], pa.int64()),
"tag": pa.array(["drop", "keep", "keep", "keep"], pa.utf8()),
"vector": pa.array(
[[1.0, 0.0], [2.0, 0.0], [3.0, 0.0], [4.0, 0.0]],
type=pa.list_(pa.float32(), list_size=2),
),
"blob": pa.array([b"one", b"two", b"three", b"four"], pa.large_binary()),
},
schema=pa.schema(
[
pa.field("id", pa.int64()),
pa.field("tag", pa.utf8()),
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
pa.field(
"blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
),
]
),
)
def _assert_lazy_blob(value, expected: bytes):
assert hasattr(value, "readall")
assert value.readall() == expected
@pytest.fixture(scope="module")
def table(tmpdir_factory) -> lancedb.table.Table:
tmp_path = str(tmpdir_factory.mktemp("data"))
@@ -210,216 +181,6 @@ async def test_query_to_pandas_kwargs(table, table_async):
assert async_df["id"].tolist() == [1, 2]
@pytest.mark.parametrize("blob_mode", ["lazy", "bytes", "descriptions"])
def test_plain_scan_query_to_pandas_blob_modes(tmp_db, blob_mode):
pytest.importorskip("lance")
table = tmp_db.create_table(
f"test_query_to_pandas_blob_{blob_mode}", _blob_query_data()
)
df = (
table.search()
.select(["id", "blob"])
.where("id = 1")
.to_pandas(blob_mode=blob_mode)
)
assert df["id"].tolist() == [1]
if blob_mode == "lazy":
_assert_lazy_blob(df["blob"].iloc[0], b"one")
elif blob_mode == "bytes":
assert df["blob"].tolist() == [b"one"]
else:
first = df["blob"].iloc[0]
assert first != b"one"
assert not hasattr(first, "readall")
def test_plain_scan_query_to_pandas_blob_projection(tmp_db):
pytest.importorskip("lance")
table = tmp_db.create_table(
"test_query_to_pandas_blob_projection", _blob_query_data()
)
df = (
table.search()
.where("id >= 2")
.select({"id_alias": "id", "payload": "blob", "double_id": "id * 2"})
.limit(2)
.offset(1)
.to_pandas(blob_mode="bytes")
)
assert df["id_alias"].tolist() == [3, 4]
assert df["payload"].tolist() == [b"three", b"four"]
assert df["double_id"].tolist() == [6, 8]
@pytest.mark.parametrize("blob_mode", ["bytes", "descriptions"])
def test_plain_scan_query_to_pandas_blob_mode_does_not_collect_arrow(
tmp_db, monkeypatch, blob_mode
):
pytest.importorskip("lance")
table = tmp_db.create_table(
"test_query_to_pandas_blob_no_arrow_collect", _blob_query_data()
)
query = table.search().where("id = 1").select(["id", "blob"])
def fail_to_arrow(*args, **kwargs):
raise AssertionError("to_arrow should not be called before native pandas")
monkeypatch.setattr(query, "to_arrow", fail_to_arrow)
df = query.to_pandas(blob_mode=blob_mode)
assert df["id"].tolist() == [1]
if blob_mode == "bytes":
assert df["blob"].tolist() == [b"one"]
else:
first = df["blob"].iloc[0]
assert first != b"one"
assert not hasattr(first, "readall")
def test_plain_scan_query_to_pandas_blob_descriptions_flatten_uses_scanner(
tmp_db, monkeypatch
):
pytest.importorskip("lance")
table = tmp_db.create_table(
"test_query_to_pandas_blob_desc_flatten", _blob_query_data()
)
query = table.search().where("id = 1").select(["id", "blob"])
def fail_to_arrow(*args, **kwargs):
raise AssertionError("to_arrow should not be called before scanner pandas")
monkeypatch.setattr(query, "to_arrow", fail_to_arrow)
df = query.to_pandas(blob_mode="descriptions", flatten=True)
assert df["id"].tolist() == [1]
assert any(column == "blob" or column.startswith("blob.") for column in df.columns)
def test_plain_scan_query_to_pandas_scanner_state(tmp_db):
pytest.importorskip("lance")
data = _blob_query_data()
table = tmp_db.create_table("test_query_to_pandas_scanner_state", data.slice(0, 2))
table.add(data.slice(2, 2))
fragments = table.to_lance().get_fragments()
assert len(fragments) == 2
query = (
table.search()
.select(["id", "blob"])
.with_row_address()
.fragment_ids([fragments[1].fragment_id])
)
query_obj = query.to_query_object()
assert query_obj.with_row_address is True
assert query_obj.fragment_ids == [fragments[1].fragment_id]
df = query.to_pandas(blob_mode="descriptions")
assert df["id"].tolist() == [3, 4]
assert "_rowaddr" in df.columns
assert {rowaddr >> 32 for rowaddr in df["_rowaddr"]} == {fragments[1].fragment_id}
df_by_fragment = (
table.search()
.select(["id", "blob"])
.with_fragments([fragments[0]])
.to_pandas(blob_mode="descriptions")
)
assert df_by_fragment["id"].tolist() == [1, 2]
@pytest.mark.asyncio
async def test_async_plain_scan_query_to_pandas_blob_projection(tmp_db_async):
pytest.importorskip("lance")
table = await tmp_db_async.create_table(
"test_async_query_to_pandas_blob_projection", _blob_query_data()
)
lazy_df = await (
table.query().where("id = 1").select(["id", "blob"]).to_pandas(blob_mode="lazy")
)
assert lazy_df["id"].tolist() == [1]
_assert_lazy_blob(lazy_df["blob"].iloc[0], b"one")
bytes_df = await (
table.query()
.where("id >= 2")
.select({"id_alias": "id", "payload": "blob", "double_id": "id * 2"})
.limit(2)
.offset(1)
.to_pandas(blob_mode="bytes")
)
assert bytes_df["id_alias"].tolist() == [3, 4]
assert bytes_df["payload"].tolist() == [b"three", b"four"]
assert bytes_df["double_id"].tolist() == [6, 8]
desc_df = await (
table.query()
.where("id = 1")
.select(["blob"])
.to_pandas(blob_mode="descriptions")
)
first = desc_df["blob"].iloc[0]
assert first != b"one"
assert not hasattr(first, "readall")
@pytest.mark.asyncio
@pytest.mark.parametrize("blob_mode", ["bytes", "descriptions"])
async def test_async_plain_scan_query_to_pandas_blob_mode_does_not_collect_arrow(
tmp_db_async, monkeypatch, blob_mode
):
pytest.importorskip("lance")
table = await tmp_db_async.create_table(
"test_async_query_to_pandas_blob_no_arrow_collect", _blob_query_data()
)
query = table.query().where("id = 1").select(["id", "blob"])
async def fail_to_arrow(*args, **kwargs):
raise AssertionError("to_arrow should not be called before native pandas")
monkeypatch.setattr(query, "to_arrow", fail_to_arrow)
df = await query.to_pandas(blob_mode=blob_mode)
assert df["id"].tolist() == [1]
if blob_mode == "bytes":
assert df["blob"].tolist() == [b"one"]
else:
first = df["blob"].iloc[0]
assert first != b"one"
assert not hasattr(first, "readall")
def test_vector_query_to_pandas_blob_mode_requires_native_path(tmp_db):
pytest.importorskip("lance")
table = tmp_db.create_table("test_vector_query_blob_mode", _blob_query_data())
with pytest.raises(RuntimeError, match="Lance native pandas conversion"):
table.search([1.0, 0.0]).select(["blob", "vector"]).limit(1).to_pandas(
blob_mode="lazy"
)
def test_vector_query_to_pandas_blob_descriptions_requires_plain_scan(tmp_db):
pytest.importorskip("lance")
table = tmp_db.create_table(
"test_vector_query_blob_descriptions", _blob_query_data()
)
with pytest.raises(RuntimeError, match="plain scan query"):
table.search([1.0, 0.0]).select(["blob", "vector"]).limit(1).to_pandas(
blob_mode="descriptions"
)
def test_order_by_plain_query(mem_db):
table = mem_db.create_table(
"test_order_by",

View File

@@ -344,12 +344,6 @@ def test_mrr_reranker(tmp_path):
assert len(result_deduped) == len(result)
def test_mrr_reranker_empty_input():
reranker = MRRReranker()
with pytest.raises(ValueError, match="must not be empty"):
reranker.rerank_multivector([])
def test_rrf_reranker_distance():
data = pa.table(
{

View File

@@ -4,7 +4,6 @@
import os
import sys
import threading
import warnings
from datetime import date, datetime, timedelta
from time import sleep
@@ -27,28 +26,6 @@ from lancedb.table import LanceTable
from pydantic import BaseModel
def _blob_test_data():
return pa.table(
{
"id": pa.array([1, 2], pa.int64()),
"blob": pa.array([b"hello", b"world"], pa.large_binary()),
},
schema=pa.schema(
[
pa.field("id", pa.int64()),
pa.field(
"blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
),
]
),
)
def _assert_lazy_blob(value, expected: bytes):
assert hasattr(value, "readall")
assert value.readall() == expected
def test_basic(mem_db: DBConnection):
data = [
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
@@ -80,30 +57,27 @@ def test_table_to_pandas_default_matches_arrow(tmp_db: DBConnection):
pd.testing.assert_frame_equal(table.to_pandas(), expected)
def test_table_to_pandas_invalid_blob_mode_non_blob_table(tmp_db: DBConnection):
data = pa.table({"id": [1, 2], "text": ["one", "two"]})
table = tmp_db.create_table("test_to_pandas_invalid_blob_mode", data=data)
with pytest.raises(ValueError, match="blob_mode must be one of"):
table.to_pandas(blob_mode="invalid")
@pytest.mark.parametrize("blob_mode", ["lazy", "bytes", "descriptions"])
def test_table_to_pandas_blob_modes(tmp_db: DBConnection, blob_mode):
def test_table_to_pandas_blob_bytes(tmp_db: DBConnection):
pytest.importorskip("lance")
table = tmp_db.create_table(f"test_to_pandas_blob_{blob_mode}", _blob_test_data())
data = pa.table(
{
"id": pa.array([1, 2], pa.int64()),
"blob": pa.array([b"hello", b"world"], pa.large_binary()),
},
schema=pa.schema(
[
pa.field("id", pa.int64()),
pa.field(
"blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
),
]
),
)
table = tmp_db.create_table("test_to_pandas_blob_bytes", data=data)
df = table.to_pandas(blob_mode=blob_mode)
df = table.to_pandas(blob_mode="bytes")
if blob_mode == "lazy":
_assert_lazy_blob(df["blob"].iloc[0], b"hello")
_assert_lazy_blob(df["blob"].iloc[1], b"world")
elif blob_mode == "bytes":
assert df["blob"].tolist() == [b"hello", b"world"]
else:
first = df["blob"].iloc[0]
assert first != b"hello"
assert not hasattr(first, "readall")
assert df["blob"].tolist() == [b"hello", b"world"]
def test_table_to_pandas_kwargs(tmp_db: DBConnection):
@@ -119,8 +93,22 @@ def test_table_to_pandas_kwargs(tmp_db: DBConnection):
@pytest.mark.asyncio
async def test_async_table_to_pandas_blob_bytes(tmp_db_async: AsyncConnection):
pytest.importorskip("lance")
data = pa.table(
{
"id": pa.array([1, 2], pa.int64()),
"blob": pa.array([b"hello", b"world"], pa.large_binary()),
},
schema=pa.schema(
[
pa.field("id", pa.int64()),
pa.field(
"blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
),
]
),
)
table = await tmp_db_async.create_table(
"test_async_to_pandas_blob_bytes", data=_blob_test_data()
"test_async_to_pandas_blob_bytes", data=data
)
df = await table.to_pandas(blob_mode="bytes")
@@ -128,19 +116,6 @@ async def test_async_table_to_pandas_blob_bytes(tmp_db_async: AsyncConnection):
assert df["blob"].tolist() == [b"hello", b"world"]
@pytest.mark.asyncio
async def test_async_table_to_pandas_invalid_blob_mode_non_blob_table(
tmp_db_async: AsyncConnection,
):
table = await tmp_db_async.create_table(
"test_async_to_pandas_invalid_blob_mode",
data=pa.table({"id": [1, 2], "text": ["one", "two"]}),
)
with pytest.raises(ValueError, match="blob_mode must be one of"):
await table.to_pandas(blob_mode="invalid")
@pytest.mark.asyncio
async def test_async_table_to_pandas_kwargs(tmp_db_async: AsyncConnection):
pd = pytest.importorskip("pandas")
@@ -928,6 +903,160 @@ async def test_async_tags(mem_db_async: AsyncConnection):
)
def test_branches(tmp_path):
db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(0))
table = db.create_table(
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
assert table.count_rows() == 2
# fork an isolated, writable branch from main
branch = table.branches.create("exp")
assert branch.count_rows() == 2
branch.add(data=[{"vector": [10.0, 11.0], "item": "baz", "price": 30.0}])
# writes on the branch do not touch main
assert branch.count_rows() == 3
assert table.count_rows() == 2
# the branch is listed, with main (None) as its parent
branches = table.branches.list()
assert "exp" in branches
assert branches["exp"]["parent_branch"] is None
# from_ref="main" is equivalent to the default
table.branches.create("exp2", from_ref="main")
assert table.branches.list()["exp2"]["parent_branch"] is None
# checkout returns a handle scoped to the branch's latest
checked_out = table.branches.checkout("exp")
assert checked_out.count_rows() == 3
# delete removes it
table.branches.delete("exp")
table.branches.delete("exp2")
assert "exp" not in table.branches.list()
def test_branch_handle_tracks_concurrent_writes(tmp_path):
db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(0))
table = db.create_table("t", [{"id": 1}])
# two independent handles on the same branch
writer = table.branches.create("exp")
reader = db.open_table("t", branch="exp")
assert reader.count_rows() == 1
# a concurrent write on the branch is visible to the other handle
writer.add([{"id": 2}])
assert reader.count_rows() == 2
# main is unaffected
assert table.count_rows() == 1
def test_branch_name_validation(tmp_path):
db = lancedb.connect(tmp_path)
table = db.create_table("t", [{"id": 1}])
with pytest.raises(ValueError, match="non-empty"):
table.branches.create("")
with pytest.raises(ValueError, match="non-empty"):
table.branches.checkout("")
with pytest.raises(ValueError, match="non-empty"):
table.branches.delete("")
def test_branches_preserve_namespace(tmp_path):
pytest.importorskip(
"lance"
) # namespace_path routes through lance's DirectoryNamespace
db = lancedb.connect(tmp_path)
table = db.create_table("t", [{"id": 1}], namespace_path=["ns1"])
assert table.namespace == ["ns1"]
branch = table.branches.create("exp")
assert branch.namespace == ["ns1"]
assert branch.id == table.id
# opening the branch directly also preserves namespace identity
opened = db.open_table("t", namespace_path=["ns1"], branch="exp")
assert opened.namespace == ["ns1"]
def test_open_table_with_branch(tmp_path):
db = lancedb.connect(tmp_path)
table = db.create_table("t", [{"i": 1}])
table.branches.create("exp").add([{"i": 2}])
# open_table(branch=...) returns a handle scoped to the branch
assert db.open_table("t", branch="exp").count_rows() == 2
# opening without branch still tracks main
assert db.open_table("t").count_rows() == 1
@pytest.mark.asyncio
async def test_async_namespace_open_table_with_branch(tmp_path):
pytest.importorskip("lance") # "dir" impl is lance.namespace.DirectoryNamespace
db = lancedb.connect_namespace_async("dir", {"root": str(tmp_path)})
await db.create_namespace(["ns1"])
table = await db.create_table("t", [{"id": 1}], namespace_path=["ns1"])
branch = await table.branches.create("exp")
await branch.add([{"id": 2}])
# open_table(branch=...) on the async namespace connection must work
opened = await db.open_table("t", namespace_path=["ns1"], branch="exp")
assert await opened.count_rows() == 2
def test_branch_to_lance_targets_branch(tmp_path):
pytest.importorskip("lance")
db = lancedb.connect(tmp_path)
table = db.create_table("t", [{"i": 1}])
branch = table.branches.create("exp")
branch.add([{"i": 2}]) # branch: 2 rows, main: 1 row
assert branch.to_lance().count_rows() == 2
assert table.to_lance().count_rows() == 1
@pytest.mark.asyncio
async def test_async_branches(tmp_path):
db = await lancedb.connect_async(tmp_path)
table = await db.create_table(
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
assert await table.count_rows() == 2
branch = await table.branches.create("exp")
assert await branch.count_rows() == 2
await branch.add(data=[{"vector": [10.0, 11.0], "item": "baz", "price": 30.0}])
assert await branch.count_rows() == 3
assert await table.count_rows() == 2
branches = await table.branches.list()
assert "exp" in branches
assert branches["exp"]["parent_branch"] is None
await table.branches.create("exp2", from_ref="main")
assert (await table.branches.list())["exp2"]["parent_branch"] is None
checked_out = await table.branches.checkout("exp")
assert await checked_out.count_rows() == 3
await table.branches.delete("exp")
await table.branches.delete("exp2")
assert "exp" not in await table.branches.list()
@patch("lancedb.table.AsyncTable.create_index")
def test_create_index_method(mock_create_index, mem_db: DBConnection):
table = mem_db.create_table(
@@ -1289,45 +1418,6 @@ def test_add_with_empty_fixed_size_list_drops_bad_rows(mem_db: DBConnection):
assert np.allclose(data["embedding"].to_pylist()[0], np.array([0.1] * 16))
def test_add_nullable_struct_with_none(mem_db: DBConnection):
"""Regression test for issue #2654: a nullable struct column whose
first batch contains only None values must not crash in
_align_field_types with AttributeError: 'pyarrow.lib.DataType'
object has no attribute 'fields'.
PyArrow infers an all-None struct column as `null` (not `struct`),
so the type-alignment path needs to handle the case where the
source field type is null and use the target type directly.
"""
# Use the v2.1 file format so that nullable structs are supported.
table = mem_db.create_table(
"test_nullable_struct",
schema=pa.schema(
[
pa.field("id", pa.string()),
pa.field(
"data",
pa.struct([pa.field("x", pa.float32())]),
nullable=True,
),
]
),
storage_options=dict(new_table_data_storage_version="2.1"),
)
# Adding a row with a non-null struct should work.
table.add([{"id": "1", "data": {"x": 1.0}}])
# Adding a row with None for the nullable struct field should also
# work — this is what used to crash.
table.add([{"id": "2", "data": None}])
result = table.to_arrow()
assert result.num_rows == 2
assert result.column("id").to_pylist() == ["1", "2"]
assert result.column("data").to_pylist() == [{"x": 1.0}, None]
def test_add_with_integer_embeddings_preserves_casting(mem_db: DBConnection):
class Schema(LanceModel):
text: str
@@ -2838,38 +2928,3 @@ def test_sanitize_data_metadata_not_stripped():
assert result_schema.metadata is not None
assert result_schema.metadata[b"existing_key"] == b"existing_value"
assert result_schema.metadata[b"new_key"] == b"new_value"
@pytest.mark.asyncio
async def test_async_search_runs_embedding_on_dedicated_executor(
mem_db_async: AsyncConnection,
):
# Regression test for #3310: AsyncTable.search() must run the (potentially
# blocking) query-embedding call on the dedicated embedding executor, not
# asyncio's default executor -- which is shared with other blocking I/O and
# can be starved by a slow embedding call under concurrent load.
func = MockTextEmbeddingFunction.create()
class Schema(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = await mem_db_async.create_table("embed_executor", schema=Schema)
await table.add([{"text": "hello world"}])
captured_threads: List[str] = []
original = MockTextEmbeddingFunction.generate_embeddings
def record_thread(self, texts):
captured_threads.append(threading.current_thread().name)
return original(self, texts)
# Patch only around the search so we capture the query-embedding call, not
# the add-time source-embedding call.
with patch.object(MockTextEmbeddingFunction, "generate_embeddings", record_thread):
await (await table.search("a query string")).limit(1).to_list()
assert captured_threads, "search did not invoke the embedding function"
assert all(name.startswith("lancedb-embedding") for name in captured_threads), (
f"embedding ran off the dedicated executor: {captured_threads}"
)

View File

@@ -17,7 +17,7 @@ use arrow::{
};
use lancedb::table::{
AddDataMode, ColumnAlteration, Duration, FieldMetadataUpdate, NewColumnTransform,
OptimizeAction, OptimizeOptions, Table as LanceDbTable,
OptimizeAction, OptimizeOptions, Ref, Table as LanceDbTable,
};
use pyo3::{
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
@@ -864,6 +864,15 @@ impl Table {
Ok(Tags::new(self.inner_ref()?.clone()))
}
pub fn current_branch(&self) -> PyResult<Option<String>> {
Ok(self.inner_ref()?.current_branch())
}
#[getter]
pub fn branches(&self) -> PyResult<Branches> {
Ok(Branches::new(self.inner_ref()?.clone()))
}
#[pyo3(signature = (offsets))]
pub fn take_offsets(self_: PyRef<'_, Self>, offsets: Vec<u64>) -> PyResult<TakeQuery> {
Ok(TakeQuery::new(
@@ -1265,3 +1274,66 @@ impl Tags {
})
}
}
#[pyclass]
pub struct Branches {
inner: LanceDbTable,
}
impl Branches {
pub fn new(table: LanceDbTable) -> Self {
Self { inner: table }
}
}
#[pymethods]
impl Branches {
pub fn list(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let res = inner.list_branches().await.infer_error()?;
Python::attach(|py| {
let py_dict = PyDict::new(py);
for (name, contents) in res {
let value = PyDict::new(py);
value.set_item("parent_branch", contents.parent_branch)?;
value.set_item("parent_version", contents.parent_version)?;
value.set_item("manifest_size", contents.manifest_size)?;
py_dict.set_item(name, value)?;
}
Ok(py_dict.unbind())
})
})
}
#[pyo3(signature = (name, from_ref=None, from_version=None))]
pub fn create(
self_: PyRef<'_, Self>,
name: String,
from_ref: Option<String>,
from_version: Option<u64>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let from = Ref::Version(from_ref, from_version);
let table = inner.create_branch(&name, from).await.infer_error()?;
Ok(Table::new(table))
})
}
pub fn checkout(self_: PyRef<'_, Self>, name: String) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let table = inner.checkout_branch(&name).await.infer_error()?;
Ok(Table::new(table))
})
}
pub fn delete(self_: PyRef<'_, Self>, name: String) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
inner.delete_branch(&name).await.infer_error()?;
Ok(())
})
}
}

4228
python/uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.30.1-beta.2"
version = "0.30.1-beta.0"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true

View File

@@ -119,6 +119,7 @@ pub struct OpenTableBuilder {
parent: Arc<dyn Database>,
request: OpenTableRequest,
embedding_registry: Arc<dyn EmbeddingRegistry>,
branch: Option<String>,
}
impl OpenTableBuilder {
@@ -139,6 +140,7 @@ impl OpenTableBuilder {
managed_versioning: None,
},
embedding_registry,
branch: None,
}
}
@@ -259,14 +261,22 @@ impl OpenTableBuilder {
self
}
/// Open the table scoped to the given branch instead of the default branch.
///
/// Reads and writes on the returned table operate in the branch's context.
pub fn branch(mut self, branch: impl Into<String>) -> Self {
self.branch = Some(branch.into());
self
}
/// Open the table
pub async fn execute(self) -> Result<Table> {
let table = self.parent.open_table(self.request).await?;
Ok(Table::new_with_embedding_registry(
table,
self.parent,
self.embedding_registry,
))
let table = Table::new_with_embedding_registry(table, self.parent, self.embedding_registry);
match self.branch {
Some(branch) => table.checkout_branch(&branch).await,
None => Ok(table),
}
}
}

View File

@@ -740,6 +740,64 @@ mod tests {
assert!(table_names.contains(&"test_table".to_string()));
}
#[tokio::test]
async fn test_namespace_branch_query_under_pushdown_stays_local() {
// With QueryTable pushdown enabled, a query on the main branch routes to
// the namespace server, but a branch handle must run locally: the
// server-side request carries no branch and would return main's rows.
let tmp_dir = tempdir().unwrap();
let root_path = tmp_dir.path().to_str().unwrap().to_string();
let mut properties = HashMap::new();
properties.insert("root".to_string(), root_path);
let conn = connect_namespace("dir", properties)
.pushdown_operation(NamespaceClientPushdownOperation::QueryTable)
.execute()
.await
.expect("Failed to connect to namespace");
conn.create_namespace(CreateNamespaceRequest {
id: Some(vec!["test_ns".into()]),
..Default::default()
})
.await
.expect("Failed to create namespace");
// main has 5 rows
let table = conn
.create_table("ref_test", create_test_data())
.namespace(vec!["test_ns".into()])
.execute()
.await
.expect("Failed to create table");
let main_version = table.version().await.unwrap();
// fork a branch off main, then add 5 more rows so it differs from main
let branch = table
.create_branch("exp", main_version)
.await
.expect("Failed to create branch");
branch
.add(create_test_data())
.execute()
.await
.expect("Failed to append to branch");
// the branch query must run locally and see the branch's 10 rows --
// not get routed to the server (which carries no branch) and see main's 5
let results = branch
.query()
.execute()
.await
.expect("Failed to query branch")
.try_collect::<Vec<_>>()
.await
.expect("Failed to collect results");
let count: usize = results.iter().map(|b| b.num_rows()).sum();
assert_eq!(count, 10);
}
#[tokio::test]
async fn test_namespace_describe_table() {
// Setup: Create a temporary directory for the namespace

View File

@@ -203,11 +203,11 @@ impl Shuffler {
// Finish writing files
for (file_idx, mut writer) in file_writers.into_iter().enumerate() {
let write_summary = writer.finish().await?;
let num_written = writer.finish().await?;
log::debug!(
"Shuffle job {}: wrote {} rows to file {}",
self.id,
write_summary.num_rows,
num_written,
file_idx
);
}

View File

@@ -23,7 +23,6 @@ use crate::table::DropColumnsResult;
use crate::table::MergeResult;
use crate::table::Tags;
use crate::table::UpdateResult;
use crate::table::merge::MergeFilter;
use crate::table::query::create_multi_vector_plan;
use crate::table::{AlterColumnsResult, FieldMetadataUpdate, UpdateFieldMetadataResult};
use crate::table::{AnyQuery, Filter, Predicate, PreprocessingOutput, TableStatistics};
@@ -1384,6 +1383,38 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
.map_err(unwrap_shared_error)
}
async fn create_branch(
&self,
_name: &str,
_from: lance::dataset::refs::Ref,
) -> Result<Arc<dyn BaseTable>> {
Err(Error::NotSupported {
message: "branching is not yet supported on remote tables".into(),
})
}
async fn checkout_branch(&self, _name: &str) -> Result<Arc<dyn BaseTable>> {
Err(Error::NotSupported {
message: "branching is not yet supported on remote tables".into(),
})
}
async fn list_branches(&self) -> Result<HashMap<String, lance::dataset::refs::BranchContents>> {
Err(Error::NotSupported {
message: "branching is not yet supported on remote tables".into(),
})
}
async fn delete_branch(&self, _name: &str) -> Result<()> {
Err(Error::NotSupported {
message: "branching is not yet supported on remote tables".into(),
})
}
fn current_branch(&self) -> Option<String> {
None
}
async fn count_rows(&self, filter: Option<Filter>) -> Result<usize> {
let mut request = self.post_read(&format!("/v1/table/{}/count_rows/", self.identifier));
@@ -1827,57 +1858,16 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
})
}
async fn set_lsm_write_spec(&self, spec: crate::table::LsmWriteSpec) -> Result<()> {
use crate::table::LsmWriteSpec;
self.check_mutable().await?;
// Map the spec onto the server's request DTO. `sharding` is internally
// tagged on `mode` to mirror sophon's `Sharding` enum; `maintained_indexes`
// and `writer_config_defaults` are sent verbatim (an empty list means "no
// maintained indexes", not "default to all").
let sharding = match &spec {
LsmWriteSpec::Bucket {
column,
num_buckets,
..
} => serde_json::json!({
"mode": "bucket",
"column": column,
"num_buckets": num_buckets,
}),
LsmWriteSpec::Identity { column, .. } => serde_json::json!({
"mode": "identity",
"column": column,
}),
LsmWriteSpec::Unsharded { .. } => serde_json::json!({ "mode": "unsharded" }),
};
let body = serde_json::json!({
"sharding": sharding,
"maintained_indexes": spec.maintained_indexes(),
"writer_config_defaults": spec.writer_config_defaults(),
});
let request = self
.client
.post(&format!(
"/v1/table/{}/set_lsm_write_spec/",
self.identifier
))
.json(&body);
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
async fn set_lsm_write_spec(&self, _spec: crate::table::LsmWriteSpec) -> Result<()> {
Err(Error::NotSupported {
message: "set_lsm_write_spec is not supported on LanceDB cloud.".into(),
})
}
async fn unset_lsm_write_spec(&self) -> Result<()> {
self.check_mutable().await?;
let request = self.client.post(&format!(
"/v1/table/{}/unset_lsm_write_spec/",
self.identifier
));
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
Err(Error::NotSupported {
message: "unset_lsm_write_spec is not supported on LanceDB cloud.".into(),
})
}
async fn tags(&self) -> Result<Box<dyn Tags + '_>> {
@@ -2308,34 +2298,13 @@ impl TryFrom<MergeInsertBuilder> for MergeInsertRequest {
}
let on = value.on[0].clone();
let when_matched_update_all_filt = match value.when_matched_update_all_filt {
Some(MergeFilter::Sql(sql)) => Some(sql),
Some(MergeFilter::Expr(_)) => {
return Err(Error::NotSupported {
message: "DataFusion expressions are not supported on remote tables".into(),
});
}
None => None,
};
let when_not_matched_by_source_delete_filt =
match value.when_not_matched_by_source_delete_filt {
Some(MergeFilter::Sql(sql)) => Some(sql),
Some(MergeFilter::Expr(_)) => {
return Err(Error::NotSupported {
message: "DataFusion expressions are not supported on remote tables".into(),
});
}
None => None,
};
Ok(Self {
on,
when_matched_update_all: value.when_matched_update_all,
when_matched_update_all_filt,
when_matched_update_all_filt: value.when_matched_update_all_filt,
when_not_matched_insert_all: value.when_not_matched_insert_all,
when_not_matched_by_source_delete: value.when_not_matched_by_source_delete,
when_not_matched_by_source_delete_filt,
when_not_matched_by_source_delete_filt: value.when_not_matched_by_source_delete_filt,
// Only serialize use_index when it's false for backwards compatibility
use_index: value.use_index,
})
@@ -4469,91 +4438,6 @@ mod tests {
assert!(matches!(e, Error::IndexNotFound { .. }));
}
#[tokio::test]
async fn test_set_lsm_write_spec_unsharded() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(
request.url().path(),
"/v1/table/my_table/set_lsm_write_spec/"
);
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
assert_eq!(body["sharding"], serde_json::json!({ "mode": "unsharded" }));
assert_eq!(body["maintained_indexes"], serde_json::json!(["id_idx"]));
assert_eq!(
body["writer_config_defaults"],
serde_json::json!({ "max_memtable_rows": "1000" })
);
http::Response::builder()
.status(200)
.body(r#"{"maintained_indexes":["id_idx"]}"#)
.unwrap()
});
let spec = crate::table::LsmWriteSpec::unsharded()
.with_maintained_indexes(["id_idx"])
.with_writer_config_defaults([("max_memtable_rows", "1000")]);
table.set_lsm_write_spec(spec).await.unwrap();
}
#[tokio::test]
async fn test_set_lsm_write_spec_bucket() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(
request.url().path(),
"/v1/table/my_table/set_lsm_write_spec/"
);
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
assert_eq!(
body["sharding"],
serde_json::json!({ "mode": "bucket", "column": "id", "num_buckets": 16 })
);
assert_eq!(body["maintained_indexes"], serde_json::json!([]));
http::Response::builder().status(200).body("{}").unwrap()
});
table
.set_lsm_write_spec(crate::table::LsmWriteSpec::bucket("id", 16))
.await
.unwrap();
}
#[tokio::test]
async fn test_set_lsm_write_spec_identity() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(
request.url().path(),
"/v1/table/my_table/set_lsm_write_spec/"
);
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
assert_eq!(
body["sharding"],
serde_json::json!({ "mode": "identity", "column": "tenant" })
);
http::Response::builder().status(200).body("{}").unwrap()
});
table
.set_lsm_write_spec(crate::table::LsmWriteSpec::identity("tenant"))
.await
.unwrap();
}
#[tokio::test]
async fn test_unset_lsm_write_spec() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(
request.url().path(),
"/v1/table/my_table/unset_lsm_write_spec/"
);
http::Response::builder().status(200).body("{}").unwrap()
});
table.unset_lsm_write_spec().await.unwrap();
}
#[tokio::test]
async fn test_wait_for_index() {
let table = _make_table_with_indices(0);

View File

@@ -86,7 +86,7 @@ pub use add_data::{AddDataBuilder, AddDataMode, AddResult, NaNVectorBehavior};
pub use chrono::Duration;
pub use delete::DeleteResult;
use futures::future::join_all;
pub use lance::dataset::refs::{TagContents, Tags as LanceTags};
pub use lance::dataset::refs::{BranchContents, Ref, TagContents, Tags as LanceTags};
pub use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::dataset::statistics::DatasetStatisticsExt;
pub use lance_index::optimize::OptimizeOptions;
@@ -625,6 +625,20 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
async fn restore(&self) -> Result<()>;
/// List the versions of the table.
async fn list_versions(&self) -> Result<Vec<Version>>;
/// Create a new branch from `from` and return a handle scoped to it.
async fn create_branch(
&self,
name: &str,
from: lance::dataset::refs::Ref,
) -> Result<Arc<dyn BaseTable>>;
/// Check out an existing branch and return a handle scoped to it.
async fn checkout_branch(&self, name: &str) -> Result<Arc<dyn BaseTable>>;
/// List the branches of the table.
async fn list_branches(&self) -> Result<HashMap<String, BranchContents>>;
/// Delete a branch.
async fn delete_branch(&self, name: &str) -> Result<()>;
/// The branch this handle is scoped to, or `None` for `main`.
fn current_branch(&self) -> Option<String>;
/// Get the table definition.
async fn table_definition(&self) -> Result<TableDefinition>;
/// Get the table URI (storage location)
@@ -1625,6 +1639,45 @@ impl Table {
self.inner.tags().await
}
/// Create a new branch from `from` (a version, tag, or branch)
pub async fn create_branch(
&self,
name: &str,
from: impl Into<lance::dataset::refs::Ref>,
) -> Result<Self> {
let inner = self.inner.create_branch(name, from.into()).await?;
Ok(Self {
inner,
database: self.database.clone(),
embedding_registry: self.embedding_registry.clone(),
})
}
/// Check out an existing branch and return a handle scoped to it.
pub async fn checkout_branch(&self, name: &str) -> Result<Self> {
let inner = self.inner.checkout_branch(name).await?;
Ok(Self {
inner,
database: self.database.clone(),
embedding_registry: self.embedding_registry.clone(),
})
}
/// List the branches of the table.
pub async fn list_branches(&self) -> Result<HashMap<String, BranchContents>> {
self.inner.list_branches().await
}
/// Delete a branch.
pub async fn delete_branch(&self, name: &str) -> Result<()> {
self.inner.delete_branch(name).await
}
/// The branch this handle is scoped to, or `None` for `main`.
pub fn current_branch(&self) -> Option<String> {
self.inner.current_branch()
}
/// Retrieve statistics on the table
pub async fn stats(&self) -> Result<TableStatistics> {
self.inner.stats().await
@@ -1861,6 +1914,30 @@ impl NativeTable {
self
}
/// Build a sibling `NativeTable` with the same identity but a different
/// (independent) dataset wrapper — used to hand out branch-scoped handles.
fn with_dataset(&self, dataset: dataset::DatasetConsistencyWrapper) -> Self {
Self {
name: self.name.clone(),
namespace: self.namespace.clone(),
id: self.id.clone(),
uri: self.uri.clone(),
dataset,
read_consistency_interval: self.read_consistency_interval,
namespace_client: self.namespace_client.clone(),
pushdown_operations: self.pushdown_operations.clone(),
}
}
fn validate_branch_name(name: &str, field: &str) -> Result<()> {
if name.is_empty() {
return Err(Error::InvalidInput {
message: format!("{field} must be a non-empty string"),
});
}
Ok(())
}
/// Opens an existing Table using a namespace client.
///
/// This method uses `DatasetBuilder::from_namespace` to open the table, which
@@ -2652,6 +2729,49 @@ impl BaseTable for NativeTable {
self.dataset.reload().await
}
async fn create_branch(
&self,
name: &str,
from: lance::dataset::refs::Ref,
) -> Result<Arc<dyn BaseTable>> {
Self::validate_branch_name(name, "branch name")?;
if let lance::dataset::refs::Ref::Version(Some(from_branch), _) = &from {
Self::validate_branch_name(from_branch, "from_ref")?;
}
let mut ds = (*self.dataset.get().await?).clone();
let branch_ds = ds.create_branch(name, from, None).await?;
let dataset = dataset::DatasetConsistencyWrapper::new_latest(
branch_ds,
self.read_consistency_interval,
);
Ok(Arc::new(self.with_dataset(dataset)))
}
async fn checkout_branch(&self, name: &str) -> Result<Arc<dyn BaseTable>> {
Self::validate_branch_name(name, "branch name")?;
let branch_ds = self.dataset.get().await?.checkout_branch(name).await?;
let dataset = dataset::DatasetConsistencyWrapper::new_latest(
branch_ds,
self.read_consistency_interval,
);
Ok(Arc::new(self.with_dataset(dataset)))
}
async fn list_branches(&self) -> Result<HashMap<String, BranchContents>> {
Ok(self.dataset.get().await?.list_branches().await?)
}
async fn delete_branch(&self, name: &str) -> Result<()> {
Self::validate_branch_name(name, "branch name")?;
let mut ds = (*self.dataset.get().await?).clone();
ds.delete_branch(name).await?;
Ok(())
}
fn current_branch(&self) -> Option<String> {
self.dataset.current_branch()
}
async fn list_versions(&self) -> Result<Vec<Version>> {
Ok(self.dataset.get().await?.versions().await?)
}
@@ -3378,6 +3498,171 @@ mod tests {
assert_eq!(table.version().await.unwrap(), 4);
}
#[tokio::test]
async fn test_branches() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let conn = ConnectBuilder::new(uri)
.read_consistency_interval(Duration::from_secs(0))
.execute()
.await
.unwrap();
// main: one row at v1
let table = conn
.create_table("my_table", some_sample_data())
.execute()
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 1);
assert_eq!(table.current_branch(), None);
let main_version = table.version().await.unwrap();
// branch off main's current version; it starts with main's data
let branch = table.create_branch("exp", main_version).await.unwrap();
assert_eq!(branch.current_branch().as_deref(), Some("exp"));
assert_eq!(branch.count_rows(None).await.unwrap(), 1);
// writes on the branch are isolated from main
branch.add(some_sample_data()).execute().await.unwrap();
assert_eq!(branch.count_rows(None).await.unwrap(), 2);
assert_eq!(
table.count_rows(None).await.unwrap(),
1,
"main must be untouched by branch writes"
);
// the branch shows up in the listing
let branches = table.list_branches().await.unwrap();
assert!(branches.contains_key("exp"));
// checking out the branch from the main handle sees the branch's latest data
let checked_out = table.checkout_branch("exp").await.unwrap();
assert_eq!(checked_out.current_branch().as_deref(), Some("exp"));
assert_eq!(checked_out.count_rows(None).await.unwrap(), 2);
// open_table(...).branch(...) opens directly onto the branch
let opened = conn
.open_table("my_table")
.branch("exp")
.execute()
.await
.unwrap();
assert_eq!(opened.current_branch().as_deref(), Some("exp"));
assert_eq!(opened.count_rows(None).await.unwrap(), 2);
// delete removes it from the listing
table.delete_branch("exp").await.unwrap();
let branches = table.list_branches().await.unwrap();
assert!(!branches.contains_key("exp"));
}
#[tokio::test]
async fn test_branch_name_validation() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let conn = ConnectBuilder::new(uri).execute().await.unwrap();
let table = conn
.create_table("my_table", some_sample_data())
.execute()
.await
.unwrap();
// every entry point rejects an empty name instead of passing it down
assert!(matches!(
table.create_branch("", 1u64).await,
Err(Error::InvalidInput { .. })
));
assert!(matches!(
table.checkout_branch("").await,
Err(Error::InvalidInput { .. })
));
assert!(matches!(
table.delete_branch("").await,
Err(Error::InvalidInput { .. })
));
// an empty source branch is rejected too
assert!(matches!(
table
.create_branch(
"ok",
lance::dataset::refs::Ref::Version(Some(String::new()), None)
)
.await,
Err(Error::InvalidInput { .. })
));
}
#[tokio::test]
async fn test_branch_handle_tracks_concurrent_writes() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
// interval = 0 so every read checks storage for new commits
let conn = ConnectBuilder::new(uri)
.read_consistency_interval(Duration::from_secs(0))
.execute()
.await
.unwrap();
let table = conn
.create_table("my_table", some_sample_data())
.execute()
.await
.unwrap();
let v1 = table.version().await.unwrap();
// two independent handles on the same branch
let writer = table.create_branch("exp", v1).await.unwrap();
let reader = conn
.open_table("my_table")
.branch("exp")
.execute()
.await
.unwrap();
assert_eq!(reader.count_rows(None).await.unwrap(), 1);
// a concurrent write on the branch is visible to the other handle, which
// tracks the branch's HEAD (not main's)
writer.add(some_sample_data()).execute().await.unwrap();
assert_eq!(reader.count_rows(None).await.unwrap(), 2);
// main is untouched
assert_eq!(table.count_rows(None).await.unwrap(), 1);
}
#[tokio::test]
async fn test_branch_handle_without_consistency_interval_is_pinned() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
// default interval (None): handles do not auto-refresh
let conn = ConnectBuilder::new(uri).execute().await.unwrap();
let table = conn
.create_table("my_table", some_sample_data())
.execute()
.await
.unwrap();
let v1 = table.version().await.unwrap();
let writer = table.create_branch("exp", v1).await.unwrap();
let reader = conn
.open_table("my_table")
.branch("exp")
.execute()
.await
.unwrap();
assert_eq!(reader.count_rows(None).await.unwrap(), 1);
// without a consistency interval the reader stays on the version it
// opened, exactly like a main-branch handle...
writer.add(some_sample_data()).execute().await.unwrap();
assert_eq!(reader.count_rows(None).await.unwrap(), 1);
// ...until it explicitly refreshes
reader.checkout_latest().await.unwrap();
assert_eq!(reader.count_rows(None).await.unwrap(), 2);
}
#[tokio::test]
async fn test_create_index() {
use arrow_array::RecordBatch;

View File

@@ -144,8 +144,19 @@ impl DatasetConsistencyWrapper {
}
/// Checkout a branch and track its HEAD for new versions.
pub async fn as_branch(&self, _branch: impl Into<String>) -> Result<()> {
todo!("Branch support not yet implemented")
pub async fn as_branch(&self, branch: impl Into<String>) -> Result<()> {
let branch = branch.into();
let dataset = { self.state.lock()?.dataset.clone() };
let new_dataset = dataset.checkout_branch(&branch).await?;
let mut state = self.state.lock()?;
state.dataset = Arc::new(new_dataset);
state.pinned_version = None;
drop(state);
if let ConsistencyMode::Eventual(bg_cache) = &self.consistency {
bg_cache.invalidate();
}
Ok(())
}
/// Check that the dataset is in a mutable mode (Latest).
@@ -161,6 +172,17 @@ impl DatasetConsistencyWrapper {
}
}
/// The branch this wrapper is currently tracking, or `None` for `main`.
pub fn current_branch(&self) -> Option<String> {
self.state
.lock()
.unwrap_or_else(|e| e.into_inner())
.dataset
.manifest()
.branch
.clone()
}
/// Returns the version, if in time travel mode, or None otherwise.
pub fn time_travel_version(&self) -> Option<u64> {
self.state
@@ -737,4 +759,31 @@ mod tests {
let result = wrapper.reload().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_as_branch_is_writable_and_tracked() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
// v1 on main, then shallow-clone a branch off it
let mut ds = create_test_dataset(uri).await;
let v1 = ds.version().version;
ds.create_branch("exp", v1, None).await.unwrap();
// wrapper starts on main: latest, writable, no branch
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
assert_eq!(wrapper.current_branch(), None);
// switch to the branch
wrapper.as_branch("exp").await.unwrap();
assert_eq!(wrapper.current_branch().as_deref(), Some("exp"));
// a branch is writable (unlike a pinned/time-travel checkout)
wrapper.ensure_mutable().unwrap();
assert_eq!(wrapper.time_travel_version(), None);
// get() returns the branch dataset
let on_branch = wrapper.get().await.unwrap();
assert_eq!(on_branch.manifest().branch.as_deref(), Some("exp"));
}
}

View File

@@ -53,12 +53,6 @@ pub struct MergeResult {
pub num_rows: u64,
}
#[derive(Debug, Clone)]
pub enum MergeFilter {
Sql(String),
Expr(datafusion_expr::Expr),
}
/// A builder used to create and run a merge insert operation
///
/// See [`super::Table::merge_insert`] for more context
@@ -67,10 +61,10 @@ pub struct MergeInsertBuilder {
table: Arc<dyn BaseTable>,
pub(crate) on: Vec<String>,
pub(crate) when_matched_update_all: bool,
pub(crate) when_matched_update_all_filt: Option<MergeFilter>,
pub(crate) when_matched_update_all_filt: Option<String>,
pub(crate) when_not_matched_insert_all: bool,
pub(crate) when_not_matched_by_source_delete: bool,
pub(crate) when_not_matched_by_source_delete_filt: Option<MergeFilter>,
pub(crate) when_not_matched_by_source_delete_filt: Option<String>,
pub(crate) timeout: Option<Duration>,
pub(crate) use_index: bool,
pub(crate) use_lsm_write: Option<bool>,
@@ -116,14 +110,7 @@ impl MergeInsertBuilder {
/// For example, "target.last_update < source.last_update"
pub fn when_matched_update_all(&mut self, condition: Option<String>) -> &mut Self {
self.when_matched_update_all = true;
self.when_matched_update_all_filt = condition.map(MergeFilter::Sql);
self
}
/// Similar to [`Self::when_matched_update_all`] but accepts a DataFusion logical expression directly.
pub fn when_matched_update_all_expr(&mut self, condition: datafusion_expr::Expr) -> &mut Self {
self.when_matched_update_all = true;
self.when_matched_update_all_filt = Some(MergeFilter::Expr(condition));
self.when_matched_update_all_filt = condition;
self
}
@@ -145,17 +132,7 @@ impl MergeInsertBuilder {
/// limit what rows are deleted.
pub fn when_not_matched_by_source_delete(&mut self, filter: Option<String>) -> &mut Self {
self.when_not_matched_by_source_delete = true;
self.when_not_matched_by_source_delete_filt = filter.map(MergeFilter::Sql);
self
}
/// Similar to [`Self::when_not_matched_by_source_delete`] but accepts a DataFusion logical expression directly.
pub fn when_not_matched_by_source_delete_expr(
&mut self,
filter: datafusion_expr::Expr,
) -> &mut Self {
self.when_not_matched_by_source_delete = true;
self.when_not_matched_by_source_delete_filt = Some(MergeFilter::Expr(filter));
self.when_not_matched_by_source_delete_filt = filter;
self
}
@@ -257,12 +234,7 @@ pub(crate) async fn execute_merge_insert(
) {
(false, _) => builder.when_matched(WhenMatched::DoNothing),
(true, None) => builder.when_matched(WhenMatched::UpdateAll),
(true, Some(MergeFilter::Sql(filt))) => {
builder.when_matched(WhenMatched::update_if(&dataset, &filt)?)
}
(true, Some(MergeFilter::Expr(expr))) => {
builder.when_matched(WhenMatched::update_if_expr(expr))
}
(true, Some(filt)) => builder.when_matched(WhenMatched::update_if(&dataset, &filt)?),
};
if params.when_not_matched_insert_all {
builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll);
@@ -270,12 +242,10 @@ pub(crate) async fn execute_merge_insert(
builder.when_not_matched(lance::dataset::WhenNotMatched::DoNothing);
}
if params.when_not_matched_by_source_delete {
let behavior = match params.when_not_matched_by_source_delete_filt {
Some(MergeFilter::Sql(filter)) => {
WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)?
}
Some(MergeFilter::Expr(expr)) => WhenNotMatchedBySource::DeleteIf(expr),
None => WhenNotMatchedBySource::Delete,
let behavior = if let Some(filter) = params.when_not_matched_by_source_delete_filt {
WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)?
} else {
WhenNotMatchedBySource::Delete
};
builder.when_not_matched_by_source(behavior);
} else {
@@ -416,45 +386,6 @@ mod tests {
merge_insert_builder.execute(new_batches).await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 25);
}
#[tokio::test]
async fn test_merge_insert_expr() {
use datafusion_expr::{col, lit};
let conn = connect("memory://").execute().await.unwrap();
// Create a dataset with i=0..10
let batches = merge_insert_test_batches(0, 0);
let table = conn
.create_table("my_table_expr", batches)
.execute()
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10);
// Conditional update that only replaces the age=0 data
let new_batches = merge_insert_test_batches(5, 3);
let mut merge_insert_builder = table.merge_insert(&["i"]);
// use expression: target.age = 0
let expr = col("target.age").eq(lit(0));
merge_insert_builder.when_matched_update_all_expr(expr);
merge_insert_builder.execute(new_batches).await.unwrap();
assert_eq!(
table.count_rows(Some("age = 3".to_string())).await.unwrap(),
5
);
// Delete with expression
// Create new batches with i=10..20 (so target rows i=0..9 are not matched by source)
let new_batches = merge_insert_test_batches(10, 0); // won't insert or update since we don't enable matched/unmatched actions
let mut merge_insert_builder = table.merge_insert(&["i"]);
// delete if target.age = 3
let delete_expr = col("target.age").eq(lit(3));
merge_insert_builder.when_not_matched_by_source_delete_expr(delete_expr);
let result = merge_insert_builder.execute(new_batches).await.unwrap();
assert_eq!(result.num_deleted_rows, 5);
assert_eq!(table.count_rows(None).await.unwrap(), 5);
}
}
#[cfg(test)]

View File

@@ -41,11 +41,14 @@ pub async fn execute_query(
query: &AnyQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
// If QueryTable pushdown is enabled and namespace client is configured, use server-side query execution
// QueryTable pushdown runs the query server-side, but only on the main
// branch: the namespace request carries no branch yet, so a branch handle
// must fall through to local execution.
if table
.pushdown_operations
.contains(&NamespaceClientPushdownOperation::QueryTable)
&& let Some(ref namespace_client) = table.namespace_client
&& table.dataset.current_branch().is_none()
{
return execute_namespace_query(table, namespace_client.clone(), query, options).await;
}