mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-07 06:10:38 +00:00
Compare commits
22 Commits
feat/table
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f6a12ce6b | ||
|
|
59fbfd4158 | ||
|
|
f37e698e2f | ||
|
|
09b1bbc12a | ||
|
|
c484b24e51 | ||
|
|
3868965413 | ||
|
|
c13ebc6796 | ||
|
|
4b287fd9c4 | ||
|
|
64194ea8ad | ||
|
|
e6c5de1a58 | ||
|
|
39a9f3e1e9 | ||
|
|
952055d428 | ||
|
|
927ba2c948 | ||
|
|
415d199c15 | ||
|
|
a16676e05f | ||
|
|
4e44262499 | ||
|
|
632375faf1 | ||
|
|
9969191d0d | ||
|
|
1e7326cd8c | ||
|
|
9483b534af | ||
|
|
ac3411e81e | ||
|
|
6f18eb4cce |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.30.1-beta.0"
|
||||
current_version = "0.30.1-beta.2"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
11
.github/dependabot.yml
vendored
11
.github/dependabot.yml
vendored
@@ -21,3 +21,14 @@ updates:
|
||||
update-types:
|
||||
- minor
|
||||
- patch
|
||||
|
||||
- package-ecosystem: pip
|
||||
directory: /python
|
||||
schedule:
|
||||
interval: weekly
|
||||
# Only update uv.lock, never widen version requirements in pyproject.toml.
|
||||
versioning-strategy: lockfile-only
|
||||
groups:
|
||||
python-deps:
|
||||
patterns:
|
||||
- "*"
|
||||
|
||||
539
Cargo.lock
generated
539
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
28
Cargo.toml
28
Cargo.toml
@@ -13,20 +13,20 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.91.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=7.2.0-beta.3", default-features = false, "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=7.2.0-beta.3", default-features = false, "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=7.2.0-beta.3", default-features = false, "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance = { "version" = "=8.0.0-beta.6", default-features = false, "tag" = "v8.0.0-beta.6", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=8.0.0-beta.6", "tag" = "v8.0.0-beta.6", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=8.0.0-beta.6", "tag" = "v8.0.0-beta.6", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=8.0.0-beta.6", "tag" = "v8.0.0-beta.6", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=8.0.0-beta.6", default-features = false, "tag" = "v8.0.0-beta.6", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=8.0.0-beta.6", "tag" = "v8.0.0-beta.6", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=8.0.0-beta.6", "tag" = "v8.0.0-beta.6", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=8.0.0-beta.6", "tag" = "v8.0.0-beta.6", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=8.0.0-beta.6", default-features = false, "tag" = "v8.0.0-beta.6", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=8.0.0-beta.6", "tag" = "v8.0.0-beta.6", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=8.0.0-beta.6", "tag" = "v8.0.0-beta.6", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=8.0.0-beta.6", "tag" = "v8.0.0-beta.6", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=8.0.0-beta.6", "tag" = "v8.0.0-beta.6", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=8.0.0-beta.6", "tag" = "v8.0.0-beta.6", "git" = "https://github.com/lance-format/lance.git" }
|
||||
ahash = "0.8"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "58.0.0", optional = false }
|
||||
|
||||
26
REVIEW.md
Normal file
26
REVIEW.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# Code review guidelines
|
||||
|
||||
Repo-specific guidance for automated PR reviews.
|
||||
|
||||
## Cross-SDK parity
|
||||
|
||||
LanceDB exposes the same core (`rust/lancedb`) through Python, TypeScript (`nodejs`),
|
||||
and Java bindings. Behavioral drift between SDKs is a recurring problem, so watch for
|
||||
parity gaps when reviewing — but only flag real ones:
|
||||
|
||||
* If the change adds or modifies user-facing API or behavior in the shared core
|
||||
(`rust/lancedb`), check whether each binding that should expose it (`python`,
|
||||
`nodejs`) does. A core change with no corresponding binding update is worth a note.
|
||||
* If the change adds or modifies a public API in one SDK but not the other, open the
|
||||
sibling SDK's corresponding module and state whether an equivalent exists. If not,
|
||||
note it as a possible parity gap and suggest a follow-up issue.
|
||||
* For bug fixes, first read the sibling SDK's analogous code path to check whether the
|
||||
same bug exists there. Only raise parity if it actually does. Do not ask to "port" a
|
||||
fix for a bug that only ever existed in one binding.
|
||||
* Stay silent on internal-only refactors, tests, docs, and changes with no cross-SDK
|
||||
surface.
|
||||
* Parity expectations apply to the Python and TypeScript (`nodejs`) SDKs. Java currently
|
||||
implements only the remote table, not the local/embedded backend, so it is expected to
|
||||
be partial — do not flag Java for missing local-only functionality.
|
||||
* Keep parity feedback to a short, clearly-labeled note (e.g. "Possible SDK parity
|
||||
gap: …"). It is advisory, not a merge blocker.
|
||||
@@ -147,6 +147,14 @@ allow = [
|
||||
"CDLA-Permissive-2.0",
|
||||
]
|
||||
confidence-threshold = 0.8
|
||||
# Per-crate license exceptions: allow a license for a specific crate only,
|
||||
# rather than globally via the `allow` list above.
|
||||
exceptions = [
|
||||
# CDDL-1.0 (copyleft) is pulled in only as a dev/profiling dependency via
|
||||
# `inferno` -> `pprof` -> `lance-testing`; it is a test dependency that we
|
||||
# do not distribute, so scope the allowance to `inferno` alone.
|
||||
{ allow = ["CDDL-1.0"], crate = "inferno" },
|
||||
]
|
||||
# Crates whose license cannot be determined from Cargo metadata but whose
|
||||
# license we've manually confirmed from upstream. Keep this list minimal.
|
||||
[[licenses.clarify]]
|
||||
|
||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
||||
<dependency>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-core</artifactId>
|
||||
<version>0.30.1-beta.0</version>
|
||||
<version>0.30.1-beta.2</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
|
||||
@@ -30,17 +30,6 @@ The type of the index
|
||||
|
||||
***
|
||||
|
||||
### loss?
|
||||
|
||||
```ts
|
||||
optional loss: number;
|
||||
```
|
||||
|
||||
The KMeans loss value of the index,
|
||||
it is only present for vector indices.
|
||||
|
||||
***
|
||||
|
||||
### numIndexedRows
|
||||
|
||||
```ts
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.30.1-beta.0</version>
|
||||
<version>0.30.1-beta.2</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.30.1-beta.0</version>
|
||||
<version>0.30.1-beta.2</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
@@ -28,7 +28,7 @@
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<arrow.version>15.0.0</arrow.version>
|
||||
<lance-core.version>7.2.0-beta.1</lance-core.version>
|
||||
<lance-core.version>8.0.0-beta.6</lance-core.version>
|
||||
<spotless.skip>false</spotless.skip>
|
||||
<spotless.version>2.30.0</spotless.version>
|
||||
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.30.1-beta.0"
|
||||
version = "0.30.1-beta.2"
|
||||
publish = false
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
|
||||
@@ -721,7 +721,7 @@ describe("When creating an index", () => {
|
||||
columns: ["vec"],
|
||||
});
|
||||
const stats = await tbl.indexStats("vec_idx");
|
||||
expect(stats?.loss).toBeDefined();
|
||||
expect(stats).toBeDefined();
|
||||
|
||||
// Search without specifying the column
|
||||
let rst = await tbl
|
||||
@@ -1150,7 +1150,6 @@ describe("When creating an index", () => {
|
||||
expect(stats?.distanceType).toBeUndefined();
|
||||
expect(stats?.indexType).toEqual("BTREE");
|
||||
expect(stats?.numIndices).toEqual(1);
|
||||
expect(stats?.loss).toBeUndefined();
|
||||
});
|
||||
|
||||
test("when getting stats on non-existent index", async () => {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.30.1-beta.2",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.30.1-beta.2",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.30.1-beta.2",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -838,9 +838,6 @@ pub struct IndexStatistics {
|
||||
pub distance_type: Option<String>,
|
||||
/// The number of parts this index is split into.
|
||||
pub num_indices: Option<u32>,
|
||||
/// The KMeans loss value of the index,
|
||||
/// it is only present for vector indices.
|
||||
pub loss: Option<f64>,
|
||||
}
|
||||
impl From<lancedb::index::IndexStatistics> for IndexStatistics {
|
||||
fn from(value: lancedb::index::IndexStatistics) -> Self {
|
||||
@@ -850,7 +847,6 @@ impl From<lancedb::index::IndexStatistics> for IndexStatistics {
|
||||
index_type: value.index_type.to_string(),
|
||||
distance_type: value.distance_type.map(|d| d.to_string()),
|
||||
num_indices: value.num_indices,
|
||||
loss: value.loss,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.33.1-beta.0"
|
||||
current_version = "0.33.1-beta.2"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.33.1-beta.0"
|
||||
version = "0.33.1-beta.2"
|
||||
publish = false
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import os
|
||||
import threading
|
||||
import warnings
|
||||
@@ -37,6 +38,24 @@ class BackgroundEventLoop:
|
||||
|
||||
LOOP = BackgroundEventLoop()
|
||||
|
||||
|
||||
def _new_embedding_executor() -> concurrent.futures.ThreadPoolExecutor:
|
||||
return concurrent.futures.ThreadPoolExecutor(thread_name_prefix="lancedb-embedding")
|
||||
|
||||
|
||||
# Embedding functions can block for a long time -- a heavy local model or an
|
||||
# HTTP request to a remote embeddings API. Running them on asyncio's default
|
||||
# executor lets them starve the unrelated blocking I/O that shares that pool,
|
||||
# so they get a dedicated one. See
|
||||
# https://github.com/lancedb/lancedb/issues/3310.
|
||||
_EMBEDDING_EXECUTOR = _new_embedding_executor()
|
||||
|
||||
|
||||
def embedding_executor() -> concurrent.futures.ThreadPoolExecutor:
|
||||
"""Return the executor dedicated to running blocking embedding calls."""
|
||||
return _EMBEDDING_EXECUTOR
|
||||
|
||||
|
||||
_FORK_WARNED = False
|
||||
|
||||
|
||||
@@ -47,6 +66,12 @@ def _reset_after_fork():
|
||||
# the new state. The Rust-side tokio runtime is reset analogously by a
|
||||
# pthread_atfork hook installed in the _lancedb extension.
|
||||
LOOP._start()
|
||||
# The embedding executor's worker threads are dead in the child as well.
|
||||
# Replace it with a fresh pool (threads are spawned lazily, so this is
|
||||
# cheap); we don't shut down the old one, since joining its dead workers
|
||||
# could hang.
|
||||
global _EMBEDDING_EXECUTOR
|
||||
_EMBEDDING_EXECUTOR = _new_embedding_executor()
|
||||
global _FORK_WARNED
|
||||
if not _FORK_WARNED:
|
||||
_FORK_WARNED = True
|
||||
|
||||
@@ -41,6 +41,14 @@ from .rerankers.rrf import RRFReranker
|
||||
from .rerankers.util import check_reranker_result
|
||||
from .util import flatten_columns
|
||||
|
||||
BlobMode = Literal["lazy", "bytes", "descriptions"]
|
||||
|
||||
_BLOB_MODE_TO_HANDLING = {
|
||||
"lazy": "blobs_descriptions",
|
||||
"bytes": "all_binary",
|
||||
"descriptions": "blobs_descriptions",
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sys
|
||||
|
||||
@@ -55,7 +63,7 @@ if TYPE_CHECKING:
|
||||
from ._lancedb import VectorQuery as LanceVectorQuery
|
||||
from .common import VEC
|
||||
from .pydantic import LanceModel
|
||||
from .table import Table
|
||||
from .table import AsyncTable, Table
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self
|
||||
@@ -65,6 +73,179 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound="LanceModel")
|
||||
|
||||
|
||||
def _validate_blob_mode(blob_mode: BlobMode) -> None:
|
||||
if blob_mode not in _BLOB_MODE_TO_HANDLING:
|
||||
modes = ", ".join(repr(mode) for mode in _BLOB_MODE_TO_HANDLING)
|
||||
raise ValueError(f"blob_mode must be one of {modes}, got {blob_mode!r}")
|
||||
|
||||
|
||||
def _field_is_blob(field: pa.Field) -> bool:
|
||||
metadata = field.metadata or {}
|
||||
return metadata.get(b"lance-encoding:blob") == b"true" or (
|
||||
metadata.get("lance-encoding:blob") == "true"
|
||||
)
|
||||
|
||||
|
||||
def _schema_has_blob_field(schema: pa.Schema) -> bool:
|
||||
return any(_field_is_blob(field) for field in schema)
|
||||
|
||||
|
||||
def _blob_mode_requires_native_pandas(blob_mode: BlobMode, schema: pa.Schema) -> bool:
|
||||
return blob_mode in _BLOB_MODE_TO_HANDLING and _schema_has_blob_field(schema)
|
||||
|
||||
|
||||
def _unsupported_blob_pandas_error(reason: str) -> RuntimeError:
|
||||
return RuntimeError(
|
||||
"blob columns require Lance native scanner conversion for query "
|
||||
f"to_pandas(), but {reason}. Use a plain scan query or remove blob "
|
||||
"columns from the projection."
|
||||
)
|
||||
|
||||
|
||||
def _query_is_plain_scan(query: Query) -> bool:
|
||||
return (
|
||||
query.vector is None
|
||||
and query.full_text_query is None
|
||||
and not query.postfilter
|
||||
and not query.order_by
|
||||
)
|
||||
|
||||
|
||||
def _filter_to_sql(filter: Optional[Union[str, Expr]]) -> Optional[str]:
|
||||
if filter is None:
|
||||
return None
|
||||
if isinstance(filter, Expr):
|
||||
return filter.to_sql()
|
||||
return filter
|
||||
|
||||
|
||||
def _projection_to_scanner_kwargs(
|
||||
columns: Optional[
|
||||
Union[
|
||||
List[str], List[Tuple[str, Union[str, Expr]]], Dict[str, Union[str, Expr]]
|
||||
]
|
||||
],
|
||||
) -> Dict[str, Any]:
|
||||
if columns is None:
|
||||
return {}
|
||||
if isinstance(columns, list):
|
||||
if all(isinstance(column, str) for column in columns):
|
||||
return {"columns": columns}
|
||||
if all(isinstance(column, tuple) and len(column) == 2 for column in columns):
|
||||
return {
|
||||
"columns": {
|
||||
name: expr.to_sql() if isinstance(expr, Expr) else expr
|
||||
for name, expr in columns
|
||||
}
|
||||
}
|
||||
# Let Lance raise the detailed projection validation error.
|
||||
return {"columns": columns}
|
||||
|
||||
projection = {}
|
||||
for name, expr in columns.items():
|
||||
if isinstance(expr, Expr):
|
||||
expr = expr.to_sql()
|
||||
projection[name] = expr
|
||||
return {"columns": projection}
|
||||
|
||||
|
||||
def _scanner_kwargs_for_query(
|
||||
query: Query, blob_mode: BlobMode, dataset: Optional[Any] = None
|
||||
) -> Dict[str, Any]:
|
||||
fragments = _scanner_fragments_for_query(query, dataset)
|
||||
kwargs = {
|
||||
**_projection_to_scanner_kwargs(query.columns),
|
||||
"filter": _filter_to_sql(query.filter),
|
||||
"limit": query.limit,
|
||||
"offset": query.offset,
|
||||
"with_row_id": query.with_row_id,
|
||||
"with_row_address": query.with_row_address,
|
||||
"fast_search": query.fast_search,
|
||||
"blob_handling": _BLOB_MODE_TO_HANDLING[blob_mode],
|
||||
"fragments": fragments,
|
||||
}
|
||||
return {key: value for key, value in kwargs.items() if value is not None}
|
||||
|
||||
|
||||
def _scanner_fragments_for_query(query: Query, dataset: Optional[Any]) -> Optional[Any]:
|
||||
if query.fragments is not None and query.fragment_ids is not None:
|
||||
raise ValueError("fragments and fragment_ids cannot both be set")
|
||||
if query.fragments is not None:
|
||||
return query.fragments
|
||||
if query.fragment_ids is None:
|
||||
return None
|
||||
if dataset is None:
|
||||
raise ValueError("fragment_ids require a Lance dataset")
|
||||
|
||||
requested = set(query.fragment_ids)
|
||||
fragments = [
|
||||
fragment
|
||||
for fragment in dataset.get_fragments()
|
||||
if fragment.fragment_id in requested
|
||||
]
|
||||
found = {fragment.fragment_id for fragment in fragments}
|
||||
missing = requested - found
|
||||
if missing:
|
||||
missing_ids = ", ".join(str(fragment_id) for fragment_id in sorted(missing))
|
||||
raise ValueError(f"fragment_ids not found in dataset: {missing_ids}")
|
||||
return fragments
|
||||
|
||||
|
||||
def _ensure_lazy_blob_frame(
|
||||
df: "pd.DataFrame", schema: pa.Schema, blob_mode: BlobMode
|
||||
) -> "pd.DataFrame":
|
||||
if blob_mode != "lazy" or not _schema_has_blob_field(schema) or len(df) == 0:
|
||||
return df
|
||||
|
||||
for field in schema:
|
||||
if not _field_is_blob(field) or field.name not in df.columns:
|
||||
continue
|
||||
value = df[field.name].iloc[0]
|
||||
if value is not None and not hasattr(value, "readall"):
|
||||
raise _unsupported_blob_pandas_error(
|
||||
"the Lance scanner did not return lazy blob files"
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
def _scanner_to_table(scanner: Any) -> pa.Table:
|
||||
if hasattr(scanner, "to_pyarrow"):
|
||||
reader = scanner.to_pyarrow()
|
||||
return reader.read_all()
|
||||
if hasattr(scanner, "to_table"):
|
||||
return scanner.to_table()
|
||||
reader = scanner.to_reader()
|
||||
return reader.read_all()
|
||||
|
||||
|
||||
def _scanner_to_pandas(scanner: Any, blob_mode: BlobMode, **kwargs) -> "pd.DataFrame":
|
||||
schema = getattr(scanner, "projected_schema", None)
|
||||
if schema is None:
|
||||
schema = getattr(scanner, "schema", None)
|
||||
if schema is None:
|
||||
schema = getattr(scanner, "dataset_schema", None)
|
||||
if callable(schema):
|
||||
schema = schema()
|
||||
if hasattr(scanner, "to_pandas"):
|
||||
try:
|
||||
df = scanner.to_pandas(blob_mode=blob_mode, **kwargs)
|
||||
except TypeError as err:
|
||||
message = str(err)
|
||||
if "blob_mode" not in message and "unexpected keyword" not in message:
|
||||
raise
|
||||
df = scanner.to_pandas(**kwargs)
|
||||
if schema is not None:
|
||||
return _ensure_lazy_blob_frame(df, schema, blob_mode)
|
||||
return df
|
||||
|
||||
tbl = _scanner_to_table(scanner)
|
||||
if blob_mode == "lazy" and _schema_has_blob_field(tbl.schema):
|
||||
raise _unsupported_blob_pandas_error(
|
||||
"the Lance scanner does not expose to_pandas"
|
||||
)
|
||||
return tbl.to_pandas(**kwargs)
|
||||
|
||||
|
||||
# Pydantic validation function for vector queries
|
||||
def ensure_vector_query(
|
||||
val: Any,
|
||||
@@ -499,6 +680,13 @@ class Query(pydantic.BaseModel):
|
||||
# if true, include the row id in the results
|
||||
with_row_id: Optional[bool] = None
|
||||
|
||||
# if true, include the row address in the results
|
||||
with_row_address: Optional[bool] = None
|
||||
|
||||
# Lance fragments or fragment ids to scan on scanner-backed plain queries
|
||||
fragments: Optional[Any] = None
|
||||
fragment_ids: Optional[List[int]] = None
|
||||
|
||||
# offset to start fetching results from
|
||||
offset: Optional[int] = None
|
||||
|
||||
@@ -691,6 +879,9 @@ class LanceQueryBuilder(ABC):
|
||||
self._where = None
|
||||
self._postfilter = None
|
||||
self._with_row_id = None
|
||||
self._with_row_address = None
|
||||
self._fragments = None
|
||||
self._fragment_ids = None
|
||||
self._vector = None
|
||||
self._text = None
|
||||
self._ef = None
|
||||
@@ -718,6 +909,7 @@ class LanceQueryBuilder(ABC):
|
||||
self,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
*,
|
||||
blob_mode: BlobMode = "lazy",
|
||||
timeout: Optional[timedelta] = None,
|
||||
**kwargs,
|
||||
) -> "pd.DataFrame":
|
||||
@@ -737,11 +929,41 @@ class LanceQueryBuilder(ABC):
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If None, wait indefinitely.
|
||||
blob_mode: str, default "lazy"
|
||||
Controls how blob columns are returned for plain scan queries.
|
||||
Vector, FTS, hybrid, and other non-native query shapes keep the
|
||||
existing Arrow conversion path and only support blob descriptions.
|
||||
**kwargs
|
||||
Forwarded to pyarrow.Table.to_pandas after query execution and
|
||||
optional flattening.
|
||||
"""
|
||||
_validate_blob_mode(blob_mode)
|
||||
output_schema = getattr(self, "output_schema", None)
|
||||
if output_schema is not None:
|
||||
schema = output_schema()
|
||||
if _blob_mode_requires_native_pandas(blob_mode, schema):
|
||||
native_error = None
|
||||
if (flatten is None or blob_mode == "descriptions") and timeout is None:
|
||||
try:
|
||||
df = self._plain_scan_to_pandas(
|
||||
blob_mode, flatten=flatten, **kwargs
|
||||
)
|
||||
if df is not None:
|
||||
return df
|
||||
except Exception as err:
|
||||
native_error = err
|
||||
reason = (
|
||||
"this query shape cannot use Lance native pandas conversion"
|
||||
if native_error is None
|
||||
else str(native_error)
|
||||
)
|
||||
raise _unsupported_blob_pandas_error(reason) from native_error
|
||||
|
||||
tbl = flatten_columns(self.to_arrow(timeout=timeout), flatten)
|
||||
if _blob_mode_requires_native_pandas(blob_mode, tbl.schema):
|
||||
raise _unsupported_blob_pandas_error(
|
||||
"this query shape cannot use Lance native pandas conversion"
|
||||
)
|
||||
return tbl.to_pandas(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
@@ -947,6 +1169,32 @@ class LanceQueryBuilder(ABC):
|
||||
self._with_row_id = with_row_id
|
||||
return self
|
||||
|
||||
def with_row_address(self, with_row_address: bool = True) -> Self:
|
||||
"""Set whether to return row addresses.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
with_row_address: bool, default True
|
||||
If True, return the _rowaddr column in the results.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._with_row_address = with_row_address
|
||||
return self
|
||||
|
||||
def with_fragments(self, fragments: Any) -> Self:
|
||||
"""Set the Lance fragments to scan for plain scanner-backed queries."""
|
||||
self._fragments = fragments
|
||||
return self
|
||||
|
||||
def fragment_ids(self, fragment_ids: List[int]) -> Self:
|
||||
"""Set the Lance fragment ids to scan for plain scanner-backed queries."""
|
||||
self._fragment_ids = fragment_ids
|
||||
return self
|
||||
|
||||
def explain_plan(self, verbose: Optional[bool] = False) -> str:
|
||||
"""Return the execution plan for this query.
|
||||
|
||||
@@ -1086,6 +1334,25 @@ class LanceQueryBuilder(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _plain_scan_to_pandas(
|
||||
self,
|
||||
blob_mode: BlobMode,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
**kwargs,
|
||||
) -> Optional["pd.DataFrame"]:
|
||||
query = self.to_query_object()
|
||||
if not _query_is_plain_scan(query):
|
||||
return None
|
||||
|
||||
dataset = self._table.to_lance()
|
||||
scanner = dataset.scanner(
|
||||
**_scanner_kwargs_for_query(query, blob_mode, dataset)
|
||||
)
|
||||
if flatten is not None:
|
||||
tbl = flatten_columns(_scanner_to_table(scanner), flatten)
|
||||
return tbl.to_pandas(**kwargs)
|
||||
return _scanner_to_pandas(scanner, blob_mode, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def to_query_object(self) -> Query:
|
||||
"""Return a serializable representation of the query
|
||||
@@ -1357,6 +1624,9 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
refine_factor=self._refine_factor,
|
||||
vector_column=self._vector_column,
|
||||
with_row_id=self._with_row_id,
|
||||
with_row_address=self._with_row_address,
|
||||
fragments=self._fragments,
|
||||
fragment_ids=self._fragment_ids,
|
||||
offset=self._offset,
|
||||
fast_search=self._fast_search,
|
||||
ef=self._ef,
|
||||
@@ -1559,6 +1829,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
limit=self._limit,
|
||||
postfilter=self._postfilter,
|
||||
with_row_id=self._with_row_id,
|
||||
with_row_address=self._with_row_address,
|
||||
fragments=self._fragments,
|
||||
fragment_ids=self._fragment_ids,
|
||||
full_text_query=FullTextSearchQuery(
|
||||
query=self._query, columns=self._fts_columns
|
||||
),
|
||||
@@ -1629,6 +1902,9 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
filter=self._where,
|
||||
limit=self._limit,
|
||||
with_row_id=self._with_row_id,
|
||||
with_row_address=self._with_row_address,
|
||||
fragments=self._fragments,
|
||||
fragment_ids=self._fragment_ids,
|
||||
offset=self._offset,
|
||||
order_by=self._order_by,
|
||||
)
|
||||
@@ -2207,7 +2483,11 @@ class AsyncQueryBase(object):
|
||||
Base class for all async queries (take, scan, vector, fts, hybrid)
|
||||
"""
|
||||
|
||||
def __init__(self, inner: Union[LanceQuery, LanceVectorQuery, LanceTakeQuery]):
|
||||
def __init__(
|
||||
self,
|
||||
inner: Union[LanceQuery, LanceVectorQuery, LanceTakeQuery],
|
||||
table: Optional["AsyncTable"] = None,
|
||||
):
|
||||
"""
|
||||
Construct an AsyncQueryBase
|
||||
|
||||
@@ -2215,6 +2495,10 @@ class AsyncQueryBase(object):
|
||||
[AsyncTable.query][lancedb.table.AsyncTable.query] method to create a query.
|
||||
"""
|
||||
self._inner = inner
|
||||
self._table = table
|
||||
self._with_row_address = None
|
||||
self._fragments = None
|
||||
self._fragment_ids = None
|
||||
|
||||
def to_query_object(self) -> Query:
|
||||
"""
|
||||
@@ -2223,7 +2507,11 @@ class AsyncQueryBase(object):
|
||||
This is currently experimental but can be useful as the query object is pure
|
||||
python and more easily serializable.
|
||||
"""
|
||||
return Query.from_inner(self._inner.to_query_request())
|
||||
query = Query.from_inner(self._inner.to_query_request())
|
||||
query.with_row_address = self._with_row_address
|
||||
query.fragments = self._fragments
|
||||
query.fragment_ids = self._fragment_ids
|
||||
return query
|
||||
|
||||
def select(self, columns: Union[List[str], dict[str, str]]) -> Self:
|
||||
"""
|
||||
@@ -2280,6 +2568,27 @@ class AsyncQueryBase(object):
|
||||
self._inner.with_row_id()
|
||||
return self
|
||||
|
||||
def with_row_address(self, with_row_address: bool = True) -> Self:
|
||||
"""
|
||||
Include the _rowaddr column in scanner-backed plain query results.
|
||||
"""
|
||||
self._with_row_address = with_row_address
|
||||
return self
|
||||
|
||||
def with_fragments(self, fragments: Any) -> Self:
|
||||
"""
|
||||
Restrict scanner-backed plain query results to the given Lance fragments.
|
||||
"""
|
||||
self._fragments = fragments
|
||||
return self
|
||||
|
||||
def fragment_ids(self, fragment_ids: List[int]) -> Self:
|
||||
"""
|
||||
Restrict scanner-backed plain query results to the given Lance fragment ids.
|
||||
"""
|
||||
self._fragment_ids = fragment_ids
|
||||
return self
|
||||
|
||||
async def to_batches(
|
||||
self,
|
||||
*,
|
||||
@@ -2357,6 +2666,8 @@ class AsyncQueryBase(object):
|
||||
self,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
*,
|
||||
blob_mode: BlobMode = "lazy",
|
||||
**kwargs,
|
||||
) -> "pd.DataFrame":
|
||||
"""
|
||||
@@ -2390,13 +2701,63 @@ class AsyncQueryBase(object):
|
||||
The maximum time to wait for the query to complete.
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
blob_mode: str, default "lazy"
|
||||
Controls how blob columns are returned for plain scan queries.
|
||||
Vector, FTS, hybrid, and other non-native query shapes keep the
|
||||
existing Arrow conversion path and only support blob descriptions.
|
||||
**kwargs
|
||||
Forwarded to pyarrow.Table.to_pandas after query execution and
|
||||
optional flattening.
|
||||
"""
|
||||
return (
|
||||
flatten_columns(await self.to_arrow(timeout=timeout), flatten)
|
||||
).to_pandas(**kwargs)
|
||||
_validate_blob_mode(blob_mode)
|
||||
if hasattr(self._inner, "output_schema"):
|
||||
schema = await self.output_schema()
|
||||
if _blob_mode_requires_native_pandas(blob_mode, schema):
|
||||
native_error = None
|
||||
if (flatten is None or blob_mode == "descriptions") and timeout is None:
|
||||
try:
|
||||
df = await self._plain_scan_to_pandas(
|
||||
blob_mode, flatten=flatten, **kwargs
|
||||
)
|
||||
if df is not None:
|
||||
return df
|
||||
except Exception as err:
|
||||
native_error = err
|
||||
reason = (
|
||||
"this query shape cannot use Lance native pandas conversion"
|
||||
if native_error is None
|
||||
else str(native_error)
|
||||
)
|
||||
raise _unsupported_blob_pandas_error(reason) from native_error
|
||||
|
||||
tbl = flatten_columns(await self.to_arrow(timeout=timeout), flatten)
|
||||
if _blob_mode_requires_native_pandas(blob_mode, tbl.schema):
|
||||
raise _unsupported_blob_pandas_error(
|
||||
"this query shape cannot use Lance native pandas conversion"
|
||||
)
|
||||
return tbl.to_pandas(**kwargs)
|
||||
|
||||
async def _plain_scan_to_pandas(
|
||||
self,
|
||||
blob_mode: BlobMode,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
**kwargs,
|
||||
) -> Optional["pd.DataFrame"]:
|
||||
if self._table is None:
|
||||
return None
|
||||
|
||||
query = self.to_query_object()
|
||||
if not _query_is_plain_scan(query):
|
||||
return None
|
||||
|
||||
dataset = await self._table._to_lance()
|
||||
scanner = dataset.scanner(
|
||||
**_scanner_kwargs_for_query(query, blob_mode, dataset)
|
||||
)
|
||||
if flatten is not None:
|
||||
tbl = flatten_columns(_scanner_to_table(scanner), flatten)
|
||||
return tbl.to_pandas(**kwargs)
|
||||
return _scanner_to_pandas(scanner, blob_mode, **kwargs)
|
||||
|
||||
async def to_polars(
|
||||
self,
|
||||
@@ -2503,14 +2864,18 @@ class AsyncStandardQuery(AsyncQueryBase):
|
||||
Base class for "standard" async queries (all but take currently)
|
||||
"""
|
||||
|
||||
def __init__(self, inner: Union[LanceQuery, LanceVectorQuery]):
|
||||
def __init__(
|
||||
self,
|
||||
inner: Union[LanceQuery, LanceVectorQuery],
|
||||
table: Optional["AsyncTable"] = None,
|
||||
):
|
||||
"""
|
||||
Construct an AsyncStandardQuery
|
||||
|
||||
This method is not intended to be called directly. Instead, use the
|
||||
[AsyncTable.query][lancedb.table.AsyncTable.query] method to create a query.
|
||||
"""
|
||||
super().__init__(inner)
|
||||
super().__init__(inner, table)
|
||||
|
||||
def where(self, predicate: Union[str, Expr]) -> Self:
|
||||
"""
|
||||
@@ -2616,14 +2981,14 @@ class AsyncStandardQuery(AsyncQueryBase):
|
||||
|
||||
|
||||
class AsyncQuery(AsyncStandardQuery):
|
||||
def __init__(self, inner: LanceQuery):
|
||||
def __init__(self, inner: LanceQuery, table: Optional["AsyncTable"] = None):
|
||||
"""
|
||||
Construct an AsyncQuery
|
||||
|
||||
This method is not intended to be called directly. Instead, use the
|
||||
[AsyncTable.query][lancedb.table.AsyncTable.query] method to create a query.
|
||||
"""
|
||||
super().__init__(inner)
|
||||
super().__init__(inner, table)
|
||||
self._inner = inner
|
||||
|
||||
@classmethod
|
||||
@@ -2707,10 +3072,11 @@ class AsyncQuery(AsyncStandardQuery):
|
||||
new_self = self._inner.nearest_to(query_vectors[0])
|
||||
for v in query_vectors[1:]:
|
||||
new_self.add_query_vector(v)
|
||||
return AsyncVectorQuery(new_self)
|
||||
return AsyncVectorQuery(new_self, self._table)
|
||||
else:
|
||||
return AsyncVectorQuery(
|
||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
|
||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector)),
|
||||
self._table,
|
||||
)
|
||||
|
||||
def nearest_to_text(
|
||||
@@ -2743,17 +3109,18 @@ class AsyncQuery(AsyncStandardQuery):
|
||||
|
||||
if isinstance(query, str):
|
||||
return AsyncFTSQuery(
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns}),
|
||||
self._table,
|
||||
)
|
||||
# FullTextQuery object
|
||||
return AsyncFTSQuery(self._inner.nearest_to_text({"query": query}))
|
||||
return AsyncFTSQuery(self._inner.nearest_to_text({"query": query}), self._table)
|
||||
|
||||
|
||||
class AsyncFTSQuery(AsyncStandardQuery):
|
||||
"""A query for full text search for LanceDB."""
|
||||
|
||||
def __init__(self, inner: LanceFTSQuery):
|
||||
super().__init__(inner)
|
||||
def __init__(self, inner: LanceFTSQuery, table: Optional["AsyncTable"] = None):
|
||||
super().__init__(inner, table)
|
||||
self._inner = inner
|
||||
self._reranker = None
|
||||
|
||||
@@ -2835,10 +3202,11 @@ class AsyncFTSQuery(AsyncStandardQuery):
|
||||
new_self = self._inner.nearest_to(query_vectors[0])
|
||||
for v in query_vectors[1:]:
|
||||
new_self.add_query_vector(v)
|
||||
return AsyncHybridQuery(new_self)
|
||||
return AsyncHybridQuery(new_self, self._table)
|
||||
else:
|
||||
return AsyncHybridQuery(
|
||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
|
||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector)),
|
||||
self._table,
|
||||
)
|
||||
|
||||
async def to_batches(
|
||||
@@ -3029,7 +3397,7 @@ class AsyncVectorQueryBase:
|
||||
|
||||
|
||||
class AsyncVectorQuery(AsyncStandardQuery, AsyncVectorQueryBase):
|
||||
def __init__(self, inner: LanceVectorQuery):
|
||||
def __init__(self, inner: LanceVectorQuery, table: Optional["AsyncTable"] = None):
|
||||
"""
|
||||
Construct an AsyncVectorQuery
|
||||
|
||||
@@ -3039,7 +3407,7 @@ class AsyncVectorQuery(AsyncStandardQuery, AsyncVectorQueryBase):
|
||||
a vector query. Or you can use
|
||||
[AsyncTable.vector_search][lancedb.table.AsyncTable.vector_search]
|
||||
"""
|
||||
super().__init__(inner)
|
||||
super().__init__(inner, table)
|
||||
self._inner = inner
|
||||
self._reranker = None
|
||||
self._query_string = None
|
||||
@@ -3093,10 +3461,13 @@ class AsyncVectorQuery(AsyncStandardQuery, AsyncVectorQueryBase):
|
||||
|
||||
if isinstance(query, str):
|
||||
return AsyncHybridQuery(
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns}),
|
||||
self._table,
|
||||
)
|
||||
# FullTextQuery object
|
||||
return AsyncHybridQuery(self._inner.nearest_to_text({"query": query}))
|
||||
return AsyncHybridQuery(
|
||||
self._inner.nearest_to_text({"query": query}), self._table
|
||||
)
|
||||
|
||||
async def to_batches(
|
||||
self,
|
||||
@@ -3123,8 +3494,8 @@ class AsyncHybridQuery(AsyncStandardQuery, AsyncVectorQueryBase):
|
||||
in the `rerank` method to convert the scores to ranks and then normalize them.
|
||||
"""
|
||||
|
||||
def __init__(self, inner: LanceHybridQuery):
|
||||
super().__init__(inner)
|
||||
def __init__(self, inner: LanceHybridQuery, table: Optional["AsyncTable"] = None):
|
||||
super().__init__(inner, table)
|
||||
self._inner = inner
|
||||
self._norm = "score"
|
||||
self._reranker = RRFReranker()
|
||||
@@ -3165,8 +3536,8 @@ class AsyncHybridQuery(AsyncStandardQuery, AsyncVectorQueryBase):
|
||||
max_batch_length: Optional[int] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> AsyncRecordBatchReader:
|
||||
fts_query = AsyncFTSQuery(self._inner.to_fts_query())
|
||||
vec_query = AsyncVectorQuery(self._inner.to_vector_query())
|
||||
fts_query = AsyncFTSQuery(self._inner.to_fts_query(), self._table)
|
||||
vec_query = AsyncVectorQuery(self._inner.to_vector_query(), self._table)
|
||||
|
||||
# save the row ID choice that was made on the query builder and force it
|
||||
# to actually fetch the row ids because we need this for reranking
|
||||
@@ -3266,8 +3637,16 @@ class AsyncTakeQuery(AsyncQueryBase):
|
||||
Builder for parameterizing and executing take queries.
|
||||
"""
|
||||
|
||||
def __init__(self, inner: LanceTakeQuery):
|
||||
super().__init__(inner)
|
||||
def __init__(self, inner: LanceTakeQuery, table: Optional["AsyncTable"] = None):
|
||||
super().__init__(inner, table)
|
||||
|
||||
async def _plain_scan_to_pandas(
|
||||
self,
|
||||
blob_mode: BlobMode,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
**kwargs,
|
||||
) -> Optional["pd.DataFrame"]:
|
||||
return None
|
||||
|
||||
|
||||
class BaseQueryBuilder(object):
|
||||
@@ -3319,6 +3698,27 @@ class BaseQueryBuilder(object):
|
||||
self._inner.with_row_id()
|
||||
return self
|
||||
|
||||
def with_row_address(self, with_row_address: bool = True) -> Self:
|
||||
"""
|
||||
Include the _rowaddr column in scanner-backed plain query results.
|
||||
"""
|
||||
self._inner.with_row_address(with_row_address)
|
||||
return self
|
||||
|
||||
def with_fragments(self, fragments: Any) -> Self:
|
||||
"""
|
||||
Restrict scanner-backed plain query results to the given Lance fragments.
|
||||
"""
|
||||
self._inner.with_fragments(fragments)
|
||||
return self
|
||||
|
||||
def fragment_ids(self, fragment_ids: List[int]) -> Self:
|
||||
"""
|
||||
Restrict scanner-backed plain query results to the given Lance fragment ids.
|
||||
"""
|
||||
self._inner.fragment_ids(fragment_ids)
|
||||
return self
|
||||
|
||||
def output_schema(self) -> pa.Schema:
|
||||
"""
|
||||
Return the output schema for the query
|
||||
@@ -3400,6 +3800,8 @@ class BaseQueryBuilder(object):
|
||||
self,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
*,
|
||||
blob_mode: BlobMode = "lazy",
|
||||
**kwargs,
|
||||
) -> "pd.DataFrame":
|
||||
"""
|
||||
@@ -3433,11 +3835,15 @@ class BaseQueryBuilder(object):
|
||||
The maximum time to wait for the query to complete.
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
blob_mode: str, default "lazy"
|
||||
Controls how blob columns are returned for plain scan queries.
|
||||
**kwargs
|
||||
Forwarded to pyarrow.Table.to_pandas after query execution and
|
||||
optional flattening.
|
||||
"""
|
||||
return LOOP.run(self._inner.to_pandas(flatten, timeout, **kwargs))
|
||||
return LOOP.run(
|
||||
self._inner.to_pandas(flatten, timeout, blob_mode=blob_mode, **kwargs)
|
||||
)
|
||||
|
||||
def to_polars(
|
||||
self,
|
||||
|
||||
@@ -27,6 +27,9 @@ class LanceDBClientError(RuntimeError):
|
||||
self.request_id = request_id
|
||||
self.status_code = status_code
|
||||
|
||||
def __reduce__(self) -> tuple[type, tuple]:
|
||||
return (self.__class__, (str(self), self.request_id, self.status_code))
|
||||
|
||||
|
||||
class HttpError(LanceDBClientError):
|
||||
"""An error that occurred during an HTTP request.
|
||||
@@ -101,3 +104,19 @@ class RetryError(LanceDBClientError):
|
||||
self.max_request_failures = max_request_failures
|
||||
self.max_connect_failures = max_connect_failures
|
||||
self.max_read_failures = max_read_failures
|
||||
|
||||
def __reduce__(self) -> tuple[type, tuple]:
|
||||
return (
|
||||
self.__class__,
|
||||
(
|
||||
str(self),
|
||||
self.request_id,
|
||||
self.request_failures,
|
||||
self.connect_failures,
|
||||
self.read_failures,
|
||||
self.max_request_failures,
|
||||
self.max_connect_failures,
|
||||
self.max_read_failures,
|
||||
self.status_code,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -125,6 +125,9 @@ class MRRReranker(Reranker):
|
||||
This cannot reuse rerank_hybrid because MRR semantics require treating
|
||||
each vector result as a separate ranking system.
|
||||
"""
|
||||
if not vector_results:
|
||||
raise ValueError("vector_results must not be empty")
|
||||
|
||||
if not all(isinstance(v, type(vector_results[0])) for v in vector_results):
|
||||
raise ValueError(
|
||||
"All elements in vector_results should be of the same type"
|
||||
|
||||
@@ -82,6 +82,9 @@ class RRFReranker(Reranker):
|
||||
results from multiple vector searches as it doesn't support reranking
|
||||
vector results individually.
|
||||
"""
|
||||
if not vector_results:
|
||||
raise ValueError("vector_results must not be empty")
|
||||
|
||||
# Make sure all elements are of the same type
|
||||
if not all(isinstance(v, type(vector_results[0])) for v in vector_results):
|
||||
raise ValueError(
|
||||
|
||||
@@ -30,7 +30,7 @@ from lancedb.scannable import _register_optional_converters, to_scannable
|
||||
|
||||
from . import __version__
|
||||
from lancedb.arrow import peek_reader
|
||||
from lancedb.background_loop import LOOP
|
||||
from lancedb.background_loop import LOOP, embedding_executor
|
||||
from .dependencies import (
|
||||
_check_for_hugging_face,
|
||||
_check_for_lance,
|
||||
@@ -89,6 +89,26 @@ from .index import lang_mapping
|
||||
|
||||
BlobMode = Literal["lazy", "bytes", "descriptions"]
|
||||
|
||||
_VALID_BLOB_MODES = ("lazy", "bytes", "descriptions")
|
||||
|
||||
|
||||
def _validate_blob_mode(blob_mode: BlobMode) -> None:
|
||||
if blob_mode not in _VALID_BLOB_MODES:
|
||||
modes = ", ".join(repr(mode) for mode in _VALID_BLOB_MODES)
|
||||
raise ValueError(f"blob_mode must be one of {modes}, got {blob_mode!r}")
|
||||
|
||||
|
||||
def _field_is_blob(field: pa.Field) -> bool:
|
||||
metadata = field.metadata or {}
|
||||
return metadata.get(b"lance-encoding:blob") == b"true" or (
|
||||
metadata.get("lance-encoding:blob") == "true"
|
||||
)
|
||||
|
||||
|
||||
def _schema_has_blob_field(schema: pa.Schema) -> bool:
|
||||
return any(_field_is_blob(field) for field in schema)
|
||||
|
||||
|
||||
_MODEL_BACKED_TOKENIZER_PREFIXES = ("jieba", "lindera")
|
||||
_MODEL_BACKED_TOKENIZER_ERRORS = (
|
||||
"unknown base tokenizer",
|
||||
@@ -2294,9 +2314,14 @@ class LanceTable(Table):
|
||||
-------
|
||||
pd.DataFrame
|
||||
"""
|
||||
if blob_mode == "lazy" and (
|
||||
self._namespace_client is not None
|
||||
or get_uri_scheme(self._dataset_path) == "memory"
|
||||
_validate_blob_mode(blob_mode)
|
||||
if blob_mode == "descriptions" or not _schema_has_blob_field(self.schema):
|
||||
return self.to_arrow().to_pandas(**kwargs)
|
||||
|
||||
if (
|
||||
blob_mode == "lazy"
|
||||
and self._namespace_client is None
|
||||
and get_uri_scheme(self._dataset_path) == "memory"
|
||||
):
|
||||
return self.to_arrow().to_pandas(**kwargs)
|
||||
|
||||
@@ -4317,7 +4342,7 @@ class AsyncTable:
|
||||
can be executed with methods like [to_arrow][lancedb.query.AsyncQuery.to_arrow],
|
||||
[to_pandas][lancedb.query.AsyncQuery.to_pandas] and more.
|
||||
"""
|
||||
return AsyncQuery(self._inner.query())
|
||||
return AsyncQuery(self._inner.query(), self)
|
||||
|
||||
async def _to_lance(self, **kwargs) -> lance.LanceDataset:
|
||||
try:
|
||||
@@ -4349,7 +4374,13 @@ class AsyncTable:
|
||||
-------
|
||||
pd.DataFrame
|
||||
"""
|
||||
if blob_mode == "lazy":
|
||||
_validate_blob_mode(blob_mode)
|
||||
if blob_mode == "descriptions" or not _schema_has_blob_field(
|
||||
await self.schema()
|
||||
):
|
||||
return (await self.to_arrow()).to_pandas(**kwargs)
|
||||
|
||||
if blob_mode == "lazy" and get_uri_scheme(await self.uri()) == "memory":
|
||||
return (await self.to_arrow()).to_pandas(**kwargs)
|
||||
return (await self._to_lance()).to_pandas(blob_mode=blob_mode, **kwargs)
|
||||
|
||||
@@ -4877,10 +4908,13 @@ class AsyncTable:
|
||||
if embedding is not None:
|
||||
loop = asyncio.get_running_loop()
|
||||
# This function is likely to block, since it either calls an expensive
|
||||
# function or makes an HTTP request to an embeddings REST API.
|
||||
# function or makes an HTTP request to an embeddings REST API. Run it
|
||||
# on a dedicated executor so it can't starve the default executor that
|
||||
# other blocking I/O shares. See
|
||||
# https://github.com/lancedb/lancedb/issues/3310.
|
||||
return (
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
embedding_executor(),
|
||||
embedding.function.compute_query_embeddings_with_retry,
|
||||
query,
|
||||
)
|
||||
@@ -5393,7 +5427,7 @@ class AsyncTable:
|
||||
pa.RecordBatch
|
||||
A record batch containing the rows at the given offsets.
|
||||
"""
|
||||
return AsyncTakeQuery(self._inner.take_offsets(offsets))
|
||||
return AsyncTakeQuery(self._inner.take_offsets(offsets), self)
|
||||
|
||||
def take_row_ids(self, row_ids: list[int]) -> AsyncTakeQuery:
|
||||
"""
|
||||
@@ -5422,7 +5456,7 @@ class AsyncTable:
|
||||
AsyncTakeQuery
|
||||
A query object that can be executed to get the rows.
|
||||
"""
|
||||
return AsyncTakeQuery(self._inner.take_row_ids(row_ids))
|
||||
return AsyncTakeQuery(self._inner.take_row_ids(row_ids), self)
|
||||
|
||||
@property
|
||||
def tags(self) -> AsyncTags:
|
||||
@@ -5603,8 +5637,6 @@ class IndexStatistics:
|
||||
The distance type used by the index.
|
||||
num_indices: Optional[int]
|
||||
The number of parts the index is split into.
|
||||
loss: Optional[float]
|
||||
The KMeans loss for the index, for only vector indices.
|
||||
"""
|
||||
|
||||
num_indexed_rows: int
|
||||
@@ -5624,7 +5656,6 @@ class IndexStatistics:
|
||||
]
|
||||
distance_type: Optional[Literal["l2", "cosine", "dot"]] = None
|
||||
num_indices: Optional[int] = None
|
||||
loss: Optional[float] = None
|
||||
|
||||
# This exists for backwards compatibility with an older API, which returned
|
||||
# a dictionary instead of a class.
|
||||
|
||||
56
python/python/tests/test_errors.py
Normal file
56
python/python/tests/test_errors.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import pickle
|
||||
|
||||
from lancedb.remote.errors import HttpError, LanceDBClientError, RetryError
|
||||
|
||||
|
||||
def test_pickle_lancedb_client_error():
|
||||
err = LanceDBClientError("something went wrong", "req-123", 400)
|
||||
restored = pickle.loads(pickle.dumps(err))
|
||||
assert str(restored) == "something went wrong"
|
||||
assert restored.request_id == "req-123"
|
||||
assert restored.status_code == 400
|
||||
|
||||
|
||||
def test_pickle_lancedb_client_error_no_status_code():
|
||||
err = LanceDBClientError("fail", "req-456")
|
||||
restored = pickle.loads(pickle.dumps(err))
|
||||
assert str(restored) == "fail"
|
||||
assert restored.request_id == "req-456"
|
||||
assert restored.status_code is None
|
||||
|
||||
|
||||
def test_pickle_http_error():
|
||||
err = HttpError("not found", "req-789", 404)
|
||||
restored = pickle.loads(pickle.dumps(err))
|
||||
assert isinstance(restored, HttpError)
|
||||
assert str(restored) == "not found"
|
||||
assert restored.request_id == "req-789"
|
||||
assert restored.status_code == 404
|
||||
|
||||
|
||||
def test_pickle_retry_error():
|
||||
err = RetryError(
|
||||
"max retries exceeded",
|
||||
"req-abc",
|
||||
request_failures=3,
|
||||
connect_failures=1,
|
||||
read_failures=2,
|
||||
max_request_failures=5,
|
||||
max_connect_failures=3,
|
||||
max_read_failures=3,
|
||||
status_code=503,
|
||||
)
|
||||
restored = pickle.loads(pickle.dumps(err))
|
||||
assert isinstance(restored, RetryError)
|
||||
assert str(restored) == "max retries exceeded"
|
||||
assert restored.request_id == "req-abc"
|
||||
assert restored.request_failures == 3
|
||||
assert restored.connect_failures == 1
|
||||
assert restored.read_failures == 2
|
||||
assert restored.max_request_failures == 5
|
||||
assert restored.max_connect_failures == 3
|
||||
assert restored.max_read_failures == 3
|
||||
assert restored.status_code == 503
|
||||
@@ -226,7 +226,6 @@ async def test_create_vector_index(some_table: AsyncTable):
|
||||
assert stats.num_indexed_rows == await some_table.count_rows()
|
||||
assert stats.num_unindexed_rows == 0
|
||||
assert stats.num_indices == 1
|
||||
assert stats.loss >= 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -250,7 +249,6 @@ async def test_create_4bit_ivfpq_index(some_table: AsyncTable):
|
||||
assert stats.num_indexed_rows == await some_table.count_rows()
|
||||
assert stats.num_unindexed_rows == 0
|
||||
assert stats.num_indices == 1
|
||||
assert stats.loss >= 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -76,6 +76,35 @@ class TestNamespaceConnection:
|
||||
assert len(result) == 0
|
||||
assert list(result.columns) == ["id", "vector", "text"]
|
||||
|
||||
def test_table_to_pandas_blob_lazy_through_namespace(self):
|
||||
"""Namespace-backed tables should use Lance blob-aware pandas conversion."""
|
||||
pytest.importorskip("lance")
|
||||
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
|
||||
db.create_namespace(["test_ns"])
|
||||
data = pa.table(
|
||||
{
|
||||
"id": pa.array([1, 2], pa.int64()),
|
||||
"blob": pa.array([b"hello", b"world"], pa.large_binary()),
|
||||
},
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field(
|
||||
"blob",
|
||||
pa.large_binary(),
|
||||
metadata={"lance-encoding:blob": "true"},
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
table = db.create_table("blob_table", data, namespace_path=["test_ns"])
|
||||
df = table.to_pandas(blob_mode="lazy").sort_values("id")
|
||||
|
||||
blob = df["blob"].iloc[0]
|
||||
assert hasattr(blob, "readall")
|
||||
assert blob.readall() == b"hello"
|
||||
|
||||
def test_open_table_through_namespace(self):
|
||||
"""Test opening an existing table through namespace."""
|
||||
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
|
||||
|
||||
@@ -39,6 +39,35 @@ from utils import exception_output
|
||||
from importlib.util import find_spec
|
||||
|
||||
|
||||
def _blob_query_data():
|
||||
return pa.table(
|
||||
{
|
||||
"id": pa.array([1, 2, 3, 4], pa.int64()),
|
||||
"tag": pa.array(["drop", "keep", "keep", "keep"], pa.utf8()),
|
||||
"vector": pa.array(
|
||||
[[1.0, 0.0], [2.0, 0.0], [3.0, 0.0], [4.0, 0.0]],
|
||||
type=pa.list_(pa.float32(), list_size=2),
|
||||
),
|
||||
"blob": pa.array([b"one", b"two", b"three", b"four"], pa.large_binary()),
|
||||
},
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("tag", pa.utf8()),
|
||||
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
|
||||
pa.field(
|
||||
"blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _assert_lazy_blob(value, expected: bytes):
|
||||
assert hasattr(value, "readall")
|
||||
assert value.readall() == expected
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def table(tmpdir_factory) -> lancedb.table.Table:
|
||||
tmp_path = str(tmpdir_factory.mktemp("data"))
|
||||
@@ -181,6 +210,216 @@ async def test_query_to_pandas_kwargs(table, table_async):
|
||||
assert async_df["id"].tolist() == [1, 2]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("blob_mode", ["lazy", "bytes", "descriptions"])
|
||||
def test_plain_scan_query_to_pandas_blob_modes(tmp_db, blob_mode):
|
||||
pytest.importorskip("lance")
|
||||
table = tmp_db.create_table(
|
||||
f"test_query_to_pandas_blob_{blob_mode}", _blob_query_data()
|
||||
)
|
||||
|
||||
df = (
|
||||
table.search()
|
||||
.select(["id", "blob"])
|
||||
.where("id = 1")
|
||||
.to_pandas(blob_mode=blob_mode)
|
||||
)
|
||||
|
||||
assert df["id"].tolist() == [1]
|
||||
if blob_mode == "lazy":
|
||||
_assert_lazy_blob(df["blob"].iloc[0], b"one")
|
||||
elif blob_mode == "bytes":
|
||||
assert df["blob"].tolist() == [b"one"]
|
||||
else:
|
||||
first = df["blob"].iloc[0]
|
||||
assert first != b"one"
|
||||
assert not hasattr(first, "readall")
|
||||
|
||||
|
||||
def test_plain_scan_query_to_pandas_blob_projection(tmp_db):
|
||||
pytest.importorskip("lance")
|
||||
table = tmp_db.create_table(
|
||||
"test_query_to_pandas_blob_projection", _blob_query_data()
|
||||
)
|
||||
|
||||
df = (
|
||||
table.search()
|
||||
.where("id >= 2")
|
||||
.select({"id_alias": "id", "payload": "blob", "double_id": "id * 2"})
|
||||
.limit(2)
|
||||
.offset(1)
|
||||
.to_pandas(blob_mode="bytes")
|
||||
)
|
||||
|
||||
assert df["id_alias"].tolist() == [3, 4]
|
||||
assert df["payload"].tolist() == [b"three", b"four"]
|
||||
assert df["double_id"].tolist() == [6, 8]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("blob_mode", ["bytes", "descriptions"])
|
||||
def test_plain_scan_query_to_pandas_blob_mode_does_not_collect_arrow(
|
||||
tmp_db, monkeypatch, blob_mode
|
||||
):
|
||||
pytest.importorskip("lance")
|
||||
table = tmp_db.create_table(
|
||||
"test_query_to_pandas_blob_no_arrow_collect", _blob_query_data()
|
||||
)
|
||||
query = table.search().where("id = 1").select(["id", "blob"])
|
||||
|
||||
def fail_to_arrow(*args, **kwargs):
|
||||
raise AssertionError("to_arrow should not be called before native pandas")
|
||||
|
||||
monkeypatch.setattr(query, "to_arrow", fail_to_arrow)
|
||||
|
||||
df = query.to_pandas(blob_mode=blob_mode)
|
||||
|
||||
assert df["id"].tolist() == [1]
|
||||
if blob_mode == "bytes":
|
||||
assert df["blob"].tolist() == [b"one"]
|
||||
else:
|
||||
first = df["blob"].iloc[0]
|
||||
assert first != b"one"
|
||||
assert not hasattr(first, "readall")
|
||||
|
||||
|
||||
def test_plain_scan_query_to_pandas_blob_descriptions_flatten_uses_scanner(
|
||||
tmp_db, monkeypatch
|
||||
):
|
||||
pytest.importorskip("lance")
|
||||
table = tmp_db.create_table(
|
||||
"test_query_to_pandas_blob_desc_flatten", _blob_query_data()
|
||||
)
|
||||
query = table.search().where("id = 1").select(["id", "blob"])
|
||||
|
||||
def fail_to_arrow(*args, **kwargs):
|
||||
raise AssertionError("to_arrow should not be called before scanner pandas")
|
||||
|
||||
monkeypatch.setattr(query, "to_arrow", fail_to_arrow)
|
||||
|
||||
df = query.to_pandas(blob_mode="descriptions", flatten=True)
|
||||
|
||||
assert df["id"].tolist() == [1]
|
||||
assert any(column == "blob" or column.startswith("blob.") for column in df.columns)
|
||||
|
||||
|
||||
def test_plain_scan_query_to_pandas_scanner_state(tmp_db):
|
||||
pytest.importorskip("lance")
|
||||
data = _blob_query_data()
|
||||
table = tmp_db.create_table("test_query_to_pandas_scanner_state", data.slice(0, 2))
|
||||
table.add(data.slice(2, 2))
|
||||
|
||||
fragments = table.to_lance().get_fragments()
|
||||
assert len(fragments) == 2
|
||||
|
||||
query = (
|
||||
table.search()
|
||||
.select(["id", "blob"])
|
||||
.with_row_address()
|
||||
.fragment_ids([fragments[1].fragment_id])
|
||||
)
|
||||
query_obj = query.to_query_object()
|
||||
assert query_obj.with_row_address is True
|
||||
assert query_obj.fragment_ids == [fragments[1].fragment_id]
|
||||
|
||||
df = query.to_pandas(blob_mode="descriptions")
|
||||
|
||||
assert df["id"].tolist() == [3, 4]
|
||||
assert "_rowaddr" in df.columns
|
||||
assert {rowaddr >> 32 for rowaddr in df["_rowaddr"]} == {fragments[1].fragment_id}
|
||||
|
||||
df_by_fragment = (
|
||||
table.search()
|
||||
.select(["id", "blob"])
|
||||
.with_fragments([fragments[0]])
|
||||
.to_pandas(blob_mode="descriptions")
|
||||
)
|
||||
assert df_by_fragment["id"].tolist() == [1, 2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_plain_scan_query_to_pandas_blob_projection(tmp_db_async):
|
||||
pytest.importorskip("lance")
|
||||
table = await tmp_db_async.create_table(
|
||||
"test_async_query_to_pandas_blob_projection", _blob_query_data()
|
||||
)
|
||||
|
||||
lazy_df = await (
|
||||
table.query().where("id = 1").select(["id", "blob"]).to_pandas(blob_mode="lazy")
|
||||
)
|
||||
assert lazy_df["id"].tolist() == [1]
|
||||
_assert_lazy_blob(lazy_df["blob"].iloc[0], b"one")
|
||||
|
||||
bytes_df = await (
|
||||
table.query()
|
||||
.where("id >= 2")
|
||||
.select({"id_alias": "id", "payload": "blob", "double_id": "id * 2"})
|
||||
.limit(2)
|
||||
.offset(1)
|
||||
.to_pandas(blob_mode="bytes")
|
||||
)
|
||||
assert bytes_df["id_alias"].tolist() == [3, 4]
|
||||
assert bytes_df["payload"].tolist() == [b"three", b"four"]
|
||||
assert bytes_df["double_id"].tolist() == [6, 8]
|
||||
|
||||
desc_df = await (
|
||||
table.query()
|
||||
.where("id = 1")
|
||||
.select(["blob"])
|
||||
.to_pandas(blob_mode="descriptions")
|
||||
)
|
||||
first = desc_df["blob"].iloc[0]
|
||||
assert first != b"one"
|
||||
assert not hasattr(first, "readall")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("blob_mode", ["bytes", "descriptions"])
|
||||
async def test_async_plain_scan_query_to_pandas_blob_mode_does_not_collect_arrow(
|
||||
tmp_db_async, monkeypatch, blob_mode
|
||||
):
|
||||
pytest.importorskip("lance")
|
||||
table = await tmp_db_async.create_table(
|
||||
"test_async_query_to_pandas_blob_no_arrow_collect", _blob_query_data()
|
||||
)
|
||||
query = table.query().where("id = 1").select(["id", "blob"])
|
||||
|
||||
async def fail_to_arrow(*args, **kwargs):
|
||||
raise AssertionError("to_arrow should not be called before native pandas")
|
||||
|
||||
monkeypatch.setattr(query, "to_arrow", fail_to_arrow)
|
||||
|
||||
df = await query.to_pandas(blob_mode=blob_mode)
|
||||
|
||||
assert df["id"].tolist() == [1]
|
||||
if blob_mode == "bytes":
|
||||
assert df["blob"].tolist() == [b"one"]
|
||||
else:
|
||||
first = df["blob"].iloc[0]
|
||||
assert first != b"one"
|
||||
assert not hasattr(first, "readall")
|
||||
|
||||
|
||||
def test_vector_query_to_pandas_blob_mode_requires_native_path(tmp_db):
|
||||
pytest.importorskip("lance")
|
||||
table = tmp_db.create_table("test_vector_query_blob_mode", _blob_query_data())
|
||||
|
||||
with pytest.raises(RuntimeError, match="Lance native pandas conversion"):
|
||||
table.search([1.0, 0.0]).select(["blob", "vector"]).limit(1).to_pandas(
|
||||
blob_mode="lazy"
|
||||
)
|
||||
|
||||
|
||||
def test_vector_query_to_pandas_blob_descriptions_requires_plain_scan(tmp_db):
|
||||
pytest.importorskip("lance")
|
||||
table = tmp_db.create_table(
|
||||
"test_vector_query_blob_descriptions", _blob_query_data()
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="plain scan query"):
|
||||
table.search([1.0, 0.0]).select(["blob", "vector"]).limit(1).to_pandas(
|
||||
blob_mode="descriptions"
|
||||
)
|
||||
|
||||
|
||||
def test_order_by_plain_query(mem_db):
|
||||
table = mem_db.create_table(
|
||||
"test_order_by",
|
||||
|
||||
@@ -344,6 +344,12 @@ def test_mrr_reranker(tmp_path):
|
||||
assert len(result_deduped) == len(result)
|
||||
|
||||
|
||||
def test_mrr_reranker_empty_input():
|
||||
reranker = MRRReranker()
|
||||
with pytest.raises(ValueError, match="must not be empty"):
|
||||
reranker.rerank_multivector([])
|
||||
|
||||
|
||||
def test_rrf_reranker_distance():
|
||||
data = pa.table(
|
||||
{
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from datetime import date, datetime, timedelta
|
||||
from time import sleep
|
||||
@@ -26,6 +27,28 @@ from lancedb.table import LanceTable
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def _blob_test_data():
|
||||
return pa.table(
|
||||
{
|
||||
"id": pa.array([1, 2], pa.int64()),
|
||||
"blob": pa.array([b"hello", b"world"], pa.large_binary()),
|
||||
},
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field(
|
||||
"blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _assert_lazy_blob(value, expected: bytes):
|
||||
assert hasattr(value, "readall")
|
||||
assert value.readall() == expected
|
||||
|
||||
|
||||
def test_basic(mem_db: DBConnection):
|
||||
data = [
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
@@ -57,27 +80,30 @@ def test_table_to_pandas_default_matches_arrow(tmp_db: DBConnection):
|
||||
pd.testing.assert_frame_equal(table.to_pandas(), expected)
|
||||
|
||||
|
||||
def test_table_to_pandas_blob_bytes(tmp_db: DBConnection):
|
||||
def test_table_to_pandas_invalid_blob_mode_non_blob_table(tmp_db: DBConnection):
|
||||
data = pa.table({"id": [1, 2], "text": ["one", "two"]})
|
||||
table = tmp_db.create_table("test_to_pandas_invalid_blob_mode", data=data)
|
||||
|
||||
with pytest.raises(ValueError, match="blob_mode must be one of"):
|
||||
table.to_pandas(blob_mode="invalid")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("blob_mode", ["lazy", "bytes", "descriptions"])
|
||||
def test_table_to_pandas_blob_modes(tmp_db: DBConnection, blob_mode):
|
||||
pytest.importorskip("lance")
|
||||
data = pa.table(
|
||||
{
|
||||
"id": pa.array([1, 2], pa.int64()),
|
||||
"blob": pa.array([b"hello", b"world"], pa.large_binary()),
|
||||
},
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field(
|
||||
"blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
table = tmp_db.create_table("test_to_pandas_blob_bytes", data=data)
|
||||
table = tmp_db.create_table(f"test_to_pandas_blob_{blob_mode}", _blob_test_data())
|
||||
|
||||
df = table.to_pandas(blob_mode="bytes")
|
||||
df = table.to_pandas(blob_mode=blob_mode)
|
||||
|
||||
assert df["blob"].tolist() == [b"hello", b"world"]
|
||||
if blob_mode == "lazy":
|
||||
_assert_lazy_blob(df["blob"].iloc[0], b"hello")
|
||||
_assert_lazy_blob(df["blob"].iloc[1], b"world")
|
||||
elif blob_mode == "bytes":
|
||||
assert df["blob"].tolist() == [b"hello", b"world"]
|
||||
else:
|
||||
first = df["blob"].iloc[0]
|
||||
assert first != b"hello"
|
||||
assert not hasattr(first, "readall")
|
||||
|
||||
|
||||
def test_table_to_pandas_kwargs(tmp_db: DBConnection):
|
||||
@@ -93,22 +119,8 @@ def test_table_to_pandas_kwargs(tmp_db: DBConnection):
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_table_to_pandas_blob_bytes(tmp_db_async: AsyncConnection):
|
||||
pytest.importorskip("lance")
|
||||
data = pa.table(
|
||||
{
|
||||
"id": pa.array([1, 2], pa.int64()),
|
||||
"blob": pa.array([b"hello", b"world"], pa.large_binary()),
|
||||
},
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field(
|
||||
"blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
table = await tmp_db_async.create_table(
|
||||
"test_async_to_pandas_blob_bytes", data=data
|
||||
"test_async_to_pandas_blob_bytes", data=_blob_test_data()
|
||||
)
|
||||
|
||||
df = await table.to_pandas(blob_mode="bytes")
|
||||
@@ -116,6 +128,19 @@ async def test_async_table_to_pandas_blob_bytes(tmp_db_async: AsyncConnection):
|
||||
assert df["blob"].tolist() == [b"hello", b"world"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_table_to_pandas_invalid_blob_mode_non_blob_table(
|
||||
tmp_db_async: AsyncConnection,
|
||||
):
|
||||
table = await tmp_db_async.create_table(
|
||||
"test_async_to_pandas_invalid_blob_mode",
|
||||
data=pa.table({"id": [1, 2], "text": ["one", "two"]}),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="blob_mode must be one of"):
|
||||
await table.to_pandas(blob_mode="invalid")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_table_to_pandas_kwargs(tmp_db_async: AsyncConnection):
|
||||
pd = pytest.importorskip("pandas")
|
||||
@@ -1264,6 +1289,45 @@ def test_add_with_empty_fixed_size_list_drops_bad_rows(mem_db: DBConnection):
|
||||
assert np.allclose(data["embedding"].to_pylist()[0], np.array([0.1] * 16))
|
||||
|
||||
|
||||
def test_add_nullable_struct_with_none(mem_db: DBConnection):
|
||||
"""Regression test for issue #2654: a nullable struct column whose
|
||||
first batch contains only None values must not crash in
|
||||
_align_field_types with AttributeError: 'pyarrow.lib.DataType'
|
||||
object has no attribute 'fields'.
|
||||
|
||||
PyArrow infers an all-None struct column as `null` (not `struct`),
|
||||
so the type-alignment path needs to handle the case where the
|
||||
source field type is null and use the target type directly.
|
||||
"""
|
||||
# Use the v2.1 file format so that nullable structs are supported.
|
||||
table = mem_db.create_table(
|
||||
"test_nullable_struct",
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("id", pa.string()),
|
||||
pa.field(
|
||||
"data",
|
||||
pa.struct([pa.field("x", pa.float32())]),
|
||||
nullable=True,
|
||||
),
|
||||
]
|
||||
),
|
||||
storage_options=dict(new_table_data_storage_version="2.1"),
|
||||
)
|
||||
|
||||
# Adding a row with a non-null struct should work.
|
||||
table.add([{"id": "1", "data": {"x": 1.0}}])
|
||||
|
||||
# Adding a row with None for the nullable struct field should also
|
||||
# work — this is what used to crash.
|
||||
table.add([{"id": "2", "data": None}])
|
||||
|
||||
result = table.to_arrow()
|
||||
assert result.num_rows == 2
|
||||
assert result.column("id").to_pylist() == ["1", "2"]
|
||||
assert result.column("data").to_pylist() == [{"x": 1.0}, None]
|
||||
|
||||
|
||||
def test_add_with_integer_embeddings_preserves_casting(mem_db: DBConnection):
|
||||
class Schema(LanceModel):
|
||||
text: str
|
||||
@@ -2774,3 +2838,38 @@ def test_sanitize_data_metadata_not_stripped():
|
||||
assert result_schema.metadata is not None
|
||||
assert result_schema.metadata[b"existing_key"] == b"existing_value"
|
||||
assert result_schema.metadata[b"new_key"] == b"new_value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_search_runs_embedding_on_dedicated_executor(
|
||||
mem_db_async: AsyncConnection,
|
||||
):
|
||||
# Regression test for #3310: AsyncTable.search() must run the (potentially
|
||||
# blocking) query-embedding call on the dedicated embedding executor, not
|
||||
# asyncio's default executor -- which is shared with other blocking I/O and
|
||||
# can be starved by a slow embedding call under concurrent load.
|
||||
func = MockTextEmbeddingFunction.create()
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = await mem_db_async.create_table("embed_executor", schema=Schema)
|
||||
await table.add([{"text": "hello world"}])
|
||||
|
||||
captured_threads: List[str] = []
|
||||
original = MockTextEmbeddingFunction.generate_embeddings
|
||||
|
||||
def record_thread(self, texts):
|
||||
captured_threads.append(threading.current_thread().name)
|
||||
return original(self, texts)
|
||||
|
||||
# Patch only around the search so we capture the query-embedding call, not
|
||||
# the add-time source-embedding call.
|
||||
with patch.object(MockTextEmbeddingFunction, "generate_embeddings", record_thread):
|
||||
await (await table.search("a query string")).limit(1).to_list()
|
||||
|
||||
assert captured_threads, "search did not invoke the embedding function"
|
||||
assert all(name.startswith("lancedb-embedding") for name in captured_threads), (
|
||||
f"embedding ran off the dedicated executor: {captured_threads}"
|
||||
)
|
||||
|
||||
@@ -711,10 +711,6 @@ impl Table {
|
||||
dict.set_item("num_indices", num_indices)?;
|
||||
}
|
||||
|
||||
if let Some(loss) = stats.loss {
|
||||
dict.set_item("loss", loss)?;
|
||||
}
|
||||
|
||||
Ok(Some(dict.unbind()))
|
||||
})
|
||||
} else {
|
||||
|
||||
4226
python/uv.lock
generated
4226
python/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.30.1-beta.0"
|
||||
version = "0.30.1-beta.2"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
|
||||
@@ -203,11 +203,11 @@ impl Shuffler {
|
||||
|
||||
// Finish writing files
|
||||
for (file_idx, mut writer) in file_writers.into_iter().enumerate() {
|
||||
let num_written = writer.finish().await?;
|
||||
let write_summary = writer.finish().await?;
|
||||
log::debug!(
|
||||
"Shuffle job {}: wrote {} rows to file {}",
|
||||
self.id,
|
||||
num_written,
|
||||
write_summary.num_rows,
|
||||
file_idx
|
||||
);
|
||||
}
|
||||
|
||||
@@ -372,7 +372,6 @@ pub(crate) struct IndexMetadata {
|
||||
pub metric_type: Option<DistanceType>,
|
||||
// Sometimes the index type is provided at this level.
|
||||
pub index_type: Option<IndexType>,
|
||||
pub loss: Option<f64>,
|
||||
}
|
||||
|
||||
// This struct is used to deserialize the JSON data returned from the Lance API
|
||||
@@ -404,6 +403,4 @@ pub struct IndexStatistics {
|
||||
pub distance_type: Option<DistanceType>,
|
||||
/// The number of parts this index is split into.
|
||||
pub num_indices: Option<u32>,
|
||||
/// The loss value used by the index.
|
||||
pub loss: Option<f64>,
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ use crate::table::DropColumnsResult;
|
||||
use crate::table::MergeResult;
|
||||
use crate::table::Tags;
|
||||
use crate::table::UpdateResult;
|
||||
use crate::table::merge::MergeFilter;
|
||||
use crate::table::query::create_multi_vector_plan;
|
||||
use crate::table::{AlterColumnsResult, FieldMetadataUpdate, UpdateFieldMetadataResult};
|
||||
use crate::table::{AnyQuery, Filter, Predicate, PreprocessingOutput, TableStatistics};
|
||||
@@ -1826,16 +1827,57 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
})
|
||||
}
|
||||
|
||||
async fn set_lsm_write_spec(&self, _spec: crate::table::LsmWriteSpec) -> Result<()> {
|
||||
Err(Error::NotSupported {
|
||||
message: "set_lsm_write_spec is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
async fn set_lsm_write_spec(&self, spec: crate::table::LsmWriteSpec) -> Result<()> {
|
||||
use crate::table::LsmWriteSpec;
|
||||
self.check_mutable().await?;
|
||||
|
||||
// Map the spec onto the server's request DTO. `sharding` is internally
|
||||
// tagged on `mode` to mirror sophon's `Sharding` enum; `maintained_indexes`
|
||||
// and `writer_config_defaults` are sent verbatim (an empty list means "no
|
||||
// maintained indexes", not "default to all").
|
||||
let sharding = match &spec {
|
||||
LsmWriteSpec::Bucket {
|
||||
column,
|
||||
num_buckets,
|
||||
..
|
||||
} => serde_json::json!({
|
||||
"mode": "bucket",
|
||||
"column": column,
|
||||
"num_buckets": num_buckets,
|
||||
}),
|
||||
LsmWriteSpec::Identity { column, .. } => serde_json::json!({
|
||||
"mode": "identity",
|
||||
"column": column,
|
||||
}),
|
||||
LsmWriteSpec::Unsharded { .. } => serde_json::json!({ "mode": "unsharded" }),
|
||||
};
|
||||
let body = serde_json::json!({
|
||||
"sharding": sharding,
|
||||
"maintained_indexes": spec.maintained_indexes(),
|
||||
"writer_config_defaults": spec.writer_config_defaults(),
|
||||
});
|
||||
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!(
|
||||
"/v1/table/{}/set_lsm_write_spec/",
|
||||
self.identifier
|
||||
))
|
||||
.json(&body);
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unset_lsm_write_spec(&self) -> Result<()> {
|
||||
Err(Error::NotSupported {
|
||||
message: "unset_lsm_write_spec is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
self.check_mutable().await?;
|
||||
let request = self.client.post(&format!(
|
||||
"/v1/table/{}/unset_lsm_write_spec/",
|
||||
self.identifier
|
||||
));
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn tags(&self) -> Result<Box<dyn Tags + '_>> {
|
||||
@@ -2266,13 +2308,34 @@ impl TryFrom<MergeInsertBuilder> for MergeInsertRequest {
|
||||
}
|
||||
let on = value.on[0].clone();
|
||||
|
||||
let when_matched_update_all_filt = match value.when_matched_update_all_filt {
|
||||
Some(MergeFilter::Sql(sql)) => Some(sql),
|
||||
Some(MergeFilter::Expr(_)) => {
|
||||
return Err(Error::NotSupported {
|
||||
message: "DataFusion expressions are not supported on remote tables".into(),
|
||||
});
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let when_not_matched_by_source_delete_filt =
|
||||
match value.when_not_matched_by_source_delete_filt {
|
||||
Some(MergeFilter::Sql(sql)) => Some(sql),
|
||||
Some(MergeFilter::Expr(_)) => {
|
||||
return Err(Error::NotSupported {
|
||||
message: "DataFusion expressions are not supported on remote tables".into(),
|
||||
});
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
on,
|
||||
when_matched_update_all: value.when_matched_update_all,
|
||||
when_matched_update_all_filt: value.when_matched_update_all_filt,
|
||||
when_matched_update_all_filt,
|
||||
when_not_matched_insert_all: value.when_not_matched_insert_all,
|
||||
when_not_matched_by_source_delete: value.when_not_matched_by_source_delete,
|
||||
when_not_matched_by_source_delete_filt: value.when_not_matched_by_source_delete_filt,
|
||||
when_not_matched_by_source_delete_filt,
|
||||
// Only serialize use_index when it's false for backwards compatibility
|
||||
use_index: value.use_index,
|
||||
})
|
||||
@@ -4058,7 +4121,6 @@ mod tests {
|
||||
index_type: IndexType::IvfPq,
|
||||
distance_type: Some(DistanceType::L2),
|
||||
num_indices: None,
|
||||
loss: None,
|
||||
};
|
||||
assert_eq!(indices, expected);
|
||||
|
||||
@@ -4406,6 +4468,91 @@ mod tests {
|
||||
assert!(matches!(e, Error::IndexNotFound { .. }));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_set_lsm_write_spec_unsharded() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(
|
||||
request.url().path(),
|
||||
"/v1/table/my_table/set_lsm_write_spec/"
|
||||
);
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
assert_eq!(body["sharding"], serde_json::json!({ "mode": "unsharded" }));
|
||||
assert_eq!(body["maintained_indexes"], serde_json::json!(["id_idx"]));
|
||||
assert_eq!(
|
||||
body["writer_config_defaults"],
|
||||
serde_json::json!({ "max_memtable_rows": "1000" })
|
||||
);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"maintained_indexes":["id_idx"]}"#)
|
||||
.unwrap()
|
||||
});
|
||||
let spec = crate::table::LsmWriteSpec::unsharded()
|
||||
.with_maintained_indexes(["id_idx"])
|
||||
.with_writer_config_defaults([("max_memtable_rows", "1000")]);
|
||||
table.set_lsm_write_spec(spec).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_set_lsm_write_spec_bucket() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(
|
||||
request.url().path(),
|
||||
"/v1/table/my_table/set_lsm_write_spec/"
|
||||
);
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
assert_eq!(
|
||||
body["sharding"],
|
||||
serde_json::json!({ "mode": "bucket", "column": "id", "num_buckets": 16 })
|
||||
);
|
||||
assert_eq!(body["maintained_indexes"], serde_json::json!([]));
|
||||
http::Response::builder().status(200).body("{}").unwrap()
|
||||
});
|
||||
table
|
||||
.set_lsm_write_spec(crate::table::LsmWriteSpec::bucket("id", 16))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_set_lsm_write_spec_identity() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(
|
||||
request.url().path(),
|
||||
"/v1/table/my_table/set_lsm_write_spec/"
|
||||
);
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
assert_eq!(
|
||||
body["sharding"],
|
||||
serde_json::json!({ "mode": "identity", "column": "tenant" })
|
||||
);
|
||||
http::Response::builder().status(200).body("{}").unwrap()
|
||||
});
|
||||
table
|
||||
.set_lsm_write_spec(crate::table::LsmWriteSpec::identity("tenant"))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_unset_lsm_write_spec() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(
|
||||
request.url().path(),
|
||||
"/v1/table/my_table/unset_lsm_write_spec/"
|
||||
);
|
||||
http::Response::builder().status(200).body("{}").unwrap()
|
||||
});
|
||||
table.unset_lsm_write_spec().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_wait_for_index() {
|
||||
let table = _make_table_with_indices(0);
|
||||
|
||||
@@ -3019,20 +3019,12 @@ impl BaseTable for NativeTable {
|
||||
.ok_or_else(|| Error::InvalidInput {
|
||||
message: "index statistics was missing index type".to_string(),
|
||||
})?;
|
||||
let loss = stats
|
||||
.indices
|
||||
.iter()
|
||||
.map(|index| index.loss.unwrap_or_default())
|
||||
.sum::<f64>();
|
||||
|
||||
let loss = first_index.loss.map(|first_loss| first_loss + loss);
|
||||
Ok(Some(IndexStatistics {
|
||||
num_indexed_rows: stats.num_indexed_rows,
|
||||
num_unindexed_rows: stats.num_unindexed_rows,
|
||||
index_type,
|
||||
distance_type: first_index.metric_type,
|
||||
num_indices: stats.num_indices,
|
||||
loss,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -3435,7 +3427,6 @@ mod tests {
|
||||
assert_eq!(stats.num_unindexed_rows, 0);
|
||||
assert_eq!(stats.index_type, crate::index::IndexType::IvfPq);
|
||||
assert_eq!(stats.distance_type, Some(crate::DistanceType::L2));
|
||||
assert!(stats.loss.is_some());
|
||||
|
||||
table.drop_index(index_name).await.unwrap();
|
||||
assert_eq!(table.list_indices().await.unwrap().len(), 0);
|
||||
|
||||
@@ -53,6 +53,12 @@ pub struct MergeResult {
|
||||
pub num_rows: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum MergeFilter {
|
||||
Sql(String),
|
||||
Expr(datafusion_expr::Expr),
|
||||
}
|
||||
|
||||
/// A builder used to create and run a merge insert operation
|
||||
///
|
||||
/// See [`super::Table::merge_insert`] for more context
|
||||
@@ -61,10 +67,10 @@ pub struct MergeInsertBuilder {
|
||||
table: Arc<dyn BaseTable>,
|
||||
pub(crate) on: Vec<String>,
|
||||
pub(crate) when_matched_update_all: bool,
|
||||
pub(crate) when_matched_update_all_filt: Option<String>,
|
||||
pub(crate) when_matched_update_all_filt: Option<MergeFilter>,
|
||||
pub(crate) when_not_matched_insert_all: bool,
|
||||
pub(crate) when_not_matched_by_source_delete: bool,
|
||||
pub(crate) when_not_matched_by_source_delete_filt: Option<String>,
|
||||
pub(crate) when_not_matched_by_source_delete_filt: Option<MergeFilter>,
|
||||
pub(crate) timeout: Option<Duration>,
|
||||
pub(crate) use_index: bool,
|
||||
pub(crate) use_lsm_write: Option<bool>,
|
||||
@@ -110,7 +116,14 @@ impl MergeInsertBuilder {
|
||||
/// For example, "target.last_update < source.last_update"
|
||||
pub fn when_matched_update_all(&mut self, condition: Option<String>) -> &mut Self {
|
||||
self.when_matched_update_all = true;
|
||||
self.when_matched_update_all_filt = condition;
|
||||
self.when_matched_update_all_filt = condition.map(MergeFilter::Sql);
|
||||
self
|
||||
}
|
||||
|
||||
/// Similar to [`Self::when_matched_update_all`] but accepts a DataFusion logical expression directly.
|
||||
pub fn when_matched_update_all_expr(&mut self, condition: datafusion_expr::Expr) -> &mut Self {
|
||||
self.when_matched_update_all = true;
|
||||
self.when_matched_update_all_filt = Some(MergeFilter::Expr(condition));
|
||||
self
|
||||
}
|
||||
|
||||
@@ -132,7 +145,17 @@ impl MergeInsertBuilder {
|
||||
/// limit what rows are deleted.
|
||||
pub fn when_not_matched_by_source_delete(&mut self, filter: Option<String>) -> &mut Self {
|
||||
self.when_not_matched_by_source_delete = true;
|
||||
self.when_not_matched_by_source_delete_filt = filter;
|
||||
self.when_not_matched_by_source_delete_filt = filter.map(MergeFilter::Sql);
|
||||
self
|
||||
}
|
||||
|
||||
/// Similar to [`Self::when_not_matched_by_source_delete`] but accepts a DataFusion logical expression directly.
|
||||
pub fn when_not_matched_by_source_delete_expr(
|
||||
&mut self,
|
||||
filter: datafusion_expr::Expr,
|
||||
) -> &mut Self {
|
||||
self.when_not_matched_by_source_delete = true;
|
||||
self.when_not_matched_by_source_delete_filt = Some(MergeFilter::Expr(filter));
|
||||
self
|
||||
}
|
||||
|
||||
@@ -234,7 +257,12 @@ pub(crate) async fn execute_merge_insert(
|
||||
) {
|
||||
(false, _) => builder.when_matched(WhenMatched::DoNothing),
|
||||
(true, None) => builder.when_matched(WhenMatched::UpdateAll),
|
||||
(true, Some(filt)) => builder.when_matched(WhenMatched::update_if(&dataset, &filt)?),
|
||||
(true, Some(MergeFilter::Sql(filt))) => {
|
||||
builder.when_matched(WhenMatched::update_if(&dataset, &filt)?)
|
||||
}
|
||||
(true, Some(MergeFilter::Expr(expr))) => {
|
||||
builder.when_matched(WhenMatched::update_if_expr(expr))
|
||||
}
|
||||
};
|
||||
if params.when_not_matched_insert_all {
|
||||
builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll);
|
||||
@@ -242,10 +270,12 @@ pub(crate) async fn execute_merge_insert(
|
||||
builder.when_not_matched(lance::dataset::WhenNotMatched::DoNothing);
|
||||
}
|
||||
if params.when_not_matched_by_source_delete {
|
||||
let behavior = if let Some(filter) = params.when_not_matched_by_source_delete_filt {
|
||||
WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)?
|
||||
} else {
|
||||
WhenNotMatchedBySource::Delete
|
||||
let behavior = match params.when_not_matched_by_source_delete_filt {
|
||||
Some(MergeFilter::Sql(filter)) => {
|
||||
WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)?
|
||||
}
|
||||
Some(MergeFilter::Expr(expr)) => WhenNotMatchedBySource::DeleteIf(expr),
|
||||
None => WhenNotMatchedBySource::Delete,
|
||||
};
|
||||
builder.when_not_matched_by_source(behavior);
|
||||
} else {
|
||||
@@ -386,6 +416,45 @@ mod tests {
|
||||
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 25);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_insert_expr() {
|
||||
use datafusion_expr::{col, lit};
|
||||
|
||||
let conn = connect("memory://").execute().await.unwrap();
|
||||
|
||||
// Create a dataset with i=0..10
|
||||
let batches = merge_insert_test_batches(0, 0);
|
||||
let table = conn
|
||||
.create_table("my_table_expr", batches)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
|
||||
// Conditional update that only replaces the age=0 data
|
||||
let new_batches = merge_insert_test_batches(5, 3);
|
||||
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||
// use expression: target.age = 0
|
||||
let expr = col("target.age").eq(lit(0));
|
||||
merge_insert_builder.when_matched_update_all_expr(expr);
|
||||
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||
assert_eq!(
|
||||
table.count_rows(Some("age = 3".to_string())).await.unwrap(),
|
||||
5
|
||||
);
|
||||
|
||||
// Delete with expression
|
||||
// Create new batches with i=10..20 (so target rows i=0..9 are not matched by source)
|
||||
let new_batches = merge_insert_test_batches(10, 0); // won't insert or update since we don't enable matched/unmatched actions
|
||||
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||
// delete if target.age = 3
|
||||
let delete_expr = col("target.age").eq(lit(3));
|
||||
merge_insert_builder.when_not_matched_by_source_delete_expr(delete_expr);
|
||||
let result = merge_insert_builder.execute(new_batches).await.unwrap();
|
||||
assert_eq!(result.num_deleted_rows, 5);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 5);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user