Compare commits

..

1 Commits

Author SHA1 Message Date
Lance Release
f26d688641 Bump version: 0.27.0-beta.6 → 0.27.0 2026-03-16 22:46:53 +00:00
41 changed files with 1476 additions and 2919 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.27.2-beta.1"
current_version = "0.27.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -207,14 +207,14 @@ jobs:
- name: Downgrade dependencies
# These packages have newer requirements for MSRV
run: |
cargo update -p aws-sdk-bedrockruntime --precise 1.77.0
cargo update -p aws-sdk-dynamodb --precise 1.68.0
cargo update -p aws-config --precise 1.6.0
cargo update -p aws-sdk-kms --precise 1.63.0
cargo update -p aws-sdk-s3 --precise 1.79.0
cargo update -p aws-sdk-sso --precise 1.62.0
cargo update -p aws-sdk-ssooidc --precise 1.63.0
cargo update -p aws-sdk-sts --precise 1.63.0
cargo update -p aws-sdk-bedrockruntime --precise 1.64.0
cargo update -p aws-sdk-dynamodb --precise 1.55.0
cargo update -p aws-config --precise 1.5.10
cargo update -p aws-sdk-kms --precise 1.51.0
cargo update -p aws-sdk-s3 --precise 1.65.0
cargo update -p aws-sdk-sso --precise 1.50.0
cargo update -p aws-sdk-ssooidc --precise 1.51.0
cargo update -p aws-sdk-sts --precise 1.51.0
cargo update -p home --precise 0.5.9
- name: cargo +${{ matrix.msrv }} check
env:

2124
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -15,20 +15,20 @@ categories = ["database-implementations"]
rust-version = "1.91.0"
[workspace.dependencies]
lance = { "version" = "=4.0.0-rc.3", default-features = false, "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=4.0.0-rc.3", default-features = false, "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=4.0.0-rc.3", default-features = false, "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance = { version = "=3.0.0", default-features = false }
lance-core = { version = "=3.0.0" }
lance-datagen = { version = "=3.0.0" }
lance-file = { version = "=3.0.0" }
lance-io = { version = "=3.0.0", default-features = false }
lance-index = { version = "=3.0.0" }
lance-linalg = { version = "=3.0.0" }
lance-namespace = { version = "=3.0.0" }
lance-namespace-impls = { version = "=3.0.0", default-features = false }
lance-table = { version = "=3.0.0" }
lance-testing = { version = "=3.0.0" }
lance-datafusion = { version = "=3.0.0" }
lance-encoding = { version = "=3.0.0" }
lance-arrow = { version = "=3.0.0" }
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "57.2", optional = false }

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import argparse
import functools
import json
import os
import re
@@ -27,7 +26,6 @@ SEMVER_RE = re.compile(
)
@functools.total_ordering
@dataclass(frozen=True)
class SemVer:
major: int
@@ -158,9 +156,7 @@ def read_current_version(repo_root: Path) -> str:
def determine_latest_tag(tags: Iterable[TagInfo]) -> TagInfo:
# Stable releases (no prerelease) are always preferred over pre-releases.
# Within each group, standard semver ordering applies.
return max(tags, key=lambda tag: (not tag.semver.prerelease, tag.semver))
return max(tags, key=lambda tag: tag.semver)
def write_outputs(args: argparse.Namespace, payload: dict) -> None:

View File

@@ -1,8 +1,8 @@
mkdocs==1.6.1
mkdocs==1.5.3
mkdocs-jupyter==0.24.1
mkdocs-material==9.6.23
mkdocs-material==9.5.3
mkdocs-autorefs>=0.5,<=1.0
mkdocstrings[python]>=0.24,<1.0
mkdocstrings[python]==0.25.2
griffe>=0.40,<1.0
mkdocs-render-swagger-plugin>=0.1.0
pydantic>=2.0,<3.0

View File

@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
<dependency>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-core</artifactId>
<version>0.27.2-beta.1</version>
<version>0.27.0</version>
</dependency>
```

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.27.2-beta.1</version>
<version>0.27.0-final.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.27.2-beta.1</version>
<version>0.27.0-final.0</version>
<packaging>pom</packaging>
<name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description>
@@ -28,7 +28,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version>
<lance-core.version>3.0.1</lance-core.version>
<lance-core.version>3.1.0-beta.2</lance-core.version>
<spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.27.2-beta.1"
version = "0.27.0"
license.workspace = true
description.workspace = true
repository.workspace = true

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.27.2-beta.1",
"version": "0.27.0",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.27.2-beta.1",
"version": "0.27.0",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.27.2-beta.1",
"version": "0.27.0",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.27.2-beta.1",
"version": "0.27.0",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.27.2-beta.1",
"version": "0.27.0",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.27.2-beta.1",
"version": "0.27.0",
"os": [
"win32"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.27.2-beta.1",
"version": "0.27.0",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.27.2-beta.1",
"version": "0.27.0-beta.5",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.27.2-beta.1",
"version": "0.27.0-beta.5",
"cpu": [
"x64",
"arm64"

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.27.2-beta.1",
"version": "0.27.0",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.30.2-beta.1"
current_version = "0.30.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.30.2-beta.1"
version = "0.30.0"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true
@@ -23,7 +23,6 @@ lance-namespace.workspace = true
lance-namespace-impls.workspace = true
lance-io.workspace = true
env_logger.workspace = true
log.workspace = true
pyo3 = { version = "0.26", features = ["extension-module", "abi3-py39"] }
pyo3-async-runtimes = { version = "0.26", features = [
"attributes",

View File

@@ -135,10 +135,7 @@ class Table:
def close(self) -> None: ...
async def schema(self) -> pa.Schema: ...
async def add(
self,
data: pa.RecordBatchReader,
mode: Literal["append", "overwrite"],
progress: Optional[Any] = None,
self, data: pa.RecordBatchReader, mode: Literal["append", "overwrite"]
) -> AddResult: ...
async def update(
self, updates: Dict[str, str], where: Optional[str]

View File

@@ -70,7 +70,7 @@ def ensure_vector_query(
) -> Union[List[float], List[List[float]], pa.Array, List[pa.Array]]:
if isinstance(val, list):
if len(val) == 0:
raise ValueError("Vector query must be a non-empty list")
return ValueError("Vector query must be a non-empty list")
sample = val[0]
else:
if isinstance(val, float):
@@ -83,7 +83,7 @@ def ensure_vector_query(
return val
if isinstance(sample, list):
if len(sample) == 0:
raise ValueError("Vector query must be a non-empty list")
return ValueError("Vector query must be a non-empty list")
if isinstance(sample[0], float):
# val is list of list of floats
return val
@@ -2205,8 +2205,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._vector_query.select(self._columns)
self._fts_query.select(self._columns)
if self._where:
self._vector_query.where(self._where, not self._postfilter)
self._fts_query.where(self._where, not self._postfilter)
self._vector_query.where(self._where, self._postfilter)
self._fts_query.where(self._where, self._postfilter)
if self._with_row_id:
self._vector_query.with_row_id(True)
self._fts_query.with_row_id(True)

View File

@@ -4,7 +4,7 @@
from datetime import timedelta
import logging
from functools import cached_property
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, Literal
from typing import Dict, Iterable, List, Optional, Union, Literal
import warnings
from lancedb._lancedb import (
@@ -35,7 +35,6 @@ import pyarrow as pa
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from lancedb.merge import LanceMergeInsertBuilder
from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.table import _normalize_progress
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder, LanceTakeQueryBuilder
from ..table import AsyncTable, IndexStatistics, Query, Table, Tags
@@ -309,7 +308,6 @@ class RemoteTable(Table):
mode: str = "append",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
progress: Optional[Union[bool, Callable, Any]] = None,
) -> AddResult:
"""Add more data to the [Table](Table). It has the same API signature as
the OSS version.
@@ -332,29 +330,17 @@ class RemoteTable(Table):
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
progress: bool, callable, or tqdm-like, optional
A callback or tqdm-compatible progress bar. See
:meth:`Table.add` for details.
Returns
-------
AddResult
An object containing the new version number of the table after adding data.
"""
progress, owns = _normalize_progress(progress)
try:
return LOOP.run(
self._table.add(
data,
mode=mode,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
progress=progress,
)
return LOOP.run(
self._table.add(
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
finally:
if owns:
progress.close()
)
def search(
self,

View File

@@ -14,7 +14,6 @@ from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
@@ -557,21 +556,6 @@ def _table_uri(base: str, table_name: str) -> str:
return join_uri(base, f"{table_name}.lance")
def _normalize_progress(progress):
"""Normalize a ``progress`` parameter for :meth:`Table.add`.
Returns ``(progress_obj, owns)`` where *owns* is True when we created a
tqdm bar that the caller must close.
"""
if progress is True:
from tqdm.auto import tqdm
return tqdm(unit=" rows"), True
if progress is False or progress is None:
return None, False
return progress, False
class Table(ABC):
"""
A Table is a collection of Records in a LanceDB Database.
@@ -990,7 +974,6 @@ class Table(ABC):
mode: AddMode = "append",
on_bad_vectors: OnBadVectorsType = "error",
fill_value: float = 0.0,
progress: Optional[Union[bool, Callable, Any]] = None,
) -> AddResult:
"""Add more data to the [Table](Table).
@@ -1012,29 +995,6 @@ class Table(ABC):
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
progress: bool, callable, or tqdm-like, optional
Progress reporting during the add operation. Can be:
- ``True`` to automatically create and display a tqdm progress
bar (requires ``tqdm`` to be installed)::
table.add(data, progress=True)
- A **callable** that receives a dict with keys ``output_rows``,
``output_bytes``, ``total_rows``, ``elapsed_seconds``,
``active_tasks``, ``total_tasks``, and ``done``::
def on_progress(p):
print(f"{p['output_rows']}/{p['total_rows']} rows, "
f"{p['active_tasks']}/{p['total_tasks']} workers")
table.add(data, progress=on_progress)
- A **tqdm-compatible** progress bar whose ``total`` and
``update()`` will be called automatically. The postfix shows
write throughput (MB/s) and active worker count::
with tqdm() as pbar:
table.add(data, progress=pbar)
Returns
-------
@@ -2532,7 +2492,6 @@ class LanceTable(Table):
mode: AddMode = "append",
on_bad_vectors: OnBadVectorsType = "error",
fill_value: float = 0.0,
progress: Optional[Union[bool, Callable, Any]] = None,
) -> AddResult:
"""Add data to the table.
If vector columns are missing and the table
@@ -2551,29 +2510,17 @@ class LanceTable(Table):
One of "error", "drop", "fill", "null".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
progress: bool, callable, or tqdm-like, optional
A callback or tqdm-compatible progress bar. See
:meth:`Table.add` for details.
Returns
-------
int
The number of vectors in the table.
"""
progress, owns = _normalize_progress(progress)
try:
return LOOP.run(
self._table.add(
data,
mode=mode,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
progress=progress,
)
return LOOP.run(
self._table.add(
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
finally:
if owns:
progress.close()
)
def merge(
self,
@@ -3822,7 +3769,6 @@ class AsyncTable:
mode: Optional[Literal["append", "overwrite"]] = "append",
on_bad_vectors: Optional[OnBadVectorsType] = None,
fill_value: Optional[float] = None,
progress: Optional[Union[bool, Callable, Any]] = None,
) -> AddResult:
"""Add more data to the [Table](Table).
@@ -3844,9 +3790,6 @@ class AsyncTable:
One of "error", "drop", "fill", "null".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
progress: callable or tqdm-like, optional
A callback or tqdm-compatible progress bar. See
:meth:`Table.add` for details.
"""
schema = await self.schema()
@@ -3870,9 +3813,8 @@ class AsyncTable:
)
_register_optional_converters()
data = to_scannable(data)
progress, owns = _normalize_progress(progress)
try:
return await self._inner.add(data, mode or "append", progress=progress)
return await self._inner.add(data, mode or "append")
except RuntimeError as e:
if "Cast error" in str(e):
raise ValueError(e)
@@ -3880,9 +3822,6 @@ class AsyncTable:
raise ValueError(e)
else:
raise
finally:
if owns:
progress.close()
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
"""
@@ -4812,16 +4751,7 @@ class IndexStatistics:
num_indexed_rows: int
num_unindexed_rows: int
index_type: Literal[
"IVF_FLAT",
"IVF_SQ",
"IVF_PQ",
"IVF_RQ",
"IVF_HNSW_SQ",
"IVF_HNSW_PQ",
"FTS",
"BTREE",
"BITMAP",
"LABEL_LIST",
"IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ", "FTS", "BTREE", "BITMAP", "LABEL_LIST"
]
distance_type: Optional[Literal["l2", "cosine", "dot"]] = None
num_indices: Optional[int] = None

View File

@@ -177,60 +177,6 @@ async def test_analyze_plan(table: AsyncTable):
assert "metrics=" in res
@pytest.fixture
def table_with_id(tmpdir_factory) -> Table:
tmp_path = str(tmpdir_factory.mktemp("data"))
db = lancedb.connect(tmp_path)
data = pa.table(
{
"id": pa.array([1, 2, 3, 4], type=pa.int64()),
"text": pa.array(["a", "b", "cat", "dog"]),
"vector": pa.array(
[[0.1, 0.1], [2, 2], [-0.1, -0.1], [0.5, -0.5]],
type=pa.list_(pa.float32(), list_size=2),
),
}
)
table = db.create_table("test_with_id", data)
table.create_fts_index("text", with_position=False, use_tantivy=False)
return table
def test_hybrid_prefilter_explain_plan(table_with_id: Table):
"""
Verify that the prefilter logic is not inverted in LanceHybridQueryBuilder.
"""
plan_prefilter = (
table_with_id.search(query_type="hybrid")
.vector([0.0, 0.0])
.text("dog")
.where("id = 1", prefilter=True)
.limit(2)
.explain_plan(verbose=True)
)
plan_postfilter = (
table_with_id.search(query_type="hybrid")
.vector([0.0, 0.0])
.text("dog")
.where("id = 1", prefilter=False)
.limit(2)
.explain_plan(verbose=True)
)
# prefilter=True: filter is pushed into the LanceRead scan.
# The FTS sub-plan exposes this as "full_filter=id = Int64(1)" inside LanceRead.
assert "full_filter=id = Int64(1)" in plan_prefilter, (
f"Should push the filter into the scan.\nPlan:\n{plan_prefilter}"
)
# prefilter=False: filter is applied as a separate FilterExec after the search.
# The filter must NOT be embedded in the scan.
assert "full_filter=id = Int64(1)" not in plan_postfilter, (
f"Should NOT push the filter into the scan.\nPlan:\n{plan_postfilter}"
)
def test_normalize_scores():
cases = [
(pa.array([0.1, 0.4]), pa.array([0.0, 1.0])),

View File

@@ -3,7 +3,6 @@
from datetime import timedelta
import random
from typing import get_args, get_type_hints
import pyarrow as pa
import pytest
@@ -23,7 +22,6 @@ from lancedb.index import (
HnswSq,
FTS,
)
from lancedb.table import IndexStatistics
@pytest_asyncio.fixture
@@ -285,23 +283,3 @@ async def test_create_index_with_binary_vectors(binary_table: AsyncTable):
for v in range(256):
res = await binary_table.query().nearest_to([v] * 128).to_arrow()
assert res["id"][0].as_py() == v
def test_index_statistics_index_type_lists_all_supported_values():
expected_index_types = {
"IVF_FLAT",
"IVF_SQ",
"IVF_PQ",
"IVF_RQ",
"IVF_HNSW_SQ",
"IVF_HNSW_PQ",
"FTS",
"BTREE",
"BITMAP",
"LABEL_LIST",
}
assert (
set(get_args(get_type_hints(IndexStatistics)["index_type"]))
== expected_index_types
)

View File

@@ -30,7 +30,6 @@ from lancedb.query import (
PhraseQuery,
Query,
FullTextSearchQuery,
ensure_vector_query,
)
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
from lancedb.table import AsyncTable, LanceTable
@@ -1502,18 +1501,6 @@ def test_search_empty_table(mem_db):
assert results == []
def test_ensure_vector_query_empty_list():
"""Regression: ensure_vector_query used to return instead of raise ValueError."""
with pytest.raises(ValueError, match="non-empty"):
ensure_vector_query([])
def test_ensure_vector_query_nested_empty_list():
"""Regression: ensure_vector_query used to return instead of raise ValueError."""
with pytest.raises(ValueError, match="non-empty"):
ensure_vector_query([[]])
def test_fast_search(tmp_path):
db = lancedb.connect(tmp_path)

View File

@@ -527,102 +527,6 @@ async def test_add_async(mem_db_async: AsyncConnection):
assert await table.count_rows() == 3
def test_add_progress_callback(mem_db: DBConnection):
table = mem_db.create_table(
"test",
data=[{"id": 1}, {"id": 2}],
)
updates = []
table.add([{"id": 3}, {"id": 4}], progress=lambda p: updates.append(dict(p)))
assert len(table) == 4
# The done callback always fires, so we should always get at least one.
assert len(updates) >= 1, "expected at least one progress callback"
for p in updates:
assert "output_rows" in p
assert "output_bytes" in p
assert "total_rows" in p
assert "elapsed_seconds" in p
assert "active_tasks" in p
assert "total_tasks" in p
assert "done" in p
# The last callback should have done=True.
assert updates[-1]["done"] is True
def test_add_progress_tqdm_like(mem_db: DBConnection):
"""Test that a tqdm-like object gets total set and update() called."""
class FakeBar:
def __init__(self):
self.total = None
self.n = 0
self.postfix = None
def update(self, n):
self.n += n
def set_postfix_str(self, s):
self.postfix = s
def refresh(self):
pass
table = mem_db.create_table(
"test",
data=[{"id": 1}, {"id": 2}],
)
bar = FakeBar()
table.add([{"id": 3}, {"id": 4}], progress=bar)
assert len(table) == 4
# Postfix should contain throughput and worker count
if bar.postfix is not None:
assert "MB/s" in bar.postfix
assert "workers" in bar.postfix
def test_add_progress_bool(mem_db: DBConnection):
"""Test that progress=True creates and closes a tqdm bar automatically."""
table = mem_db.create_table(
"test",
data=[{"id": 1}, {"id": 2}],
)
table.add([{"id": 3}, {"id": 4}], progress=True)
assert len(table) == 4
# progress=False should be the same as None
table.add([{"id": 5}], progress=False)
assert len(table) == 5
@pytest.mark.asyncio
async def test_add_progress_callback_async(mem_db_async: AsyncConnection):
"""Progress callbacks work through the async path too."""
table = await mem_db_async.create_table("test", data=[{"id": 1}, {"id": 2}])
updates = []
await table.add([{"id": 3}, {"id": 4}], progress=lambda p: updates.append(dict(p)))
assert await table.count_rows() == 4
assert len(updates) >= 1
assert updates[-1]["done"] is True
def test_add_progress_callback_error(mem_db: DBConnection):
"""A failing callback must not prevent the write from succeeding."""
table = mem_db.create_table("test", data=[{"id": 1}, {"id": 2}])
def bad_callback(p):
raise RuntimeError("boom")
table.add([{"id": 3}, {"id": 4}], progress=bad_callback)
assert len(table) == 4
def test_polars(mem_db: DBConnection):
data = {
"vector": [[3.1, 4.1], [5.9, 26.5]],

View File

@@ -19,7 +19,7 @@ use lancedb::table::{
Table as LanceDbTable,
};
use pyo3::{
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
pyclass, pymethods,
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
@@ -299,12 +299,10 @@ impl Table {
})
}
#[pyo3(signature = (data, mode, progress=None))]
pub fn add<'a>(
self_: PyRef<'a, Self>,
data: PyScannable,
mode: String,
progress: Option<Py<PyAny>>,
) -> PyResult<Bound<'a, PyAny>> {
let mut op = self_.inner_ref()?.add(data);
if mode == "append" {
@@ -314,81 +312,6 @@ impl Table {
} else {
return Err(PyValueError::new_err(format!("Invalid mode: {}", mode)));
}
if let Some(progress_obj) = progress {
let is_callable = Python::attach(|py| progress_obj.bind(py).is_callable());
if is_callable {
// Callback: call with a dict of progress info.
op = op.progress(move |p| {
Python::attach(|py| {
let dict = PyDict::new(py);
if let Err(e) = dict
.set_item("output_rows", p.output_rows())
.and_then(|_| dict.set_item("output_bytes", p.output_bytes()))
.and_then(|_| dict.set_item("total_rows", p.total_rows()))
.and_then(|_| {
dict.set_item("elapsed_seconds", p.elapsed().as_secs_f64())
})
.and_then(|_| dict.set_item("active_tasks", p.active_tasks()))
.and_then(|_| dict.set_item("total_tasks", p.total_tasks()))
.and_then(|_| dict.set_item("done", p.done()))
{
log::warn!("progress dict error: {e}");
return;
}
if let Err(e) = progress_obj.call1(py, (dict,)) {
log::warn!("progress callback error: {e}");
}
});
});
} else {
// tqdm-like: has update() method.
let mut last_rows: usize = 0;
let mut total_set = false;
op = op.progress(move |p| {
let current = p.output_rows();
let prev = last_rows;
last_rows = current;
Python::attach(|py| {
if let Some(total) = p.total_rows()
&& !total_set
{
if let Err(e) = progress_obj.setattr(py, "total", total) {
log::warn!("progress setattr error: {e}");
}
total_set = true;
}
let delta = current.saturating_sub(prev);
if delta > 0 {
if let Err(e) = progress_obj.call_method1(py, "update", (delta,)) {
log::warn!("progress update error: {e}");
}
// Show throughput and active workers in tqdm postfix.
let elapsed = p.elapsed().as_secs_f64();
if elapsed > 0.0 {
let mb_per_sec = p.output_bytes() as f64 / elapsed / 1_000_000.0;
let postfix = format!(
"{:.1} MB/s | {}/{} workers",
mb_per_sec,
p.active_tasks(),
p.total_tasks()
);
if let Err(e) =
progress_obj.call_method1(py, "set_postfix_str", (postfix,))
{
log::warn!("progress set_postfix_str error: {e}");
}
}
}
if p.done() {
// Force a final refresh so the bar shows completion.
if let Err(e) = progress_obj.call_method0(py, "refresh") {
log::warn!("progress refresh error: {e}");
}
}
});
});
}
}
future_into_py(self_.py(), async move {
let result = op.execute().await.infer_error()?;

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.27.2-beta.1"
version = "0.27.0"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true

View File

@@ -596,8 +596,11 @@ pub struct ConnectBuilder {
}
#[cfg(feature = "remote")]
const ENV_VARS_TO_STORAGE_OPTS: [(&str, &str); 1] =
[("AZURE_STORAGE_ACCOUNT_NAME", "azure_storage_account_name")];
const ENV_VARS_TO_STORAGE_OPTS: [(&str, &str); 3] = [
("AZURE_STORAGE_ACCOUNT_NAME", "azure_storage_account_name"),
("AZURE_CLIENT_ID", "azure_client_id"),
("AZURE_TENANT_ID", "azure_tenant_id"),
];
impl ConnectBuilder {
/// Create a new [`ConnectOptions`] with the given database URI.

View File

@@ -443,13 +443,23 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
})?,
);
}
if let Some(v) = options.0.get("azure_storage_account_name") {
headers.insert(
HeaderName::from_static("x-azure-storage-account-name"),
HeaderValue::from_str(v).map_err(|_| Error::InvalidInput {
message: format!("non-ascii storage account name '{}' provided", db_name),
})?,
);
// Map azure storage options to x-azure-* headers.
// The option key uses underscores (e.g. "azure_client_id") while the
// header uses hyphens (e.g. "x-azure-client-id").
let azure_opts: [(&str, &str); 3] = [
("azure_storage_account_name", "x-azure-storage-account-name"),
("azure_client_id", "x-azure-client-id"),
("azure_tenant_id", "x-azure-tenant-id"),
];
for (opt_key, header_name) in azure_opts {
if let Some(v) = options.0.get(opt_key) {
headers.insert(
HeaderName::from_static(header_name),
HeaderValue::from_str(v).map_err(|_| Error::InvalidInput {
message: format!("non-ascii value '{}' for option '{}'", v, opt_key),
})?,
);
}
}
for (key, value) in &config.extra_headers {
@@ -1072,4 +1082,34 @@ mod tests {
_ => panic!("Expected Runtime error"),
}
}
#[test]
fn test_default_headers_azure_opts() {
let mut opts = HashMap::new();
opts.insert(
"azure_storage_account_name".to_string(),
"myaccount".to_string(),
);
opts.insert("azure_client_id".to_string(), "my-client-id".to_string());
opts.insert("azure_tenant_id".to_string(), "my-tenant-id".to_string());
let remote_opts = RemoteOptions::new(opts);
let headers = RestfulLanceDbClient::<Sender>::default_headers(
"test-key",
"us-east-1",
"testdb",
false,
&remote_opts,
None,
&ClientConfig::default(),
)
.unwrap();
assert_eq!(
headers.get("x-azure-storage-account-name").unwrap(),
"myaccount"
);
assert_eq!(headers.get("x-azure-client-id").unwrap(), "my-client-id");
assert_eq!(headers.get("x-azure-tenant-id").unwrap(), "my-tenant-id");
}
}

View File

@@ -72,10 +72,6 @@ impl ServerVersion {
pub fn support_structural_fts(&self) -> bool {
self.0 >= semver::Version::new(0, 3, 0)
}
pub fn support_multipart_write(&self) -> bool {
self.0 >= semver::Version::new(0, 4, 0)
}
}
pub const OPT_REMOTE_PREFIX: &str = "remote_database_";
@@ -782,7 +778,12 @@ impl RemoteOptions {
impl From<StorageOptions> for RemoteOptions {
fn from(options: StorageOptions) -> Self {
let supported_opts = vec!["account_name", "azure_storage_account_name"];
let supported_opts = vec![
"account_name",
"azure_storage_account_name",
"azure_client_id",
"azure_tenant_id",
];
let mut filtered = HashMap::new();
for opt in supported_opts {
if let Some(v) = options.0.get(opt) {

File diff suppressed because it is too large Load Diff

View File

@@ -11,14 +11,10 @@ use arrow_ipc::CompressionType;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_physical_expr::EquivalenceProperties;
use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
};
use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use futures::StreamExt;
use http::header::CONTENT_TYPE;
use lance::io::exec::utils::InstrumentedRecordBatchStreamAdapter;
use crate::Error;
use crate::remote::ARROW_STREAM_CONTENT_TYPE;
@@ -26,16 +22,13 @@ use crate::remote::client::{HttpSend, RestfulLanceDbClient, Sender};
use crate::remote::table::RemoteTable;
use crate::table::AddResult;
use crate::table::datafusion::insert::COUNT_SCHEMA;
use crate::table::write_progress::WriteProgressTracker;
/// ExecutionPlan for inserting data into a remote LanceDB table.
///
/// Streams data as Arrow IPC to `/v1/table/{id}/insert/` endpoint.
///
/// When `upload_id` is set, inserts are staged as part of a multipart write
/// session and the plan supports multiple partitions for parallel uploads.
/// Without `upload_id`, the plan requires a single partition and commits
/// immediately.
/// This plan:
/// 1. Requires single partition (no parallel remote inserts yet)
/// 2. Streams data as Arrow IPC to `/v1/table/{id}/insert/` endpoint
/// 3. Stores AddResult for retrieval after execution
#[derive(Debug)]
pub struct RemoteInsertExec<S: HttpSend = Sender> {
table_name: String,
@@ -45,69 +38,21 @@ pub struct RemoteInsertExec<S: HttpSend = Sender> {
overwrite: bool,
properties: PlanProperties,
add_result: Arc<Mutex<Option<AddResult>>>,
metrics: ExecutionPlanMetricsSet,
upload_id: Option<String>,
tracker: Option<Arc<WriteProgressTracker>>,
}
impl<S: HttpSend + 'static> RemoteInsertExec<S> {
/// Create a new single-partition RemoteInsertExec.
/// Create a new RemoteInsertExec.
pub fn new(
table_name: String,
identifier: String,
client: RestfulLanceDbClient<S>,
input: Arc<dyn ExecutionPlan>,
overwrite: bool,
tracker: Option<Arc<WriteProgressTracker>>,
) -> Self {
Self::new_inner(
table_name, identifier, client, input, overwrite, None, tracker,
)
}
/// Create a multi-partition RemoteInsertExec for use with multipart writes.
///
/// Each partition's insert is staged under the given `upload_id` without
/// committing. The caller is responsible for calling the complete (or abort)
/// endpoint after all partitions finish.
pub fn new_multipart(
table_name: String,
identifier: String,
client: RestfulLanceDbClient<S>,
input: Arc<dyn ExecutionPlan>,
overwrite: bool,
upload_id: String,
tracker: Option<Arc<WriteProgressTracker>>,
) -> Self {
Self::new_inner(
table_name,
identifier,
client,
input,
overwrite,
Some(upload_id),
tracker,
)
}
fn new_inner(
table_name: String,
identifier: String,
client: RestfulLanceDbClient<S>,
input: Arc<dyn ExecutionPlan>,
overwrite: bool,
upload_id: Option<String>,
tracker: Option<Arc<WriteProgressTracker>>,
) -> Self {
let num_partitions = if upload_id.is_some() {
input.output_partitioning().partition_count()
} else {
1
};
let schema = COUNT_SCHEMA.clone();
let properties = PlanProperties::new(
EquivalenceProperties::new(schema),
datafusion_physical_plan::Partitioning::UnknownPartitioning(num_partitions),
datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
datafusion_physical_plan::execution_plan::EmissionType::Final,
datafusion_physical_plan::execution_plan::Boundedness::Bounded,
);
@@ -120,9 +65,6 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
overwrite,
properties,
add_result: Arc::new(Mutex::new(None)),
metrics: ExecutionPlanMetricsSet::new(),
upload_id,
tracker,
}
}
@@ -141,7 +83,6 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
fn stream_as_http_body(
data: SendableRecordBatchStream,
error_tx: tokio::sync::oneshot::Sender<DataFusionError>,
tracker: Option<Arc<WriteProgressTracker>>,
) -> DataFusionResult<reqwest::Body> {
let options = arrow_ipc::writer::IpcWriteOptions::default()
.try_with_compression(Some(CompressionType::LZ4_FRAME))?;
@@ -153,46 +94,37 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
let stream = futures::stream::try_unfold(
(data, writer, Some(error_tx), false),
move |(mut data, mut writer, error_tx, finished)| {
let tracker = tracker.clone();
async move {
if finished {
return Ok(None);
move |(mut data, mut writer, error_tx, finished)| async move {
if finished {
return Ok(None);
}
match data.next().await {
Some(Ok(batch)) => {
writer
.write(&batch)
.map_err(|e| std::io::Error::other(e.to_string()))?;
let buffer = std::mem::take(writer.get_mut());
Ok(Some((buffer, (data, writer, error_tx, false))))
}
match data.next().await {
Some(Ok(batch)) => {
writer
.write(&batch)
.map_err(|e| std::io::Error::other(e.to_string()))?;
let buffer = std::mem::take(writer.get_mut());
if let Some(ref t) = tracker {
t.record_bytes(buffer.len());
}
Ok(Some((buffer, (data, writer, error_tx, false))))
Some(Err(e)) => {
// Send the original error through the channel before
// returning a generic error to reqwest.
if let Some(tx) = error_tx {
let _ = tx.send(e);
}
Some(Err(e)) => {
// Send the original error through the channel before
// returning a generic error to reqwest.
if let Some(tx) = error_tx {
let _ = tx.send(e);
}
Err(std::io::Error::other(
"input stream error (see error channel)",
))
}
None => {
writer
.finish()
.map_err(|e| std::io::Error::other(e.to_string()))?;
let buffer = std::mem::take(writer.get_mut());
if buffer.is_empty() {
Ok(None)
} else {
if let Some(ref t) = tracker {
t.record_bytes(buffer.len());
}
Ok(Some((buffer, (data, writer, None, true))))
}
Err(std::io::Error::other(
"input stream error (see error channel)",
))
}
None => {
writer
.finish()
.map_err(|e| std::io::Error::other(e.to_string()))?;
let buffer = std::mem::take(writer.get_mut());
if buffer.is_empty() {
Ok(None)
} else {
Ok(Some((buffer, (data, writer, None, true))))
}
}
}
@@ -242,11 +174,8 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
}
fn required_input_distribution(&self) -> Vec<datafusion_physical_plan::Distribution> {
if self.upload_id.is_some() {
vec![datafusion_physical_plan::Distribution::UnspecifiedDistribution]
} else {
vec![datafusion_physical_plan::Distribution::SinglePartition]
}
// Until we have a separate commit endpoint, we need to do all inserts in a single partition
vec![datafusion_physical_plan::Distribution::SinglePartition]
}
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
@@ -262,14 +191,12 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
"RemoteInsertExec requires exactly one child".to_string(),
));
}
Ok(Arc::new(Self::new_inner(
Ok(Arc::new(Self::new(
self.table_name.clone(),
self.identifier.clone(),
self.client.clone(),
children[0].clone(),
self.overwrite,
self.upload_id.clone(),
self.tracker.clone(),
)))
}
@@ -278,29 +205,18 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
partition: usize,
context: Arc<TaskContext>,
) -> DataFusionResult<SendableRecordBatchStream> {
if self.upload_id.is_none() && partition != 0 {
if partition != 0 {
return Err(DataFusionError::Internal(
"RemoteInsertExec only supports single partition execution without upload_id"
.to_string(),
"RemoteInsertExec only supports single partition execution".to_string(),
));
}
let input_stream = self.input.execute(partition, context)?;
let input_schema = input_stream.schema();
let input_stream: SendableRecordBatchStream =
Box::pin(InstrumentedRecordBatchStreamAdapter::new(
input_schema,
input_stream,
partition,
&self.metrics,
));
let input_stream = self.input.execute(0, context)?;
let client = self.client.clone();
let identifier = self.identifier.clone();
let overwrite = self.overwrite;
let add_result = self.add_result.clone();
let table_name = self.table_name.clone();
let upload_id = self.upload_id.clone();
let tracker = self.tracker.clone();
let stream = futures::stream::once(async move {
let mut request = client
@@ -310,12 +226,9 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
if overwrite {
request = request.query(&[("mode", "overwrite")]);
}
if let Some(ref uid) = upload_id {
request = request.query(&[("upload_id", uid.as_str())]);
}
let (error_tx, mut error_rx) = tokio::sync::oneshot::channel();
let body = Self::stream_as_http_body(input_stream, error_tx, tracker)?;
let body = Self::stream_as_http_body(input_stream, error_tx)?;
let request = request.body(body);
let result: DataFusionResult<(String, _)> = async {
@@ -349,43 +262,32 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
let (request_id, response) = result?;
// For multipart writes, the staging response is not the final
// version. Only parse AddResult for non-multipart inserts.
if upload_id.is_none() {
let body_text = response.text().await.map_err(|e| {
let body_text = response.text().await.map_err(|e| {
DataFusionError::External(Box::new(Error::Http {
source: Box::new(e),
request_id: request_id.clone(),
status_code: None,
}))
})?;
let parsed_result = if body_text.trim().is_empty() {
// Backward compatible with old servers
AddResult { version: 0 }
} else {
serde_json::from_str(&body_text).map_err(|e| {
DataFusionError::External(Box::new(Error::Http {
source: Box::new(e),
source: format!("Failed to parse add response: {}", e).into(),
request_id: request_id.clone(),
status_code: None,
}))
})?;
let parsed_result = if body_text.trim().is_empty() {
// Backward compatible with old servers
AddResult { version: 0 }
} else {
serde_json::from_str(&body_text).map_err(|e| {
DataFusionError::External(Box::new(Error::Http {
source: format!("Failed to parse add response: {}", e).into(),
request_id: request_id.clone(),
status_code: None,
}))
})?
};
})?
};
{
let mut res_lock = add_result.lock().map_err(|_| {
DataFusionError::Execution("Failed to acquire lock for add_result".to_string())
})?;
*res_lock = Some(parsed_result);
} else {
// We don't use the body in this case, but we should still consume it.
let _ = response.bytes().await.map_err(|e| {
DataFusionError::External(Box::new(Error::Http {
source: Box::new(e),
request_id: request_id.clone(),
status_code: None,
}))
})?;
}
// Return a single batch with count 0 (actual count is tracked in add_result)
@@ -399,10 +301,6 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
stream,
)))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
}
#[cfg(test)]

View File

@@ -74,10 +74,7 @@ pub mod optimize;
pub mod query;
pub mod schema_evolution;
pub mod update;
pub mod write_progress;
use crate::index::waiter::wait_for_index;
#[cfg(feature = "remote")]
pub(crate) use add_data::PreprocessingOutput;
pub use add_data::{AddDataBuilder, AddDataMode, AddResult, NaNVectorBehavior};
pub use chrono::Duration;
pub use delete::DeleteResult;
@@ -443,34 +440,6 @@ mod test_utils {
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
pub fn new_with_handler_version_and_config<T>(
name: impl Into<String>,
version: semver::Version,
handler: impl Fn(reqwest::Request) -> http::Response<T> + Clone + Send + Sync + 'static,
config: crate::remote::ClientConfig,
) -> Self
where
T: Into<reqwest::Body>,
{
let inner = Arc::new(
crate::remote::table::RemoteTable::new_mock_with_version_and_config(
name.into(),
handler.clone(),
Some(version),
config.clone(),
),
);
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock_with_config(
handler, config,
));
Self {
inner,
database: Some(database),
// Registry is unused.
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
}
}
@@ -2229,26 +2198,21 @@ impl BaseTable for NativeTable {
let table_schema = Schema::from(&ds.schema().clone());
let num_partitions = if let Some(parallelism) = add.write_parallelism {
parallelism
// Peek at the first batch to estimate a good partition count for
// write parallelism.
let mut peeked = PeekedScannable::new(add.data);
let num_partitions = if let Some(first_batch) = peeked.peek().await {
let max_partitions = lance_core::utils::tokio::get_num_compute_intensive_cpus();
estimate_write_partitions(
first_batch.get_array_memory_size(),
first_batch.num_rows(),
peeked.num_rows(),
max_partitions,
)
} else {
// Peek at the first batch to estimate a good partition count for
// write parallelism.
let mut peeked = PeekedScannable::new(add.data);
let n = if let Some(first_batch) = peeked.peek().await {
let max_partitions = lance_core::utils::tokio::get_num_compute_intensive_cpus();
estimate_write_partitions(
first_batch.get_array_memory_size(),
first_batch.num_rows(),
peeked.num_rows(),
max_partitions,
)
} else {
1
};
add.data = Box::new(peeked);
n
1
};
add.data = Box::new(peeked);
let output = add.into_plan(&table_schema, &table_def)?;
@@ -2277,21 +2241,13 @@ impl BaseTable for NativeTable {
let insert_exec = Arc::new(InsertExec::new(ds_wrapper.clone(), ds, plan, lance_params));
let tracker_for_tasks = output.tracker.clone();
if let Some(ref t) = tracker_for_tasks {
t.set_total_tasks(num_partitions);
}
let _finish = write_progress::FinishOnDrop(output.tracker);
// Execute all partitions in parallel.
let task_ctx = Arc::new(TaskContext::default());
let handles = FuturesUnordered::new();
for partition in 0..num_partitions {
let exec = insert_exec.clone();
let ctx = task_ctx.clone();
let tracker = tracker_for_tasks.clone();
handles.push(tokio::spawn(async move {
let _guard = tracker.as_ref().map(|t| t.track_task());
let mut stream = exec
.execute(partition, ctx)
.map_err(|e| -> Error { e.into() })?;

View File

@@ -13,9 +13,6 @@ use crate::embeddings::EmbeddingRegistry;
use crate::table::datafusion::cast::cast_to_table_schema;
use crate::table::datafusion::reject_nan::reject_nan_vectors;
use crate::table::datafusion::scannable_exec::ScannableExec;
use crate::table::write_progress::ProgressCallback;
use crate::table::write_progress::WriteProgress;
use crate::table::write_progress::WriteProgressTracker;
use crate::{Error, Result};
use super::{BaseTable, TableDefinition, WriteOptions};
@@ -55,8 +52,6 @@ pub struct AddDataBuilder {
pub(crate) write_options: WriteOptions,
pub(crate) on_nan_vectors: NaNVectorBehavior,
pub(crate) embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
pub(crate) progress_callback: Option<ProgressCallback>,
pub(crate) write_parallelism: Option<usize>,
}
impl std::fmt::Debug for AddDataBuilder {
@@ -82,8 +77,6 @@ impl AddDataBuilder {
write_options: WriteOptions::default(),
on_nan_vectors: NaNVectorBehavior::default(),
embedding_registry,
progress_callback: None,
write_parallelism: None,
}
}
@@ -108,43 +101,7 @@ impl AddDataBuilder {
self
}
/// Set a callback to receive progress updates during the add operation.
///
/// The callback is invoked once per batch written, and once more with
/// [`WriteProgress::done`] set to `true` when the write completes.
///
/// ```
/// # use lancedb::Table;
/// # async fn example(table: &Table) -> Result<(), Box<dyn std::error::Error>> {
/// let batch = arrow_array::record_batch!(("id", Int32, [1, 2, 3])).unwrap();
/// table.add(batch)
/// .progress(|p| println!("{}/{:?} rows", p.output_rows(), p.total_rows()))
/// .execute()
/// .await?;
/// # Ok(())
/// # }
/// ```
pub fn progress(mut self, callback: impl FnMut(&WriteProgress) + Send + 'static) -> Self {
self.progress_callback = Some(Arc::new(std::sync::Mutex::new(callback)));
self
}
/// Set the number of parallel write streams.
///
/// By default, the number of streams is estimated from the data size.
/// Setting this to `1` disables parallel writes.
pub fn write_parallelism(mut self, parallelism: usize) -> Self {
self.write_parallelism = Some(parallelism);
self
}
pub async fn execute(self) -> Result<AddResult> {
if self.write_parallelism.map(|p| p == 0).unwrap_or(false) {
return Err(Error::InvalidInput {
message: "write_parallelism must be greater than 0".to_string(),
});
}
self.parent.clone().add(self).await
}
@@ -173,11 +130,8 @@ impl AddDataBuilder {
scannable_with_embeddings(self.data, table_def, self.embedding_registry.as_ref())?;
let rescannable = self.data.rescannable();
let tracker = self
.progress_callback
.map(|cb| Arc::new(WriteProgressTracker::new(cb, self.data.num_rows())));
let plan: Arc<dyn datafusion_physical_plan::ExecutionPlan> =
Arc::new(ScannableExec::new(self.data, tracker.clone()));
Arc::new(ScannableExec::new(self.data));
// Skip casting when overwriting — the input schema replaces the table schema.
let plan = if overwrite {
plan
@@ -195,7 +149,6 @@ impl AddDataBuilder {
rescannable,
write_options: self.write_options,
mode: self.mode,
tracker,
})
}
}
@@ -208,7 +161,6 @@ pub struct PreprocessingOutput {
pub rescannable: bool,
pub write_options: WriteOptions,
pub mode: AddDataMode,
pub tracker: Option<Arc<WriteProgressTracker>>,
}
/// Check that the input schema is valid for insert.

View File

@@ -12,16 +12,13 @@ use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
};
use futures::TryStreamExt;
use lance::Dataset;
use lance::dataset::transaction::{Operation, Transaction};
use lance::dataset::{CommitBuilder, InsertBuilder, WriteParams};
use lance::io::exec::utils::InstrumentedRecordBatchStreamAdapter;
use lance_table::format::Fragment;
use crate::table::dataset::DatasetConsistencyWrapper;
@@ -83,7 +80,6 @@ pub struct InsertExec {
write_params: WriteParams,
properties: PlanProperties,
partial_transactions: Arc<Mutex<Vec<Transaction>>>,
metrics: ExecutionPlanMetricsSet,
}
impl InsertExec {
@@ -109,7 +105,6 @@ impl InsertExec {
write_params,
properties,
partial_transactions: Arc::new(Mutex::new(Vec::with_capacity(num_partitions))),
metrics: ExecutionPlanMetricsSet::new(),
}
}
}
@@ -181,19 +176,6 @@ impl ExecutionPlan for InsertExec {
let total_partitions = self.input.output_partitioning().partition_count();
let ds_wrapper = self.ds_wrapper.clone();
let output_bytes = MetricBuilder::new(&self.metrics).output_bytes(partition);
let input_schema = input_stream.schema();
let input_stream: SendableRecordBatchStream =
Box::pin(InstrumentedRecordBatchStreamAdapter::new(
input_schema,
input_stream.map_ok(move |batch| {
output_bytes.add(batch.get_array_memory_size());
batch
}),
partition,
&self.metrics,
));
let stream = futures::stream::once(async move {
let transaction = InsertBuilder::new(dataset.clone())
.with_params(&write_params)
@@ -233,10 +215,6 @@ impl ExecutionPlan for InsertExec {
stream,
)))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
}
#[cfg(test)]

View File

@@ -7,21 +7,17 @@ use std::sync::{Arc, Mutex};
use datafusion_common::{DataFusionError, Result as DFResult, Statistics, stats::Precision};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, execution_plan::EmissionType,
};
use futures::TryStreamExt;
use crate::table::write_progress::WriteProgressTracker;
use crate::{arrow::SendableRecordBatchStreamExt, data::scannable::Scannable};
pub(crate) struct ScannableExec {
// We don't require Scannable to be Sync, so we wrap it in a Mutex to allow safe concurrent access.
pub struct ScannableExec {
// We don't require Scannable to by Sync, so we wrap it in a Mutex to allow safe concurrent access.
source: Mutex<Box<dyn Scannable>>,
num_rows: Option<usize>,
properties: PlanProperties,
tracker: Option<Arc<WriteProgressTracker>>,
}
impl std::fmt::Debug for ScannableExec {
@@ -34,7 +30,7 @@ impl std::fmt::Debug for ScannableExec {
}
impl ScannableExec {
pub fn new(source: Box<dyn Scannable>, tracker: Option<Arc<WriteProgressTracker>>) -> Self {
pub fn new(source: Box<dyn Scannable>) -> Self {
let schema = source.schema();
let eq_properties = EquivalenceProperties::new(schema);
let properties = PlanProperties::new(
@@ -50,7 +46,6 @@ impl ScannableExec {
source,
num_rows,
properties,
tracker,
}
}
}
@@ -107,18 +102,7 @@ impl ExecutionPlan for ScannableExec {
Err(poison) => poison.into_inner().scan_as_stream(),
};
let tracker = self.tracker.clone();
let stream = stream.into_df_stream().map_ok(move |batch| {
if let Some(ref t) = tracker {
t.record_batch(batch.num_rows(), batch.get_array_memory_size());
}
batch
});
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
stream,
)))
Ok(stream.into_df_stream())
}
fn partition_statistics(&self, _partition: Option<usize>) -> DFResult<Statistics> {

View File

@@ -1,379 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Progress monitoring for write operations.
//!
//! You can add a callback to process progress in [`crate::table::AddDataBuilder::progress`].
//! [`WriteProgress`] is the struct passed to the callback.
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
/// Progress snapshot for a write operation.
#[derive(Debug, Clone)]
pub struct WriteProgress {
// These are private and only accessible via getters, to make it easy to add
// new fields without breaking existing callbacks.
elapsed: Duration,
output_rows: usize,
output_bytes: usize,
total_rows: Option<usize>,
active_tasks: usize,
total_tasks: usize,
done: bool,
}
impl WriteProgress {
/// Wall-clock time since monitoring started.
pub fn elapsed(&self) -> Duration {
self.elapsed
}
/// Number of rows written so far.
pub fn output_rows(&self) -> usize {
self.output_rows
}
/// Number of bytes written so far.
pub fn output_bytes(&self) -> usize {
self.output_bytes
}
/// Total rows expected.
///
/// Populated when the input source reports a row count (e.g. a
/// [`arrow_array::RecordBatch`]). Always `Some` when [`WriteProgress::done`]
/// is `true` — falling back to the actual number of rows written.
pub fn total_rows(&self) -> Option<usize> {
self.total_rows
}
/// Number of parallel write tasks currently in flight.
pub fn active_tasks(&self) -> usize {
self.active_tasks
}
/// Total number of parallel write tasks (i.e. the write parallelism).
pub fn total_tasks(&self) -> usize {
self.total_tasks
}
/// Whether the write operation has completed.
///
/// The final callback always has `done = true`. Callers can use this to
/// finalize progress bars or perform cleanup.
pub fn done(&self) -> bool {
self.done
}
}
/// Callback type for progress updates.
///
/// Callbacks are serialized by the tracker and are never invoked reentrantly,
/// so `FnMut` is safe to use here.
pub type ProgressCallback = Arc<Mutex<dyn FnMut(&WriteProgress) + Send>>;
/// Tracks progress of a write operation and invokes a [`ProgressCallback`].
///
/// Call [`WriteProgressTracker::record_batch`] for each batch written.
/// Call [`WriteProgressTracker::finish`] once after all data is written.
///
/// The callback is never invoked reentrantly: all state updates and callback
/// invocations are serialized behind a single lock.
impl std::fmt::Debug for WriteProgressTracker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WriteProgressTracker")
.field("total_rows", &self.total_rows)
.finish()
}
}
pub(crate) struct WriteProgressTracker {
rows_and_bytes: std::sync::Mutex<(usize, usize)>,
/// Wire bytes tracked separately by the insert layer. When set (> 0),
/// this takes precedence over the in-memory bytes from `rows_and_bytes`.
wire_bytes: AtomicUsize,
active_tasks: Arc<AtomicUsize>,
total_tasks: AtomicUsize,
start: Instant,
/// Known total rows from the input source, if available.
total_rows: Option<usize>,
callback: ProgressCallback,
}
impl WriteProgressTracker {
pub fn new(callback: ProgressCallback, total_rows: Option<usize>) -> Self {
Self {
rows_and_bytes: std::sync::Mutex::new((0, 0)),
wire_bytes: AtomicUsize::new(0),
active_tasks: Arc::new(AtomicUsize::new(0)),
total_tasks: AtomicUsize::new(1),
start: Instant::now(),
total_rows,
callback,
}
}
/// Set the total number of parallel write tasks (the write parallelism).
pub fn set_total_tasks(&self, n: usize) {
self.total_tasks.store(n, Ordering::Relaxed);
}
/// Increment the active task count. Returns a guard that decrements on drop.
pub fn track_task(&self) -> ActiveTaskGuard {
self.active_tasks.fetch_add(1, Ordering::Relaxed);
ActiveTaskGuard(self.active_tasks.clone())
}
/// Record a batch of rows passing through the scan node.
pub fn record_batch(&self, rows: usize, bytes: usize) {
// Lock order: callback first, then rows_and_bytes. This is the only
// order used anywhere, so deadlocks cannot occur.
let mut cb = self.callback.lock().unwrap();
let mut guard = self.rows_and_bytes.lock().unwrap();
guard.0 += rows;
guard.1 += bytes;
let progress = self.snapshot(guard.0, guard.1, false);
drop(guard);
cb(&progress);
}
/// Record wire bytes from the insert layer (e.g. IPC-encoded bytes for
/// remote writes). When wire bytes are recorded, they take precedence over
/// the in-memory Arrow bytes tracked by [`record_batch`].
pub fn record_bytes(&self, bytes: usize) {
self.wire_bytes.fetch_add(bytes, Ordering::Relaxed);
}
/// Emit the final progress callback indicating the write is complete.
///
/// `total_rows` is always `Some` on the final callback: it uses the known
/// total if available, or falls back to the number of rows actually written.
pub fn finish(&self) {
let mut cb = self.callback.lock().unwrap();
let guard = self.rows_and_bytes.lock().unwrap();
let mut snap = self.snapshot(guard.0, guard.1, true);
snap.total_rows = Some(self.total_rows.unwrap_or(guard.0));
drop(guard);
cb(&snap);
}
fn snapshot(&self, rows: usize, in_memory_bytes: usize, done: bool) -> WriteProgress {
let wire = self.wire_bytes.load(Ordering::Relaxed);
// Prefer wire bytes (actual I/O size) when the insert layer is
// tracking them; fall back to in-memory Arrow size otherwise.
// TODO: for local writes, track actual bytes written by Lance
// instead of using in-memory Arrow size as a proxy.
let output_bytes = if wire > 0 { wire } else { in_memory_bytes };
WriteProgress {
elapsed: self.start.elapsed(),
output_rows: rows,
output_bytes,
total_rows: self.total_rows,
active_tasks: self.active_tasks.load(Ordering::Relaxed),
total_tasks: self.total_tasks.load(Ordering::Relaxed),
done,
}
}
}
/// RAII guard that decrements the active task count when dropped.
pub(crate) struct ActiveTaskGuard(Arc<AtomicUsize>);
impl Drop for ActiveTaskGuard {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::Relaxed);
}
}
/// RAII guard that calls [`WriteProgressTracker::finish`] on drop.
///
/// This ensures the final `done=true` callback fires even if the write
/// errors or the future is cancelled.
pub(crate) struct FinishOnDrop(pub Option<Arc<WriteProgressTracker>>);
impl Drop for FinishOnDrop {
fn drop(&mut self) {
if let Some(t) = self.0.take() {
t.finish();
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use arrow_array::record_batch;
use crate::connect;
#[tokio::test]
async fn test_progress_monitor_fires_callback() {
let db = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
let table = db
.create_table("progress_test", batch)
.execute()
.await
.unwrap();
let callback_count = Arc::new(AtomicUsize::new(0));
let last_rows = Arc::new(AtomicUsize::new(0));
let max_active = Arc::new(AtomicUsize::new(0));
let last_total_tasks = Arc::new(AtomicUsize::new(0));
let cb_count = callback_count.clone();
let cb_rows = last_rows.clone();
let cb_active = max_active.clone();
let cb_total_tasks = last_total_tasks.clone();
let new_data = record_batch!(("id", Int32, [4, 5, 6])).unwrap();
table
.add(new_data)
.progress(move |p| {
cb_count.fetch_add(1, Ordering::SeqCst);
cb_rows.store(p.output_rows(), Ordering::SeqCst);
cb_active.fetch_max(p.active_tasks(), Ordering::SeqCst);
cb_total_tasks.store(p.total_tasks(), Ordering::SeqCst);
})
.execute()
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 6);
assert!(callback_count.load(Ordering::SeqCst) >= 1);
// Progress tracks the newly inserted rows, not the total table size.
assert_eq!(last_rows.load(Ordering::SeqCst), 3);
// At least one callback should have seen an active task.
assert!(max_active.load(Ordering::SeqCst) >= 1);
// total_tasks should reflect the write parallelism.
assert!(last_total_tasks.load(Ordering::SeqCst) >= 1);
}
#[tokio::test]
async fn test_progress_done_fires_at_end() {
let db = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
let table = db
.create_table("progress_done", batch)
.execute()
.await
.unwrap();
let seen_done = Arc::new(std::sync::Mutex::new(Vec::<bool>::new()));
let seen = seen_done.clone();
let new_data = record_batch!(("id", Int32, [4, 5, 6])).unwrap();
table
.add(new_data)
.progress(move |p| {
seen.lock().unwrap().push(p.done());
})
.execute()
.await
.unwrap();
let done_flags = seen_done.lock().unwrap();
assert!(!done_flags.is_empty(), "at least one callback must fire");
// Only the last callback should have done=true.
let last = *done_flags.last().unwrap();
assert!(last, "last callback must have done=true");
// All earlier callbacks should have done=false.
for &d in done_flags.iter().rev().skip(1) {
assert!(!d, "non-final callbacks must have done=false");
}
}
#[tokio::test]
async fn test_progress_total_rows_known() {
let db = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
let table = db
.create_table("total_known", batch)
.execute()
.await
.unwrap();
let seen_total = Arc::new(std::sync::Mutex::new(Vec::new()));
let seen = seen_total.clone();
// RecordBatch implements Scannable with num_rows() -> Some(3)
let new_data = record_batch!(("id", Int32, [4, 5, 6])).unwrap();
table
.add(new_data)
.progress(move |p| {
seen.lock().unwrap().push(p.total_rows());
})
.execute()
.await
.unwrap();
let totals = seen_total.lock().unwrap();
// All callbacks (including done) should have total_rows = Some(3)
assert!(
totals.contains(&Some(3)),
"expected total_rows=Some(3) in at least one callback, got: {:?}",
*totals
);
}
#[tokio::test]
async fn test_progress_total_rows_unknown() {
use arrow_array::RecordBatchIterator;
let db = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
let table = db
.create_table("total_unknown", batch)
.execute()
.await
.unwrap();
let seen_total = Arc::new(std::sync::Mutex::new(Vec::new()));
let seen = seen_total.clone();
// RecordBatchReader does not provide num_rows, so total_rows should be
// None in intermediate callbacks but always Some on the done callback.
let schema = arrow_schema::Schema::new(vec![arrow_schema::Field::new(
"id",
arrow_schema::DataType::Int32,
false,
)]);
let new_data: Box<dyn arrow_array::RecordBatchReader + Send> =
Box::new(RecordBatchIterator::new(
vec![Ok(record_batch!(("id", Int32, [4, 5, 6])).unwrap())],
Arc::new(schema),
));
table
.add(new_data)
.progress(move |p| {
seen.lock().unwrap().push((p.total_rows(), p.done()));
})
.execute()
.await
.unwrap();
let entries = seen_total.lock().unwrap();
assert!(!entries.is_empty(), "at least one callback must fire");
for (total, done) in entries.iter() {
if *done {
assert!(
total.is_some(),
"done callback must have total_rows set, got: {:?}",
total
);
} else {
assert_eq!(
*total, None,
"intermediate callback must have total_rows=None, got: {:?}",
total
);
}
}
}
}