mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-17 03:50:38 +00:00
feat: progress bar for add() (#3067)
## Summary
Adds progress reporting for `table.add()` so users can track large write
operations. The progress callback is available in Rust, Python (sync and
async), and through the PyO3 bindings.
### Usage
Pass `progress=True` to get an automatic tqdm bar:
```python
table.add(data, progress=True)
# 100%|██████████| 1000000/1000000 [00:12<00:00, 82345 rows/s, 45.2 MB/s | 4/4 workers]
```
Or pass a tqdm bar for more control:
```python
from tqdm import tqdm
with tqdm(unit=" rows") as pbar:
table.add(data, progress=pbar)
```
Or use a callback for custom progress handling:
```python
def on_progress(p):
print(f"{p['output_rows']}/{p['total_rows']} rows, "
f"{p['active_tasks']}/{p['total_tasks']} workers, "
f"done={p['done']}")
table.add(data, progress=on_progress)
```
In Rust:
```rust
table.add(data)
.progress(|p| println!("{}/{:?} rows", p.output_rows(), p.total_rows()))
.execute()
.await?;
```
### Details
- `WriteProgress` struct in Rust with getters for `elapsed`,
`output_rows`, `output_bytes`, `total_rows`, `active_tasks`,
`total_tasks`, and `done`. Fields are private behind getters so new
fields can be added without breaking changes.
- `WriteProgressTracker` tracks progress across parallel write tasks
using a mutex for row/byte counts and atomics for active task counts.
- Active task tracking uses an RAII guard pattern (`ActiveTaskGuard`)
that increments on creation and decrements on drop.
- For remote writes, `output_bytes` reflects IPC wire bytes rather than
in-memory Arrow size. For local writes it uses in-memory Arrow size as a
proxy (see TODO below).
- tqdm postfix displays throughput (MB/s) and worker utilization
(active/total).
- The `done` callback always fires, even on error (via `FinishOnDrop`),
so progress bars are always finalized.
### TODO
- Track actual bytes written to disk for local tables. This requires
Lance to expose a progress callback from its write path. See
lance-format/lance#6247.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
---------
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -23,6 +23,7 @@ lance-namespace.workspace = true
|
||||
lance-namespace-impls.workspace = true
|
||||
lance-io.workspace = true
|
||||
env_logger.workspace = true
|
||||
log.workspace = true
|
||||
pyo3 = { version = "0.26", features = ["extension-module", "abi3-py39"] }
|
||||
pyo3-async-runtimes = { version = "0.26", features = [
|
||||
"attributes",
|
||||
|
||||
@@ -135,7 +135,10 @@ class Table:
|
||||
def close(self) -> None: ...
|
||||
async def schema(self) -> pa.Schema: ...
|
||||
async def add(
|
||||
self, data: pa.RecordBatchReader, mode: Literal["append", "overwrite"]
|
||||
self,
|
||||
data: pa.RecordBatchReader,
|
||||
mode: Literal["append", "overwrite"],
|
||||
progress: Optional[Any] = None,
|
||||
) -> AddResult: ...
|
||||
async def update(
|
||||
self, updates: Dict[str, str], where: Optional[str]
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from functools import cached_property
|
||||
from typing import Dict, Iterable, List, Optional, Union, Literal
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, Literal
|
||||
import warnings
|
||||
|
||||
from lancedb._lancedb import (
|
||||
@@ -35,6 +35,7 @@ import pyarrow as pa
|
||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from lancedb.merge import LanceMergeInsertBuilder
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.table import _normalize_progress
|
||||
|
||||
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder, LanceTakeQueryBuilder
|
||||
from ..table import AsyncTable, IndexStatistics, Query, Table, Tags
|
||||
@@ -308,6 +309,7 @@ class RemoteTable(Table):
|
||||
mode: str = "append",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
progress: Optional[Union[bool, Callable, Any]] = None,
|
||||
) -> AddResult:
|
||||
"""Add more data to the [Table](Table). It has the same API signature as
|
||||
the OSS version.
|
||||
@@ -330,17 +332,29 @@ class RemoteTable(Table):
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
progress: bool, callable, or tqdm-like, optional
|
||||
A callback or tqdm-compatible progress bar. See
|
||||
:meth:`Table.add` for details.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AddResult
|
||||
An object containing the new version number of the table after adding data.
|
||||
"""
|
||||
return LOOP.run(
|
||||
self._table.add(
|
||||
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
progress, owns = _normalize_progress(progress)
|
||||
try:
|
||||
return LOOP.run(
|
||||
self._table.add(
|
||||
data,
|
||||
mode=mode,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
progress=progress,
|
||||
)
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if owns:
|
||||
progress.close()
|
||||
|
||||
def search(
|
||||
self,
|
||||
|
||||
@@ -14,6 +14,7 @@ from functools import cached_property
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
@@ -556,6 +557,21 @@ def _table_uri(base: str, table_name: str) -> str:
|
||||
return join_uri(base, f"{table_name}.lance")
|
||||
|
||||
|
||||
def _normalize_progress(progress):
|
||||
"""Normalize a ``progress`` parameter for :meth:`Table.add`.
|
||||
|
||||
Returns ``(progress_obj, owns)`` where *owns* is True when we created a
|
||||
tqdm bar that the caller must close.
|
||||
"""
|
||||
if progress is True:
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
return tqdm(unit=" rows"), True
|
||||
if progress is False or progress is None:
|
||||
return None, False
|
||||
return progress, False
|
||||
|
||||
|
||||
class Table(ABC):
|
||||
"""
|
||||
A Table is a collection of Records in a LanceDB Database.
|
||||
@@ -974,6 +990,7 @@ class Table(ABC):
|
||||
mode: AddMode = "append",
|
||||
on_bad_vectors: OnBadVectorsType = "error",
|
||||
fill_value: float = 0.0,
|
||||
progress: Optional[Union[bool, Callable, Any]] = None,
|
||||
) -> AddResult:
|
||||
"""Add more data to the [Table](Table).
|
||||
|
||||
@@ -995,6 +1012,29 @@ class Table(ABC):
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
progress: bool, callable, or tqdm-like, optional
|
||||
Progress reporting during the add operation. Can be:
|
||||
|
||||
- ``True`` to automatically create and display a tqdm progress
|
||||
bar (requires ``tqdm`` to be installed)::
|
||||
|
||||
table.add(data, progress=True)
|
||||
|
||||
- A **callable** that receives a dict with keys ``output_rows``,
|
||||
``output_bytes``, ``total_rows``, ``elapsed_seconds``,
|
||||
``active_tasks``, ``total_tasks``, and ``done``::
|
||||
|
||||
def on_progress(p):
|
||||
print(f"{p['output_rows']}/{p['total_rows']} rows, "
|
||||
f"{p['active_tasks']}/{p['total_tasks']} workers")
|
||||
table.add(data, progress=on_progress)
|
||||
|
||||
- A **tqdm-compatible** progress bar whose ``total`` and
|
||||
``update()`` will be called automatically. The postfix shows
|
||||
write throughput (MB/s) and active worker count::
|
||||
|
||||
with tqdm() as pbar:
|
||||
table.add(data, progress=pbar)
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -2492,6 +2532,7 @@ class LanceTable(Table):
|
||||
mode: AddMode = "append",
|
||||
on_bad_vectors: OnBadVectorsType = "error",
|
||||
fill_value: float = 0.0,
|
||||
progress: Optional[Union[bool, Callable, Any]] = None,
|
||||
) -> AddResult:
|
||||
"""Add data to the table.
|
||||
If vector columns are missing and the table
|
||||
@@ -2510,17 +2551,29 @@ class LanceTable(Table):
|
||||
One of "error", "drop", "fill", "null".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
progress: bool, callable, or tqdm-like, optional
|
||||
A callback or tqdm-compatible progress bar. See
|
||||
:meth:`Table.add` for details.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The number of vectors in the table.
|
||||
"""
|
||||
return LOOP.run(
|
||||
self._table.add(
|
||||
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
progress, owns = _normalize_progress(progress)
|
||||
try:
|
||||
return LOOP.run(
|
||||
self._table.add(
|
||||
data,
|
||||
mode=mode,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
progress=progress,
|
||||
)
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if owns:
|
||||
progress.close()
|
||||
|
||||
def merge(
|
||||
self,
|
||||
@@ -3769,6 +3822,7 @@ class AsyncTable:
|
||||
mode: Optional[Literal["append", "overwrite"]] = "append",
|
||||
on_bad_vectors: Optional[OnBadVectorsType] = None,
|
||||
fill_value: Optional[float] = None,
|
||||
progress: Optional[Union[bool, Callable, Any]] = None,
|
||||
) -> AddResult:
|
||||
"""Add more data to the [Table](Table).
|
||||
|
||||
@@ -3790,6 +3844,9 @@ class AsyncTable:
|
||||
One of "error", "drop", "fill", "null".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
progress: callable or tqdm-like, optional
|
||||
A callback or tqdm-compatible progress bar. See
|
||||
:meth:`Table.add` for details.
|
||||
|
||||
"""
|
||||
schema = await self.schema()
|
||||
@@ -3813,8 +3870,9 @@ class AsyncTable:
|
||||
)
|
||||
_register_optional_converters()
|
||||
data = to_scannable(data)
|
||||
progress, owns = _normalize_progress(progress)
|
||||
try:
|
||||
return await self._inner.add(data, mode or "append")
|
||||
return await self._inner.add(data, mode or "append", progress=progress)
|
||||
except RuntimeError as e:
|
||||
if "Cast error" in str(e):
|
||||
raise ValueError(e)
|
||||
@@ -3822,6 +3880,9 @@ class AsyncTable:
|
||||
raise ValueError(e)
|
||||
else:
|
||||
raise
|
||||
finally:
|
||||
if owns:
|
||||
progress.close()
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
|
||||
@@ -527,6 +527,102 @@ async def test_add_async(mem_db_async: AsyncConnection):
|
||||
assert await table.count_rows() == 3
|
||||
|
||||
|
||||
def test_add_progress_callback(mem_db: DBConnection):
|
||||
table = mem_db.create_table(
|
||||
"test",
|
||||
data=[{"id": 1}, {"id": 2}],
|
||||
)
|
||||
|
||||
updates = []
|
||||
table.add([{"id": 3}, {"id": 4}], progress=lambda p: updates.append(dict(p)))
|
||||
|
||||
assert len(table) == 4
|
||||
# The done callback always fires, so we should always get at least one.
|
||||
assert len(updates) >= 1, "expected at least one progress callback"
|
||||
for p in updates:
|
||||
assert "output_rows" in p
|
||||
assert "output_bytes" in p
|
||||
assert "total_rows" in p
|
||||
assert "elapsed_seconds" in p
|
||||
assert "active_tasks" in p
|
||||
assert "total_tasks" in p
|
||||
assert "done" in p
|
||||
# The last callback should have done=True.
|
||||
assert updates[-1]["done"] is True
|
||||
|
||||
|
||||
def test_add_progress_tqdm_like(mem_db: DBConnection):
|
||||
"""Test that a tqdm-like object gets total set and update() called."""
|
||||
|
||||
class FakeBar:
|
||||
def __init__(self):
|
||||
self.total = None
|
||||
self.n = 0
|
||||
self.postfix = None
|
||||
|
||||
def update(self, n):
|
||||
self.n += n
|
||||
|
||||
def set_postfix_str(self, s):
|
||||
self.postfix = s
|
||||
|
||||
def refresh(self):
|
||||
pass
|
||||
|
||||
table = mem_db.create_table(
|
||||
"test",
|
||||
data=[{"id": 1}, {"id": 2}],
|
||||
)
|
||||
|
||||
bar = FakeBar()
|
||||
table.add([{"id": 3}, {"id": 4}], progress=bar)
|
||||
|
||||
assert len(table) == 4
|
||||
# Postfix should contain throughput and worker count
|
||||
if bar.postfix is not None:
|
||||
assert "MB/s" in bar.postfix
|
||||
assert "workers" in bar.postfix
|
||||
|
||||
|
||||
def test_add_progress_bool(mem_db: DBConnection):
|
||||
"""Test that progress=True creates and closes a tqdm bar automatically."""
|
||||
table = mem_db.create_table(
|
||||
"test",
|
||||
data=[{"id": 1}, {"id": 2}],
|
||||
)
|
||||
|
||||
table.add([{"id": 3}, {"id": 4}], progress=True)
|
||||
assert len(table) == 4
|
||||
|
||||
# progress=False should be the same as None
|
||||
table.add([{"id": 5}], progress=False)
|
||||
assert len(table) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_progress_callback_async(mem_db_async: AsyncConnection):
|
||||
"""Progress callbacks work through the async path too."""
|
||||
table = await mem_db_async.create_table("test", data=[{"id": 1}, {"id": 2}])
|
||||
|
||||
updates = []
|
||||
await table.add([{"id": 3}, {"id": 4}], progress=lambda p: updates.append(dict(p)))
|
||||
|
||||
assert await table.count_rows() == 4
|
||||
assert len(updates) >= 1
|
||||
assert updates[-1]["done"] is True
|
||||
|
||||
|
||||
def test_add_progress_callback_error(mem_db: DBConnection):
|
||||
"""A failing callback must not prevent the write from succeeding."""
|
||||
table = mem_db.create_table("test", data=[{"id": 1}, {"id": 2}])
|
||||
|
||||
def bad_callback(p):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
table.add([{"id": 3}, {"id": 4}], progress=bad_callback)
|
||||
assert len(table) == 4
|
||||
|
||||
|
||||
def test_polars(mem_db: DBConnection):
|
||||
data = {
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
|
||||
@@ -19,7 +19,7 @@ use lancedb::table::{
|
||||
Table as LanceDbTable,
|
||||
};
|
||||
use pyo3::{
|
||||
Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
|
||||
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
||||
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
|
||||
pyclass, pymethods,
|
||||
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
|
||||
@@ -299,10 +299,12 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (data, mode, progress=None))]
|
||||
pub fn add<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
data: PyScannable,
|
||||
mode: String,
|
||||
progress: Option<Py<PyAny>>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let mut op = self_.inner_ref()?.add(data);
|
||||
if mode == "append" {
|
||||
@@ -312,6 +314,81 @@ impl Table {
|
||||
} else {
|
||||
return Err(PyValueError::new_err(format!("Invalid mode: {}", mode)));
|
||||
}
|
||||
if let Some(progress_obj) = progress {
|
||||
let is_callable = Python::attach(|py| progress_obj.bind(py).is_callable());
|
||||
if is_callable {
|
||||
// Callback: call with a dict of progress info.
|
||||
op = op.progress(move |p| {
|
||||
Python::attach(|py| {
|
||||
let dict = PyDict::new(py);
|
||||
if let Err(e) = dict
|
||||
.set_item("output_rows", p.output_rows())
|
||||
.and_then(|_| dict.set_item("output_bytes", p.output_bytes()))
|
||||
.and_then(|_| dict.set_item("total_rows", p.total_rows()))
|
||||
.and_then(|_| {
|
||||
dict.set_item("elapsed_seconds", p.elapsed().as_secs_f64())
|
||||
})
|
||||
.and_then(|_| dict.set_item("active_tasks", p.active_tasks()))
|
||||
.and_then(|_| dict.set_item("total_tasks", p.total_tasks()))
|
||||
.and_then(|_| dict.set_item("done", p.done()))
|
||||
{
|
||||
log::warn!("progress dict error: {e}");
|
||||
return;
|
||||
}
|
||||
if let Err(e) = progress_obj.call1(py, (dict,)) {
|
||||
log::warn!("progress callback error: {e}");
|
||||
}
|
||||
});
|
||||
});
|
||||
} else {
|
||||
// tqdm-like: has update() method.
|
||||
let mut last_rows: usize = 0;
|
||||
let mut total_set = false;
|
||||
op = op.progress(move |p| {
|
||||
let current = p.output_rows();
|
||||
let prev = last_rows;
|
||||
last_rows = current;
|
||||
Python::attach(|py| {
|
||||
if let Some(total) = p.total_rows()
|
||||
&& !total_set
|
||||
{
|
||||
if let Err(e) = progress_obj.setattr(py, "total", total) {
|
||||
log::warn!("progress setattr error: {e}");
|
||||
}
|
||||
total_set = true;
|
||||
}
|
||||
let delta = current.saturating_sub(prev);
|
||||
if delta > 0 {
|
||||
if let Err(e) = progress_obj.call_method1(py, "update", (delta,)) {
|
||||
log::warn!("progress update error: {e}");
|
||||
}
|
||||
// Show throughput and active workers in tqdm postfix.
|
||||
let elapsed = p.elapsed().as_secs_f64();
|
||||
if elapsed > 0.0 {
|
||||
let mb_per_sec = p.output_bytes() as f64 / elapsed / 1_000_000.0;
|
||||
let postfix = format!(
|
||||
"{:.1} MB/s | {}/{} workers",
|
||||
mb_per_sec,
|
||||
p.active_tasks(),
|
||||
p.total_tasks()
|
||||
);
|
||||
if let Err(e) =
|
||||
progress_obj.call_method1(py, "set_postfix_str", (postfix,))
|
||||
{
|
||||
log::warn!("progress set_postfix_str error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
if p.done() {
|
||||
// Force a final refresh so the bar shows completion.
|
||||
if let Err(e) = progress_obj.call_method0(py, "refresh") {
|
||||
log::warn!("progress refresh error: {e}");
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
let result = op.execute().await.infer_error()?;
|
||||
|
||||
Reference in New Issue
Block a user