feat: progress bar for add() (#3067)

## Summary

Adds progress reporting for `table.add()` so users can track large write
operations. The progress callback is available in Rust, Python (sync and
async), and through the PyO3 bindings.

### Usage

Pass `progress=True` to get an automatic tqdm bar:

```python
table.add(data, progress=True)
# 100%|██████████| 1000000/1000000 [00:12<00:00, 82345 rows/s, 45.2 MB/s | 4/4 workers]
```

Or pass a tqdm bar for more control:

```python
from tqdm import tqdm

with tqdm(unit=" rows") as pbar:
    table.add(data, progress=pbar)
```

Or use a callback for custom progress handling:

```python
def on_progress(p):
    print(f"{p['output_rows']}/{p['total_rows']} rows, "
          f"{p['active_tasks']}/{p['total_tasks']} workers, "
          f"done={p['done']}")

table.add(data, progress=on_progress)
```

In Rust:

```rust
table.add(data)
    .progress(|p| println!("{}/{:?} rows", p.output_rows(), p.total_rows()))
    .execute()
    .await?;
```

### Details

- `WriteProgress` struct in Rust with getters for `elapsed`,
`output_rows`, `output_bytes`, `total_rows`, `active_tasks`,
`total_tasks`, and `done`. Fields are private behind getters so new
fields can be added without breaking changes.
- `WriteProgressTracker` tracks progress across parallel write tasks
using a mutex for row/byte counts and atomics for active task counts.
- Active task tracking uses an RAII guard pattern (`ActiveTaskGuard`)
that increments on creation and decrements on drop.
- For remote writes, `output_bytes` reflects IPC wire bytes rather than
in-memory Arrow size. For local writes it uses in-memory Arrow size as a
proxy (see TODO below).
- tqdm postfix displays throughput (MB/s) and worker utilization
(active/total).
- The `done` callback always fires, even on error (via `FinishOnDrop`),
so progress bars are always finalized.

### TODO

- Track actual bytes written to disk for local tables. This requires
Lance to expose a progress callback from its write path. See
lance-format/lance#6247.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Will Jones
2026-03-23 16:14:13 -07:00
committed by GitHub
parent a0228036ae
commit 1d6e00b902
14 changed files with 894 additions and 48 deletions

View File

@@ -135,7 +135,10 @@ class Table:
def close(self) -> None: ...
async def schema(self) -> pa.Schema: ...
async def add(
self, data: pa.RecordBatchReader, mode: Literal["append", "overwrite"]
self,
data: pa.RecordBatchReader,
mode: Literal["append", "overwrite"],
progress: Optional[Any] = None,
) -> AddResult: ...
async def update(
self, updates: Dict[str, str], where: Optional[str]

View File

@@ -4,7 +4,7 @@
from datetime import timedelta
import logging
from functools import cached_property
from typing import Dict, Iterable, List, Optional, Union, Literal
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, Literal
import warnings
from lancedb._lancedb import (
@@ -35,6 +35,7 @@ import pyarrow as pa
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from lancedb.merge import LanceMergeInsertBuilder
from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.table import _normalize_progress
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder, LanceTakeQueryBuilder
from ..table import AsyncTable, IndexStatistics, Query, Table, Tags
@@ -308,6 +309,7 @@ class RemoteTable(Table):
mode: str = "append",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
progress: Optional[Union[bool, Callable, Any]] = None,
) -> AddResult:
"""Add more data to the [Table](Table). It has the same API signature as
the OSS version.
@@ -330,17 +332,29 @@ class RemoteTable(Table):
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
progress: bool, callable, or tqdm-like, optional
A callback or tqdm-compatible progress bar. See
:meth:`Table.add` for details.
Returns
-------
AddResult
An object containing the new version number of the table after adding data.
"""
return LOOP.run(
self._table.add(
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
progress, owns = _normalize_progress(progress)
try:
return LOOP.run(
self._table.add(
data,
mode=mode,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
progress=progress,
)
)
)
finally:
if owns:
progress.close()
def search(
self,

View File

@@ -14,6 +14,7 @@ from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
@@ -556,6 +557,21 @@ def _table_uri(base: str, table_name: str) -> str:
return join_uri(base, f"{table_name}.lance")
def _normalize_progress(progress):
"""Normalize a ``progress`` parameter for :meth:`Table.add`.
Returns ``(progress_obj, owns)`` where *owns* is True when we created a
tqdm bar that the caller must close.
"""
if progress is True:
from tqdm.auto import tqdm
return tqdm(unit=" rows"), True
if progress is False or progress is None:
return None, False
return progress, False
class Table(ABC):
"""
A Table is a collection of Records in a LanceDB Database.
@@ -974,6 +990,7 @@ class Table(ABC):
mode: AddMode = "append",
on_bad_vectors: OnBadVectorsType = "error",
fill_value: float = 0.0,
progress: Optional[Union[bool, Callable, Any]] = None,
) -> AddResult:
"""Add more data to the [Table](Table).
@@ -995,6 +1012,29 @@ class Table(ABC):
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
progress: bool, callable, or tqdm-like, optional
Progress reporting during the add operation. Can be:
- ``True`` to automatically create and display a tqdm progress
bar (requires ``tqdm`` to be installed)::
table.add(data, progress=True)
- A **callable** that receives a dict with keys ``output_rows``,
``output_bytes``, ``total_rows``, ``elapsed_seconds``,
``active_tasks``, ``total_tasks``, and ``done``::
def on_progress(p):
print(f"{p['output_rows']}/{p['total_rows']} rows, "
f"{p['active_tasks']}/{p['total_tasks']} workers")
table.add(data, progress=on_progress)
- A **tqdm-compatible** progress bar whose ``total`` and
``update()`` will be called automatically. The postfix shows
write throughput (MB/s) and active worker count::
with tqdm() as pbar:
table.add(data, progress=pbar)
Returns
-------
@@ -2492,6 +2532,7 @@ class LanceTable(Table):
mode: AddMode = "append",
on_bad_vectors: OnBadVectorsType = "error",
fill_value: float = 0.0,
progress: Optional[Union[bool, Callable, Any]] = None,
) -> AddResult:
"""Add data to the table.
If vector columns are missing and the table
@@ -2510,17 +2551,29 @@ class LanceTable(Table):
One of "error", "drop", "fill", "null".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
progress: bool, callable, or tqdm-like, optional
A callback or tqdm-compatible progress bar. See
:meth:`Table.add` for details.
Returns
-------
int
The number of vectors in the table.
"""
return LOOP.run(
self._table.add(
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
progress, owns = _normalize_progress(progress)
try:
return LOOP.run(
self._table.add(
data,
mode=mode,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
progress=progress,
)
)
)
finally:
if owns:
progress.close()
def merge(
self,
@@ -3769,6 +3822,7 @@ class AsyncTable:
mode: Optional[Literal["append", "overwrite"]] = "append",
on_bad_vectors: Optional[OnBadVectorsType] = None,
fill_value: Optional[float] = None,
progress: Optional[Union[bool, Callable, Any]] = None,
) -> AddResult:
"""Add more data to the [Table](Table).
@@ -3790,6 +3844,9 @@ class AsyncTable:
One of "error", "drop", "fill", "null".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
progress: callable or tqdm-like, optional
A callback or tqdm-compatible progress bar. See
:meth:`Table.add` for details.
"""
schema = await self.schema()
@@ -3813,8 +3870,9 @@ class AsyncTable:
)
_register_optional_converters()
data = to_scannable(data)
progress, owns = _normalize_progress(progress)
try:
return await self._inner.add(data, mode or "append")
return await self._inner.add(data, mode or "append", progress=progress)
except RuntimeError as e:
if "Cast error" in str(e):
raise ValueError(e)
@@ -3822,6 +3880,9 @@ class AsyncTable:
raise ValueError(e)
else:
raise
finally:
if owns:
progress.close()
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
"""

View File

@@ -527,6 +527,102 @@ async def test_add_async(mem_db_async: AsyncConnection):
assert await table.count_rows() == 3
def test_add_progress_callback(mem_db: DBConnection):
table = mem_db.create_table(
"test",
data=[{"id": 1}, {"id": 2}],
)
updates = []
table.add([{"id": 3}, {"id": 4}], progress=lambda p: updates.append(dict(p)))
assert len(table) == 4
# The done callback always fires, so we should always get at least one.
assert len(updates) >= 1, "expected at least one progress callback"
for p in updates:
assert "output_rows" in p
assert "output_bytes" in p
assert "total_rows" in p
assert "elapsed_seconds" in p
assert "active_tasks" in p
assert "total_tasks" in p
assert "done" in p
# The last callback should have done=True.
assert updates[-1]["done"] is True
def test_add_progress_tqdm_like(mem_db: DBConnection):
"""Test that a tqdm-like object gets total set and update() called."""
class FakeBar:
def __init__(self):
self.total = None
self.n = 0
self.postfix = None
def update(self, n):
self.n += n
def set_postfix_str(self, s):
self.postfix = s
def refresh(self):
pass
table = mem_db.create_table(
"test",
data=[{"id": 1}, {"id": 2}],
)
bar = FakeBar()
table.add([{"id": 3}, {"id": 4}], progress=bar)
assert len(table) == 4
# Postfix should contain throughput and worker count
if bar.postfix is not None:
assert "MB/s" in bar.postfix
assert "workers" in bar.postfix
def test_add_progress_bool(mem_db: DBConnection):
"""Test that progress=True creates and closes a tqdm bar automatically."""
table = mem_db.create_table(
"test",
data=[{"id": 1}, {"id": 2}],
)
table.add([{"id": 3}, {"id": 4}], progress=True)
assert len(table) == 4
# progress=False should be the same as None
table.add([{"id": 5}], progress=False)
assert len(table) == 5
@pytest.mark.asyncio
async def test_add_progress_callback_async(mem_db_async: AsyncConnection):
"""Progress callbacks work through the async path too."""
table = await mem_db_async.create_table("test", data=[{"id": 1}, {"id": 2}])
updates = []
await table.add([{"id": 3}, {"id": 4}], progress=lambda p: updates.append(dict(p)))
assert await table.count_rows() == 4
assert len(updates) >= 1
assert updates[-1]["done"] is True
def test_add_progress_callback_error(mem_db: DBConnection):
"""A failing callback must not prevent the write from succeeding."""
table = mem_db.create_table("test", data=[{"id": 1}, {"id": 2}])
def bad_callback(p):
raise RuntimeError("boom")
table.add([{"id": 3}, {"id": 4}], progress=bad_callback)
assert len(table) == 4
def test_polars(mem_db: DBConnection):
data = {
"vector": [[3.1, 4.1], [5.9, 26.5]],