From 1d6e00b902210c08f615304d3dd970114c85151f Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 23 Mar 2026 16:14:13 -0700 Subject: [PATCH] feat: progress bar for `add()` (#3067) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Adds progress reporting for `table.add()` so users can track large write operations. The progress callback is available in Rust, Python (sync and async), and through the PyO3 bindings. ### Usage Pass `progress=True` to get an automatic tqdm bar: ```python table.add(data, progress=True) # 100%|██████████| 1000000/1000000 [00:12<00:00, 82345 rows/s, 45.2 MB/s | 4/4 workers] ``` Or pass a tqdm bar for more control: ```python from tqdm import tqdm with tqdm(unit=" rows") as pbar: table.add(data, progress=pbar) ``` Or use a callback for custom progress handling: ```python def on_progress(p): print(f"{p['output_rows']}/{p['total_rows']} rows, " f"{p['active_tasks']}/{p['total_tasks']} workers, " f"done={p['done']}") table.add(data, progress=on_progress) ``` In Rust: ```rust table.add(data) .progress(|p| println!("{}/{:?} rows", p.output_rows(), p.total_rows())) .execute() .await?; ``` ### Details - `WriteProgress` struct in Rust with getters for `elapsed`, `output_rows`, `output_bytes`, `total_rows`, `active_tasks`, `total_tasks`, and `done`. Fields are private behind getters so new fields can be added without breaking changes. - `WriteProgressTracker` tracks progress across parallel write tasks using a mutex for row/byte counts and atomics for active task counts. - Active task tracking uses an RAII guard pattern (`ActiveTaskGuard`) that increments on creation and decrements on drop. - For remote writes, `output_bytes` reflects IPC wire bytes rather than in-memory Arrow size. For local writes it uses in-memory Arrow size as a proxy (see TODO below). - tqdm postfix displays throughput (MB/s) and worker utilization (active/total). - The `done` callback always fires, even on error (via `FinishOnDrop`), so progress bars are always finalized. ### TODO - Track actual bytes written to disk for local tables. This requires Lance to expose a progress callback from its write path. See lance-format/lance#6247. 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 --- Cargo.lock | 1 + python/Cargo.toml | 1 + python/python/lancedb/_lancedb.pyi | 5 +- python/python/lancedb/remote/table.py | 24 +- python/python/lancedb/table.py | 71 +++- python/python/tests/test_table.py | 96 +++++ python/src/table.rs | 79 +++- rust/lancedb/src/remote/table.rs | 90 +++++ rust/lancedb/src/remote/table/insert.rs | 108 +++-- rust/lancedb/src/table.rs | 9 + rust/lancedb/src/table/add_data.rs | 33 +- rust/lancedb/src/table/datafusion/insert.rs | 22 + .../src/table/datafusion/scannable_exec.rs | 24 +- rust/lancedb/src/table/write_progress.rs | 379 ++++++++++++++++++ 14 files changed, 894 insertions(+), 48 deletions(-) create mode 100644 rust/lancedb/src/table/write_progress.rs diff --git a/Cargo.lock b/Cargo.lock index a5ef703b1..0b9dce9a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4861,6 +4861,7 @@ dependencies = [ "lance-namespace", "lance-namespace-impls", "lancedb", + "log", "pin-project", "pyo3", "pyo3-async-runtimes", diff --git a/python/Cargo.toml b/python/Cargo.toml index c656d8d6d..9578f6d42 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -23,6 +23,7 @@ lance-namespace.workspace = true lance-namespace-impls.workspace = true lance-io.workspace = true env_logger.workspace = true +log.workspace = true pyo3 = { version = "0.26", features = ["extension-module", "abi3-py39"] } pyo3-async-runtimes = { version = "0.26", features = [ "attributes", diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index c5b35c945..0d4378b10 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -135,7 +135,10 @@ class Table: def close(self) -> None: ... async def schema(self) -> pa.Schema: ... async def add( - self, data: pa.RecordBatchReader, mode: Literal["append", "overwrite"] + self, + data: pa.RecordBatchReader, + mode: Literal["append", "overwrite"], + progress: Optional[Any] = None, ) -> AddResult: ... async def update( self, updates: Dict[str, str], where: Optional[str] diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 4dd5b428f..905e1481a 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -4,7 +4,7 @@ from datetime import timedelta import logging from functools import cached_property -from typing import Dict, Iterable, List, Optional, Union, Literal +from typing import Any, Callable, Dict, Iterable, List, Optional, Union, Literal import warnings from lancedb._lancedb import ( @@ -35,6 +35,7 @@ import pyarrow as pa from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME from lancedb.merge import LanceMergeInsertBuilder from lancedb.embeddings import EmbeddingFunctionRegistry +from lancedb.table import _normalize_progress from ..query import LanceVectorQueryBuilder, LanceQueryBuilder, LanceTakeQueryBuilder from ..table import AsyncTable, IndexStatistics, Query, Table, Tags @@ -308,6 +309,7 @@ class RemoteTable(Table): mode: str = "append", on_bad_vectors: str = "error", fill_value: float = 0.0, + progress: Optional[Union[bool, Callable, Any]] = None, ) -> AddResult: """Add more data to the [Table](Table). It has the same API signature as the OSS version. @@ -330,17 +332,29 @@ class RemoteTable(Table): One of "error", "drop", "fill". fill_value: float, default 0. The value to use when filling vectors. Only used if on_bad_vectors="fill". + progress: bool, callable, or tqdm-like, optional + A callback or tqdm-compatible progress bar. See + :meth:`Table.add` for details. Returns ------- AddResult An object containing the new version number of the table after adding data. """ - return LOOP.run( - self._table.add( - data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value + progress, owns = _normalize_progress(progress) + try: + return LOOP.run( + self._table.add( + data, + mode=mode, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + progress=progress, + ) ) - ) + finally: + if owns: + progress.close() def search( self, diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index e4bf24577..48918222d 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -14,6 +14,7 @@ from functools import cached_property from typing import ( TYPE_CHECKING, Any, + Callable, Dict, Iterable, List, @@ -556,6 +557,21 @@ def _table_uri(base: str, table_name: str) -> str: return join_uri(base, f"{table_name}.lance") +def _normalize_progress(progress): + """Normalize a ``progress`` parameter for :meth:`Table.add`. + + Returns ``(progress_obj, owns)`` where *owns* is True when we created a + tqdm bar that the caller must close. + """ + if progress is True: + from tqdm.auto import tqdm + + return tqdm(unit=" rows"), True + if progress is False or progress is None: + return None, False + return progress, False + + class Table(ABC): """ A Table is a collection of Records in a LanceDB Database. @@ -974,6 +990,7 @@ class Table(ABC): mode: AddMode = "append", on_bad_vectors: OnBadVectorsType = "error", fill_value: float = 0.0, + progress: Optional[Union[bool, Callable, Any]] = None, ) -> AddResult: """Add more data to the [Table](Table). @@ -995,6 +1012,29 @@ class Table(ABC): One of "error", "drop", "fill". fill_value: float, default 0. The value to use when filling vectors. Only used if on_bad_vectors="fill". + progress: bool, callable, or tqdm-like, optional + Progress reporting during the add operation. Can be: + + - ``True`` to automatically create and display a tqdm progress + bar (requires ``tqdm`` to be installed):: + + table.add(data, progress=True) + + - A **callable** that receives a dict with keys ``output_rows``, + ``output_bytes``, ``total_rows``, ``elapsed_seconds``, + ``active_tasks``, ``total_tasks``, and ``done``:: + + def on_progress(p): + print(f"{p['output_rows']}/{p['total_rows']} rows, " + f"{p['active_tasks']}/{p['total_tasks']} workers") + table.add(data, progress=on_progress) + + - A **tqdm-compatible** progress bar whose ``total`` and + ``update()`` will be called automatically. The postfix shows + write throughput (MB/s) and active worker count:: + + with tqdm() as pbar: + table.add(data, progress=pbar) Returns ------- @@ -2492,6 +2532,7 @@ class LanceTable(Table): mode: AddMode = "append", on_bad_vectors: OnBadVectorsType = "error", fill_value: float = 0.0, + progress: Optional[Union[bool, Callable, Any]] = None, ) -> AddResult: """Add data to the table. If vector columns are missing and the table @@ -2510,17 +2551,29 @@ class LanceTable(Table): One of "error", "drop", "fill", "null". fill_value: float, default 0. The value to use when filling vectors. Only used if on_bad_vectors="fill". + progress: bool, callable, or tqdm-like, optional + A callback or tqdm-compatible progress bar. See + :meth:`Table.add` for details. Returns ------- int The number of vectors in the table. """ - return LOOP.run( - self._table.add( - data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value + progress, owns = _normalize_progress(progress) + try: + return LOOP.run( + self._table.add( + data, + mode=mode, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + progress=progress, + ) ) - ) + finally: + if owns: + progress.close() def merge( self, @@ -3769,6 +3822,7 @@ class AsyncTable: mode: Optional[Literal["append", "overwrite"]] = "append", on_bad_vectors: Optional[OnBadVectorsType] = None, fill_value: Optional[float] = None, + progress: Optional[Union[bool, Callable, Any]] = None, ) -> AddResult: """Add more data to the [Table](Table). @@ -3790,6 +3844,9 @@ class AsyncTable: One of "error", "drop", "fill", "null". fill_value: float, default 0. The value to use when filling vectors. Only used if on_bad_vectors="fill". + progress: callable or tqdm-like, optional + A callback or tqdm-compatible progress bar. See + :meth:`Table.add` for details. """ schema = await self.schema() @@ -3813,8 +3870,9 @@ class AsyncTable: ) _register_optional_converters() data = to_scannable(data) + progress, owns = _normalize_progress(progress) try: - return await self._inner.add(data, mode or "append") + return await self._inner.add(data, mode or "append", progress=progress) except RuntimeError as e: if "Cast error" in str(e): raise ValueError(e) @@ -3822,6 +3880,9 @@ class AsyncTable: raise ValueError(e) else: raise + finally: + if owns: + progress.close() def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: """ diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index eaa1eb7af..fc1788ee0 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -527,6 +527,102 @@ async def test_add_async(mem_db_async: AsyncConnection): assert await table.count_rows() == 3 +def test_add_progress_callback(mem_db: DBConnection): + table = mem_db.create_table( + "test", + data=[{"id": 1}, {"id": 2}], + ) + + updates = [] + table.add([{"id": 3}, {"id": 4}], progress=lambda p: updates.append(dict(p))) + + assert len(table) == 4 + # The done callback always fires, so we should always get at least one. + assert len(updates) >= 1, "expected at least one progress callback" + for p in updates: + assert "output_rows" in p + assert "output_bytes" in p + assert "total_rows" in p + assert "elapsed_seconds" in p + assert "active_tasks" in p + assert "total_tasks" in p + assert "done" in p + # The last callback should have done=True. + assert updates[-1]["done"] is True + + +def test_add_progress_tqdm_like(mem_db: DBConnection): + """Test that a tqdm-like object gets total set and update() called.""" + + class FakeBar: + def __init__(self): + self.total = None + self.n = 0 + self.postfix = None + + def update(self, n): + self.n += n + + def set_postfix_str(self, s): + self.postfix = s + + def refresh(self): + pass + + table = mem_db.create_table( + "test", + data=[{"id": 1}, {"id": 2}], + ) + + bar = FakeBar() + table.add([{"id": 3}, {"id": 4}], progress=bar) + + assert len(table) == 4 + # Postfix should contain throughput and worker count + if bar.postfix is not None: + assert "MB/s" in bar.postfix + assert "workers" in bar.postfix + + +def test_add_progress_bool(mem_db: DBConnection): + """Test that progress=True creates and closes a tqdm bar automatically.""" + table = mem_db.create_table( + "test", + data=[{"id": 1}, {"id": 2}], + ) + + table.add([{"id": 3}, {"id": 4}], progress=True) + assert len(table) == 4 + + # progress=False should be the same as None + table.add([{"id": 5}], progress=False) + assert len(table) == 5 + + +@pytest.mark.asyncio +async def test_add_progress_callback_async(mem_db_async: AsyncConnection): + """Progress callbacks work through the async path too.""" + table = await mem_db_async.create_table("test", data=[{"id": 1}, {"id": 2}]) + + updates = [] + await table.add([{"id": 3}, {"id": 4}], progress=lambda p: updates.append(dict(p))) + + assert await table.count_rows() == 4 + assert len(updates) >= 1 + assert updates[-1]["done"] is True + + +def test_add_progress_callback_error(mem_db: DBConnection): + """A failing callback must not prevent the write from succeeding.""" + table = mem_db.create_table("test", data=[{"id": 1}, {"id": 2}]) + + def bad_callback(p): + raise RuntimeError("boom") + + table.add([{"id": 3}, {"id": 4}], progress=bad_callback) + assert len(table) == 4 + + def test_polars(mem_db: DBConnection): data = { "vector": [[3.1, 4.1], [5.9, 26.5]], diff --git a/python/src/table.rs b/python/src/table.rs index 00015bba8..d44b6c1fd 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -19,7 +19,7 @@ use lancedb::table::{ Table as LanceDbTable, }; use pyo3::{ - Bound, FromPyObject, PyAny, PyRef, PyResult, Python, + Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python, exceptions::{PyKeyError, PyRuntimeError, PyValueError}, pyclass, pymethods, types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods}, @@ -299,10 +299,12 @@ impl Table { }) } + #[pyo3(signature = (data, mode, progress=None))] pub fn add<'a>( self_: PyRef<'a, Self>, data: PyScannable, mode: String, + progress: Option>, ) -> PyResult> { let mut op = self_.inner_ref()?.add(data); if mode == "append" { @@ -312,6 +314,81 @@ impl Table { } else { return Err(PyValueError::new_err(format!("Invalid mode: {}", mode))); } + if let Some(progress_obj) = progress { + let is_callable = Python::attach(|py| progress_obj.bind(py).is_callable()); + if is_callable { + // Callback: call with a dict of progress info. + op = op.progress(move |p| { + Python::attach(|py| { + let dict = PyDict::new(py); + if let Err(e) = dict + .set_item("output_rows", p.output_rows()) + .and_then(|_| dict.set_item("output_bytes", p.output_bytes())) + .and_then(|_| dict.set_item("total_rows", p.total_rows())) + .and_then(|_| { + dict.set_item("elapsed_seconds", p.elapsed().as_secs_f64()) + }) + .and_then(|_| dict.set_item("active_tasks", p.active_tasks())) + .and_then(|_| dict.set_item("total_tasks", p.total_tasks())) + .and_then(|_| dict.set_item("done", p.done())) + { + log::warn!("progress dict error: {e}"); + return; + } + if let Err(e) = progress_obj.call1(py, (dict,)) { + log::warn!("progress callback error: {e}"); + } + }); + }); + } else { + // tqdm-like: has update() method. + let mut last_rows: usize = 0; + let mut total_set = false; + op = op.progress(move |p| { + let current = p.output_rows(); + let prev = last_rows; + last_rows = current; + Python::attach(|py| { + if let Some(total) = p.total_rows() + && !total_set + { + if let Err(e) = progress_obj.setattr(py, "total", total) { + log::warn!("progress setattr error: {e}"); + } + total_set = true; + } + let delta = current.saturating_sub(prev); + if delta > 0 { + if let Err(e) = progress_obj.call_method1(py, "update", (delta,)) { + log::warn!("progress update error: {e}"); + } + // Show throughput and active workers in tqdm postfix. + let elapsed = p.elapsed().as_secs_f64(); + if elapsed > 0.0 { + let mb_per_sec = p.output_bytes() as f64 / elapsed / 1_000_000.0; + let postfix = format!( + "{:.1} MB/s | {}/{} workers", + mb_per_sec, + p.active_tasks(), + p.total_tasks() + ); + if let Err(e) = + progress_obj.call_method1(py, "set_postfix_str", (postfix,)) + { + log::warn!("progress set_postfix_str error: {e}"); + } + } + } + if p.done() { + // Force a final refresh so the bar shows completion. + if let Err(e) = progress_obj.call_method0(py, "refresh") { + log::warn!("progress refresh error: {e}"); + } + } + }); + }); + } + } future_into_py(self_.py(), async move { let result = op.execute().await.infer_error()?; diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 244952a17..c9b807505 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -5,6 +5,7 @@ pub mod insert; use self::insert::RemoteInsertExec; use crate::expr::expr_to_sql_string; +use crate::table::write_progress::FinishOnDrop; use super::ARROW_STREAM_CONTENT_TYPE; use super::client::RequestResultExt; @@ -939,12 +940,15 @@ impl RemoteTable { async fn add_single_partition(&self, output: PreprocessingOutput) -> Result { use crate::remote::retry::RetryCounter; + let _guard = output.tracker.as_ref().map(|t| t.track_task()); + let mut insert: Arc = Arc::new(RemoteInsertExec::new( self.name.clone(), self.identifier.clone(), self.client.clone(), output.plan, output.overwrite, + output.tracker.clone(), )); let mut retry_counter = @@ -1045,6 +1049,11 @@ impl RemoteTable { output: &PreprocessingOutput, num_partitions: usize, ) -> Result<()> { + debug_assert!( + output.rescannable, + "multipart inserts require rescannable input for retry support" + ); + let plan = Arc::new( datafusion_physical_plan::repartition::RepartitionExec::try_new( output.plan.clone(), @@ -1059,14 +1068,18 @@ impl RemoteTable { plan, output.overwrite, upload_id.to_string(), + output.tracker.clone(), )); let task_ctx = Arc::new(datafusion_execution::TaskContext::default()); + let tracker = output.tracker.clone(); let mut join_set = tokio::task::JoinSet::new(); for partition in 0..num_partitions { let exec = insert.clone(); let ctx = task_ctx.clone(); + let tracker = tracker.clone(); join_set.spawn(async move { + let _guard = tracker.as_ref().map(|t| t.track_task()); let mut stream = exec .execute(partition, ctx) .map_err(|e| -> Error { e.into() })?; @@ -1273,6 +1286,11 @@ impl BaseTable for RemoteTable { let output = add.into_plan(&table_schema, &table_def)?; + if let Some(ref t) = output.tracker { + t.set_total_tasks(num_partitions); + } + let _finish = FinishOnDrop(output.tracker.clone()); + if num_partitions > 1 { self.add_multipart(output, num_partitions).await } else { @@ -1975,6 +1993,7 @@ impl BaseTable for RemoteTable { self.client.clone(), input, overwrite, + None, ))) } } @@ -5170,6 +5189,77 @@ mod tests { ); } + #[tokio::test] + async fn test_multipart_write_progress() { + let callback_count = Arc::new(AtomicUsize::new(0)); + let max_active = Arc::new(AtomicUsize::new(0)); + let last_total_tasks = Arc::new(AtomicUsize::new(0)); + let seen_done = Arc::new(std::sync::Mutex::new(false)); + + let cb_count = callback_count.clone(); + let cb_active = max_active.clone(); + let cb_total = last_total_tasks.clone(); + let cb_done = seen_done.clone(); + + let table = Table::new_with_handler_version( + "my_table", + semver::Version::new(0, 4, 0), + move |request| { + let path = request.url().path(); + + if path == "/v1/table/my_table/describe/" { + return simple_describe_response(); + } + if path == "/v1/table/my_table/multipart_write/create" { + return http::Response::builder() + .status(200) + .body(r#"{"upload_id": "prog-upload"}"#.to_string()) + .unwrap(); + } + if path == "/v1/table/my_table/insert/" { + return http::Response::builder() + .status(200) + .body(r#"{"version": 1}"#.to_string()) + .unwrap(); + } + if path == "/v1/table/my_table/multipart_write/complete" { + return http::Response::builder() + .status(200) + .body(r#"{"version": 3}"#.to_string()) + .unwrap(); + } + panic!("Unexpected request path: {}", path); + }, + ); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + table + .add(vec![batch]) + .write_parallelism(2) + .progress(move |p| { + cb_count.fetch_add(1, Ordering::SeqCst); + cb_active.fetch_max(p.active_tasks(), Ordering::SeqCst); + cb_total.store(p.total_tasks(), Ordering::SeqCst); + if p.done() { + *cb_done.lock().unwrap() = true; + } + }) + .execute() + .await + .unwrap(); + + assert!( + callback_count.load(Ordering::SeqCst) >= 1, + "expected at least one progress callback" + ); + assert!(*seen_done.lock().unwrap(), "must see done=true"); + assert_eq!(last_total_tasks.load(Ordering::SeqCst), 2); + assert!( + max_active.load(Ordering::SeqCst) >= 1, + "expected at least one active task" + ); + } + #[tokio::test] async fn test_multipart_write_fallback_old_server() { let insert_count = Arc::new(AtomicUsize::new(0)); diff --git a/rust/lancedb/src/remote/table/insert.rs b/rust/lancedb/src/remote/table/insert.rs index bc13010c4..d7d30a680 100644 --- a/rust/lancedb/src/remote/table/insert.rs +++ b/rust/lancedb/src/remote/table/insert.rs @@ -11,12 +11,14 @@ use arrow_ipc::CompressionType; use datafusion_common::{DataFusionError, Result as DataFusionResult}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, }; use futures::StreamExt; use http::header::CONTENT_TYPE; +use lance::io::exec::utils::InstrumentedRecordBatchStreamAdapter; use crate::Error; use crate::remote::ARROW_STREAM_CONTENT_TYPE; @@ -24,6 +26,7 @@ use crate::remote::client::{HttpSend, RestfulLanceDbClient, Sender}; use crate::remote::table::RemoteTable; use crate::table::AddResult; use crate::table::datafusion::insert::COUNT_SCHEMA; +use crate::table::write_progress::WriteProgressTracker; /// ExecutionPlan for inserting data into a remote LanceDB table. /// @@ -42,7 +45,9 @@ pub struct RemoteInsertExec { overwrite: bool, properties: PlanProperties, add_result: Arc>>, + metrics: ExecutionPlanMetricsSet, upload_id: Option, + tracker: Option>, } impl RemoteInsertExec { @@ -53,8 +58,11 @@ impl RemoteInsertExec { client: RestfulLanceDbClient, input: Arc, overwrite: bool, + tracker: Option>, ) -> Self { - Self::new_inner(table_name, identifier, client, input, overwrite, None) + Self::new_inner( + table_name, identifier, client, input, overwrite, None, tracker, + ) } /// Create a multi-partition RemoteInsertExec for use with multipart writes. @@ -69,6 +77,7 @@ impl RemoteInsertExec { input: Arc, overwrite: bool, upload_id: String, + tracker: Option>, ) -> Self { Self::new_inner( table_name, @@ -77,6 +86,7 @@ impl RemoteInsertExec { input, overwrite, Some(upload_id), + tracker, ) } @@ -87,6 +97,7 @@ impl RemoteInsertExec { input: Arc, overwrite: bool, upload_id: Option, + tracker: Option>, ) -> Self { let num_partitions = if upload_id.is_some() { input.output_partitioning().partition_count() @@ -109,7 +120,9 @@ impl RemoteInsertExec { overwrite, properties, add_result: Arc::new(Mutex::new(None)), + metrics: ExecutionPlanMetricsSet::new(), upload_id, + tracker, } } @@ -128,6 +141,7 @@ impl RemoteInsertExec { fn stream_as_http_body( data: SendableRecordBatchStream, error_tx: tokio::sync::oneshot::Sender, + tracker: Option>, ) -> DataFusionResult { let options = arrow_ipc::writer::IpcWriteOptions::default() .try_with_compression(Some(CompressionType::LZ4_FRAME))?; @@ -139,37 +153,46 @@ impl RemoteInsertExec { let stream = futures::stream::try_unfold( (data, writer, Some(error_tx), false), - move |(mut data, mut writer, error_tx, finished)| async move { - if finished { - return Ok(None); - } - match data.next().await { - Some(Ok(batch)) => { - writer - .write(&batch) - .map_err(|e| std::io::Error::other(e.to_string()))?; - let buffer = std::mem::take(writer.get_mut()); - Ok(Some((buffer, (data, writer, error_tx, false)))) + move |(mut data, mut writer, error_tx, finished)| { + let tracker = tracker.clone(); + async move { + if finished { + return Ok(None); } - Some(Err(e)) => { - // Send the original error through the channel before - // returning a generic error to reqwest. - if let Some(tx) = error_tx { - let _ = tx.send(e); + match data.next().await { + Some(Ok(batch)) => { + writer + .write(&batch) + .map_err(|e| std::io::Error::other(e.to_string()))?; + let buffer = std::mem::take(writer.get_mut()); + if let Some(ref t) = tracker { + t.record_bytes(buffer.len()); + } + Ok(Some((buffer, (data, writer, error_tx, false)))) } - Err(std::io::Error::other( - "input stream error (see error channel)", - )) - } - None => { - writer - .finish() - .map_err(|e| std::io::Error::other(e.to_string()))?; - let buffer = std::mem::take(writer.get_mut()); - if buffer.is_empty() { - Ok(None) - } else { - Ok(Some((buffer, (data, writer, None, true)))) + Some(Err(e)) => { + // Send the original error through the channel before + // returning a generic error to reqwest. + if let Some(tx) = error_tx { + let _ = tx.send(e); + } + Err(std::io::Error::other( + "input stream error (see error channel)", + )) + } + None => { + writer + .finish() + .map_err(|e| std::io::Error::other(e.to_string()))?; + let buffer = std::mem::take(writer.get_mut()); + if buffer.is_empty() { + Ok(None) + } else { + if let Some(ref t) = tracker { + t.record_bytes(buffer.len()); + } + Ok(Some((buffer, (data, writer, None, true)))) + } } } } @@ -246,6 +269,7 @@ impl ExecutionPlan for RemoteInsertExec { children[0].clone(), self.overwrite, self.upload_id.clone(), + self.tracker.clone(), ))) } @@ -262,12 +286,21 @@ impl ExecutionPlan for RemoteInsertExec { } let input_stream = self.input.execute(partition, context)?; + let input_schema = input_stream.schema(); + let input_stream: SendableRecordBatchStream = + Box::pin(InstrumentedRecordBatchStreamAdapter::new( + input_schema, + input_stream, + partition, + &self.metrics, + )); let client = self.client.clone(); let identifier = self.identifier.clone(); let overwrite = self.overwrite; let add_result = self.add_result.clone(); let table_name = self.table_name.clone(); let upload_id = self.upload_id.clone(); + let tracker = self.tracker.clone(); let stream = futures::stream::once(async move { let mut request = client @@ -282,7 +315,7 @@ impl ExecutionPlan for RemoteInsertExec { } let (error_tx, mut error_rx) = tokio::sync::oneshot::channel(); - let body = Self::stream_as_http_body(input_stream, error_tx)?; + let body = Self::stream_as_http_body(input_stream, error_tx, tracker)?; let request = request.body(body); let result: DataFusionResult<(String, _)> = async { @@ -344,6 +377,15 @@ impl ExecutionPlan for RemoteInsertExec { DataFusionError::Execution("Failed to acquire lock for add_result".to_string()) })?; *res_lock = Some(parsed_result); + } else { + // We don't use the body in this case, but we should still consume it. + let _ = response.bytes().await.map_err(|e| { + DataFusionError::External(Box::new(Error::Http { + source: Box::new(e), + request_id: request_id.clone(), + status_code: None, + })) + })?; } // Return a single batch with count 0 (actual count is tracked in add_result) @@ -357,6 +399,10 @@ impl ExecutionPlan for RemoteInsertExec { stream, ))) } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } } #[cfg(test)] diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index c513f6593..7eac7463a 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -74,6 +74,7 @@ pub mod optimize; pub mod query; pub mod schema_evolution; pub mod update; +pub mod write_progress; use crate::index::waiter::wait_for_index; #[cfg(feature = "remote")] pub(crate) use add_data::PreprocessingOutput; @@ -2276,13 +2277,21 @@ impl BaseTable for NativeTable { let insert_exec = Arc::new(InsertExec::new(ds_wrapper.clone(), ds, plan, lance_params)); + let tracker_for_tasks = output.tracker.clone(); + if let Some(ref t) = tracker_for_tasks { + t.set_total_tasks(num_partitions); + } + let _finish = write_progress::FinishOnDrop(output.tracker); + // Execute all partitions in parallel. let task_ctx = Arc::new(TaskContext::default()); let handles = FuturesUnordered::new(); for partition in 0..num_partitions { let exec = insert_exec.clone(); let ctx = task_ctx.clone(); + let tracker = tracker_for_tasks.clone(); handles.push(tokio::spawn(async move { + let _guard = tracker.as_ref().map(|t| t.track_task()); let mut stream = exec .execute(partition, ctx) .map_err(|e| -> Error { e.into() })?; diff --git a/rust/lancedb/src/table/add_data.rs b/rust/lancedb/src/table/add_data.rs index bbafd6ce2..1c4b4bdf3 100644 --- a/rust/lancedb/src/table/add_data.rs +++ b/rust/lancedb/src/table/add_data.rs @@ -13,6 +13,9 @@ use crate::embeddings::EmbeddingRegistry; use crate::table::datafusion::cast::cast_to_table_schema; use crate::table::datafusion::reject_nan::reject_nan_vectors; use crate::table::datafusion::scannable_exec::ScannableExec; +use crate::table::write_progress::ProgressCallback; +use crate::table::write_progress::WriteProgress; +use crate::table::write_progress::WriteProgressTracker; use crate::{Error, Result}; use super::{BaseTable, TableDefinition, WriteOptions}; @@ -52,6 +55,7 @@ pub struct AddDataBuilder { pub(crate) write_options: WriteOptions, pub(crate) on_nan_vectors: NaNVectorBehavior, pub(crate) embedding_registry: Option>, + pub(crate) progress_callback: Option, pub(crate) write_parallelism: Option, } @@ -78,6 +82,7 @@ impl AddDataBuilder { write_options: WriteOptions::default(), on_nan_vectors: NaNVectorBehavior::default(), embedding_registry, + progress_callback: None, write_parallelism: None, } } @@ -103,6 +108,27 @@ impl AddDataBuilder { self } + /// Set a callback to receive progress updates during the add operation. + /// + /// The callback is invoked once per batch written, and once more with + /// [`WriteProgress::done`] set to `true` when the write completes. + /// + /// ``` + /// # use lancedb::Table; + /// # async fn example(table: &Table) -> Result<(), Box> { + /// let batch = arrow_array::record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + /// table.add(batch) + /// .progress(|p| println!("{}/{:?} rows", p.output_rows(), p.total_rows())) + /// .execute() + /// .await?; + /// # Ok(()) + /// # } + /// ``` + pub fn progress(mut self, callback: impl FnMut(&WriteProgress) + Send + 'static) -> Self { + self.progress_callback = Some(Arc::new(std::sync::Mutex::new(callback))); + self + } + /// Set the number of parallel write streams. /// /// By default, the number of streams is estimated from the data size. @@ -147,8 +173,11 @@ impl AddDataBuilder { scannable_with_embeddings(self.data, table_def, self.embedding_registry.as_ref())?; let rescannable = self.data.rescannable(); + let tracker = self + .progress_callback + .map(|cb| Arc::new(WriteProgressTracker::new(cb, self.data.num_rows()))); let plan: Arc = - Arc::new(ScannableExec::new(self.data)); + Arc::new(ScannableExec::new(self.data, tracker.clone())); // Skip casting when overwriting — the input schema replaces the table schema. let plan = if overwrite { plan @@ -166,6 +195,7 @@ impl AddDataBuilder { rescannable, write_options: self.write_options, mode: self.mode, + tracker, }) } } @@ -178,6 +208,7 @@ pub struct PreprocessingOutput { pub rescannable: bool, pub write_options: WriteOptions, pub mode: AddDataMode, + pub tracker: Option>, } /// Check that the input schema is valid for insert. diff --git a/rust/lancedb/src/table/datafusion/insert.rs b/rust/lancedb/src/table/datafusion/insert.rs index 4c3d66195..4dce78788 100644 --- a/rust/lancedb/src/table/datafusion/insert.rs +++ b/rust/lancedb/src/table/datafusion/insert.rs @@ -12,13 +12,16 @@ use datafusion_common::{DataFusionError, Result as DataFusionResult}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, }; +use futures::TryStreamExt; use lance::Dataset; use lance::dataset::transaction::{Operation, Transaction}; use lance::dataset::{CommitBuilder, InsertBuilder, WriteParams}; +use lance::io::exec::utils::InstrumentedRecordBatchStreamAdapter; use lance_table::format::Fragment; use crate::table::dataset::DatasetConsistencyWrapper; @@ -80,6 +83,7 @@ pub struct InsertExec { write_params: WriteParams, properties: PlanProperties, partial_transactions: Arc>>, + metrics: ExecutionPlanMetricsSet, } impl InsertExec { @@ -105,6 +109,7 @@ impl InsertExec { write_params, properties, partial_transactions: Arc::new(Mutex::new(Vec::with_capacity(num_partitions))), + metrics: ExecutionPlanMetricsSet::new(), } } } @@ -176,6 +181,19 @@ impl ExecutionPlan for InsertExec { let total_partitions = self.input.output_partitioning().partition_count(); let ds_wrapper = self.ds_wrapper.clone(); + let output_bytes = MetricBuilder::new(&self.metrics).output_bytes(partition); + let input_schema = input_stream.schema(); + let input_stream: SendableRecordBatchStream = + Box::pin(InstrumentedRecordBatchStreamAdapter::new( + input_schema, + input_stream.map_ok(move |batch| { + output_bytes.add(batch.get_array_memory_size()); + batch + }), + partition, + &self.metrics, + )); + let stream = futures::stream::once(async move { let transaction = InsertBuilder::new(dataset.clone()) .with_params(&write_params) @@ -215,6 +233,10 @@ impl ExecutionPlan for InsertExec { stream, ))) } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } } #[cfg(test)] diff --git a/rust/lancedb/src/table/datafusion/scannable_exec.rs b/rust/lancedb/src/table/datafusion/scannable_exec.rs index eb128ac18..a55b6e13f 100644 --- a/rust/lancedb/src/table/datafusion/scannable_exec.rs +++ b/rust/lancedb/src/table/datafusion/scannable_exec.rs @@ -7,17 +7,21 @@ use std::sync::{Arc, Mutex}; use datafusion_common::{DataFusionError, Result as DFResult, Statistics, stats::Precision}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, execution_plan::EmissionType, }; +use futures::TryStreamExt; +use crate::table::write_progress::WriteProgressTracker; use crate::{arrow::SendableRecordBatchStreamExt, data::scannable::Scannable}; -pub struct ScannableExec { - // We don't require Scannable to by Sync, so we wrap it in a Mutex to allow safe concurrent access. +pub(crate) struct ScannableExec { + // We don't require Scannable to be Sync, so we wrap it in a Mutex to allow safe concurrent access. source: Mutex>, num_rows: Option, properties: PlanProperties, + tracker: Option>, } impl std::fmt::Debug for ScannableExec { @@ -30,7 +34,7 @@ impl std::fmt::Debug for ScannableExec { } impl ScannableExec { - pub fn new(source: Box) -> Self { + pub fn new(source: Box, tracker: Option>) -> Self { let schema = source.schema(); let eq_properties = EquivalenceProperties::new(schema); let properties = PlanProperties::new( @@ -46,6 +50,7 @@ impl ScannableExec { source, num_rows, properties, + tracker, } } } @@ -102,7 +107,18 @@ impl ExecutionPlan for ScannableExec { Err(poison) => poison.into_inner().scan_as_stream(), }; - Ok(stream.into_df_stream()) + let tracker = self.tracker.clone(); + let stream = stream.into_df_stream().map_ok(move |batch| { + if let Some(ref t) = tracker { + t.record_batch(batch.num_rows(), batch.get_array_memory_size()); + } + batch + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream, + ))) } fn partition_statistics(&self, _partition: Option) -> DFResult { diff --git a/rust/lancedb/src/table/write_progress.rs b/rust/lancedb/src/table/write_progress.rs new file mode 100644 index 000000000..bf1b513ae --- /dev/null +++ b/rust/lancedb/src/table/write_progress.rs @@ -0,0 +1,379 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! Progress monitoring for write operations. +//! +//! You can add a callback to process progress in [`crate::table::AddDataBuilder::progress`]. +//! [`WriteProgress`] is the struct passed to the callback. + +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +/// Progress snapshot for a write operation. +#[derive(Debug, Clone)] +pub struct WriteProgress { + // These are private and only accessible via getters, to make it easy to add + // new fields without breaking existing callbacks. + elapsed: Duration, + output_rows: usize, + output_bytes: usize, + total_rows: Option, + active_tasks: usize, + total_tasks: usize, + done: bool, +} + +impl WriteProgress { + /// Wall-clock time since monitoring started. + pub fn elapsed(&self) -> Duration { + self.elapsed + } + + /// Number of rows written so far. + pub fn output_rows(&self) -> usize { + self.output_rows + } + + /// Number of bytes written so far. + pub fn output_bytes(&self) -> usize { + self.output_bytes + } + + /// Total rows expected. + /// + /// Populated when the input source reports a row count (e.g. a + /// [`arrow_array::RecordBatch`]). Always `Some` when [`WriteProgress::done`] + /// is `true` — falling back to the actual number of rows written. + pub fn total_rows(&self) -> Option { + self.total_rows + } + + /// Number of parallel write tasks currently in flight. + pub fn active_tasks(&self) -> usize { + self.active_tasks + } + + /// Total number of parallel write tasks (i.e. the write parallelism). + pub fn total_tasks(&self) -> usize { + self.total_tasks + } + + /// Whether the write operation has completed. + /// + /// The final callback always has `done = true`. Callers can use this to + /// finalize progress bars or perform cleanup. + pub fn done(&self) -> bool { + self.done + } +} + +/// Callback type for progress updates. +/// +/// Callbacks are serialized by the tracker and are never invoked reentrantly, +/// so `FnMut` is safe to use here. +pub type ProgressCallback = Arc>; + +/// Tracks progress of a write operation and invokes a [`ProgressCallback`]. +/// +/// Call [`WriteProgressTracker::record_batch`] for each batch written. +/// Call [`WriteProgressTracker::finish`] once after all data is written. +/// +/// The callback is never invoked reentrantly: all state updates and callback +/// invocations are serialized behind a single lock. +impl std::fmt::Debug for WriteProgressTracker { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WriteProgressTracker") + .field("total_rows", &self.total_rows) + .finish() + } +} + +pub(crate) struct WriteProgressTracker { + rows_and_bytes: std::sync::Mutex<(usize, usize)>, + /// Wire bytes tracked separately by the insert layer. When set (> 0), + /// this takes precedence over the in-memory bytes from `rows_and_bytes`. + wire_bytes: AtomicUsize, + active_tasks: Arc, + total_tasks: AtomicUsize, + start: Instant, + /// Known total rows from the input source, if available. + total_rows: Option, + callback: ProgressCallback, +} + +impl WriteProgressTracker { + pub fn new(callback: ProgressCallback, total_rows: Option) -> Self { + Self { + rows_and_bytes: std::sync::Mutex::new((0, 0)), + wire_bytes: AtomicUsize::new(0), + active_tasks: Arc::new(AtomicUsize::new(0)), + total_tasks: AtomicUsize::new(1), + start: Instant::now(), + total_rows, + callback, + } + } + + /// Set the total number of parallel write tasks (the write parallelism). + pub fn set_total_tasks(&self, n: usize) { + self.total_tasks.store(n, Ordering::Relaxed); + } + + /// Increment the active task count. Returns a guard that decrements on drop. + pub fn track_task(&self) -> ActiveTaskGuard { + self.active_tasks.fetch_add(1, Ordering::Relaxed); + ActiveTaskGuard(self.active_tasks.clone()) + } + + /// Record a batch of rows passing through the scan node. + pub fn record_batch(&self, rows: usize, bytes: usize) { + // Lock order: callback first, then rows_and_bytes. This is the only + // order used anywhere, so deadlocks cannot occur. + let mut cb = self.callback.lock().unwrap(); + let mut guard = self.rows_and_bytes.lock().unwrap(); + guard.0 += rows; + guard.1 += bytes; + let progress = self.snapshot(guard.0, guard.1, false); + drop(guard); + cb(&progress); + } + + /// Record wire bytes from the insert layer (e.g. IPC-encoded bytes for + /// remote writes). When wire bytes are recorded, they take precedence over + /// the in-memory Arrow bytes tracked by [`record_batch`]. + pub fn record_bytes(&self, bytes: usize) { + self.wire_bytes.fetch_add(bytes, Ordering::Relaxed); + } + + /// Emit the final progress callback indicating the write is complete. + /// + /// `total_rows` is always `Some` on the final callback: it uses the known + /// total if available, or falls back to the number of rows actually written. + pub fn finish(&self) { + let mut cb = self.callback.lock().unwrap(); + let guard = self.rows_and_bytes.lock().unwrap(); + let mut snap = self.snapshot(guard.0, guard.1, true); + snap.total_rows = Some(self.total_rows.unwrap_or(guard.0)); + drop(guard); + cb(&snap); + } + + fn snapshot(&self, rows: usize, in_memory_bytes: usize, done: bool) -> WriteProgress { + let wire = self.wire_bytes.load(Ordering::Relaxed); + // Prefer wire bytes (actual I/O size) when the insert layer is + // tracking them; fall back to in-memory Arrow size otherwise. + // TODO: for local writes, track actual bytes written by Lance + // instead of using in-memory Arrow size as a proxy. + let output_bytes = if wire > 0 { wire } else { in_memory_bytes }; + WriteProgress { + elapsed: self.start.elapsed(), + output_rows: rows, + output_bytes, + total_rows: self.total_rows, + active_tasks: self.active_tasks.load(Ordering::Relaxed), + total_tasks: self.total_tasks.load(Ordering::Relaxed), + done, + } + } +} + +/// RAII guard that decrements the active task count when dropped. +pub(crate) struct ActiveTaskGuard(Arc); + +impl Drop for ActiveTaskGuard { + fn drop(&mut self) { + self.0.fetch_sub(1, Ordering::Relaxed); + } +} + +/// RAII guard that calls [`WriteProgressTracker::finish`] on drop. +/// +/// This ensures the final `done=true` callback fires even if the write +/// errors or the future is cancelled. +pub(crate) struct FinishOnDrop(pub Option>); + +impl Drop for FinishOnDrop { + fn drop(&mut self) { + if let Some(t) = self.0.take() { + t.finish(); + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + + use arrow_array::record_batch; + + use crate::connect; + + #[tokio::test] + async fn test_progress_monitor_fires_callback() { + let db = connect("memory://").execute().await.unwrap(); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let table = db + .create_table("progress_test", batch) + .execute() + .await + .unwrap(); + + let callback_count = Arc::new(AtomicUsize::new(0)); + let last_rows = Arc::new(AtomicUsize::new(0)); + let max_active = Arc::new(AtomicUsize::new(0)); + let last_total_tasks = Arc::new(AtomicUsize::new(0)); + let cb_count = callback_count.clone(); + let cb_rows = last_rows.clone(); + let cb_active = max_active.clone(); + let cb_total_tasks = last_total_tasks.clone(); + + let new_data = record_batch!(("id", Int32, [4, 5, 6])).unwrap(); + table + .add(new_data) + .progress(move |p| { + cb_count.fetch_add(1, Ordering::SeqCst); + cb_rows.store(p.output_rows(), Ordering::SeqCst); + cb_active.fetch_max(p.active_tasks(), Ordering::SeqCst); + cb_total_tasks.store(p.total_tasks(), Ordering::SeqCst); + }) + .execute() + .await + .unwrap(); + + assert_eq!(table.count_rows(None).await.unwrap(), 6); + assert!(callback_count.load(Ordering::SeqCst) >= 1); + // Progress tracks the newly inserted rows, not the total table size. + assert_eq!(last_rows.load(Ordering::SeqCst), 3); + // At least one callback should have seen an active task. + assert!(max_active.load(Ordering::SeqCst) >= 1); + // total_tasks should reflect the write parallelism. + assert!(last_total_tasks.load(Ordering::SeqCst) >= 1); + } + + #[tokio::test] + async fn test_progress_done_fires_at_end() { + let db = connect("memory://").execute().await.unwrap(); + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let table = db + .create_table("progress_done", batch) + .execute() + .await + .unwrap(); + + let seen_done = Arc::new(std::sync::Mutex::new(Vec::::new())); + let seen = seen_done.clone(); + + let new_data = record_batch!(("id", Int32, [4, 5, 6])).unwrap(); + table + .add(new_data) + .progress(move |p| { + seen.lock().unwrap().push(p.done()); + }) + .execute() + .await + .unwrap(); + + let done_flags = seen_done.lock().unwrap(); + assert!(!done_flags.is_empty(), "at least one callback must fire"); + // Only the last callback should have done=true. + let last = *done_flags.last().unwrap(); + assert!(last, "last callback must have done=true"); + // All earlier callbacks should have done=false. + for &d in done_flags.iter().rev().skip(1) { + assert!(!d, "non-final callbacks must have done=false"); + } + } + + #[tokio::test] + async fn test_progress_total_rows_known() { + let db = connect("memory://").execute().await.unwrap(); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let table = db + .create_table("total_known", batch) + .execute() + .await + .unwrap(); + + let seen_total = Arc::new(std::sync::Mutex::new(Vec::new())); + let seen = seen_total.clone(); + + // RecordBatch implements Scannable with num_rows() -> Some(3) + let new_data = record_batch!(("id", Int32, [4, 5, 6])).unwrap(); + table + .add(new_data) + .progress(move |p| { + seen.lock().unwrap().push(p.total_rows()); + }) + .execute() + .await + .unwrap(); + + let totals = seen_total.lock().unwrap(); + // All callbacks (including done) should have total_rows = Some(3) + assert!( + totals.contains(&Some(3)), + "expected total_rows=Some(3) in at least one callback, got: {:?}", + *totals + ); + } + + #[tokio::test] + async fn test_progress_total_rows_unknown() { + use arrow_array::RecordBatchIterator; + + let db = connect("memory://").execute().await.unwrap(); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let table = db + .create_table("total_unknown", batch) + .execute() + .await + .unwrap(); + + let seen_total = Arc::new(std::sync::Mutex::new(Vec::new())); + let seen = seen_total.clone(); + + // RecordBatchReader does not provide num_rows, so total_rows should be + // None in intermediate callbacks but always Some on the done callback. + let schema = arrow_schema::Schema::new(vec![arrow_schema::Field::new( + "id", + arrow_schema::DataType::Int32, + false, + )]); + let new_data: Box = + Box::new(RecordBatchIterator::new( + vec![Ok(record_batch!(("id", Int32, [4, 5, 6])).unwrap())], + Arc::new(schema), + )); + table + .add(new_data) + .progress(move |p| { + seen.lock().unwrap().push((p.total_rows(), p.done())); + }) + .execute() + .await + .unwrap(); + + let entries = seen_total.lock().unwrap(); + assert!(!entries.is_empty(), "at least one callback must fire"); + for (total, done) in entries.iter() { + if *done { + assert!( + total.is_some(), + "done callback must have total_rows set, got: {:?}", + total + ); + } else { + assert_eq!( + *total, None, + "intermediate callback must have total_rows=None, got: {:?}", + total + ); + } + } + } +}