From 394bb34fa28a3d0fe4477c4f3398d46d6d643be5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ghxst=20=E2=98=A0=EF=B8=8F?= Date: Wed, 17 Jun 2026 21:05:59 +0200 Subject: [PATCH] fix(rust): report local write progress bytes from Lance (#3422) Fixes #3360. This updates native table writes so local write progress uses Lance writer byte stats instead of Arrow in-memory batch size once write bytes are available. The change wires the existing `WriteProgressTracker` into `InsertExec` for native `add` writes, installs a Lance `WriteProgressFn` only when no lower-level callback is already configured, and keeps the existing public `InsertExec::new` signature unchanged. Validation: - `cargo test -p lancedb --features remote table::write_progress::tests::test_progress_uses_lance_write_bytes_for_local_tables -- --nocapture` passed: 1 passed, 0 failed. - `cargo test -p lancedb --features remote table::write_progress::tests -- --nocapture` passed: 7 passed, 0 failed. - `cargo check --quiet --features remote --tests --examples` passed. - `cargo fmt --all --check` passed. - `git diff --check` passed. - `git diff | gitleaks stdin --no-banner --redact --timeout 30` passed: no leaks found. I did not run the full `cargo test --quiet --features remote --tests` suite. Co-authored-by: Ghxst <200635707+GHX5T-SOL@users.noreply.github.com> --- rust/lancedb/src/table.rs | 8 ++- rust/lancedb/src/table/datafusion/insert.rs | 36 ++++++++++- rust/lancedb/src/table/write_progress.rs | 66 +++++++++++++++++++-- 3 files changed, 101 insertions(+), 9 deletions(-) diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 91f7080bf..ccc41b81e 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -2616,7 +2616,13 @@ impl BaseTable for NativeTable { output.plan }; - let insert_exec = Arc::new(InsertExec::new(ds_wrapper.clone(), ds, plan, lance_params)); + let insert_exec = Arc::new(InsertExec::new_with_tracker( + ds_wrapper.clone(), + ds, + plan, + lance_params, + output.tracker.clone(), + )); let tracker_for_tasks = output.tracker.clone(); if let Some(ref t) = tracker_for_tasks { diff --git a/rust/lancedb/src/table/datafusion/insert.rs b/rust/lancedb/src/table/datafusion/insert.rs index f2cf21f13..0428e03b6 100644 --- a/rust/lancedb/src/table/datafusion/insert.rs +++ b/rust/lancedb/src/table/datafusion/insert.rs @@ -4,6 +4,7 @@ //! DataFusion ExecutionPlan for inserting data into LanceDB tables. use std::any::Any; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, LazyLock, Mutex}; use arrow_array::{RecordBatch, UInt64Array}; @@ -20,11 +21,12 @@ use datafusion_physical_plan::{ use futures::TryStreamExt; use lance::Dataset; use lance::dataset::transaction::{Operation, Transaction}; -use lance::dataset::{CommitBuilder, InsertBuilder, WriteParams}; +use lance::dataset::{CommitBuilder, InsertBuilder, WriteParams, WriteProgressFn}; use lance::io::exec::utils::InstrumentedRecordBatchStreamAdapter; use lance_table::format::Fragment; use crate::table::dataset::DatasetConsistencyWrapper; +use crate::table::write_progress::WriteProgressTracker; pub(crate) static COUNT_SCHEMA: LazyLock = LazyLock::new(|| { Arc::new(ArrowSchema::new(vec![Field::new( @@ -81,6 +83,7 @@ pub struct InsertExec { dataset: Arc, input: Arc, write_params: WriteParams, + tracker: Option>, properties: Arc, partial_transactions: Arc>>, metrics: ExecutionPlanMetricsSet, @@ -92,6 +95,16 @@ impl InsertExec { dataset: Arc, input: Arc, write_params: WriteParams, + ) -> Self { + Self::new_with_tracker(ds_wrapper, dataset, input, write_params, None) + } + + pub(crate) fn new_with_tracker( + ds_wrapper: DatasetConsistencyWrapper, + dataset: Arc, + input: Arc, + write_params: WriteParams, + tracker: Option>, ) -> Self { let schema = COUNT_SCHEMA.clone(); let num_partitions = input.output_partitioning().partition_count(); @@ -107,6 +120,7 @@ impl InsertExec { dataset, input, write_params, + tracker, properties: Arc::new(properties), partial_transactions: Arc::new(Mutex::new(Vec::with_capacity(num_partitions))), metrics: ExecutionPlanMetricsSet::new(), @@ -161,11 +175,12 @@ impl ExecutionPlan for InsertExec { "InsertExec requires exactly one child".to_string(), )); } - Ok(Arc::new(Self::new( + Ok(Arc::new(Self::new_with_tracker( self.ds_wrapper.clone(), self.dataset.clone(), children[0].clone(), self.write_params.clone(), + self.tracker.clone(), ))) } @@ -176,10 +191,11 @@ impl ExecutionPlan for InsertExec { ) -> DataFusionResult { let input_stream = self.input.execute(partition, context)?; let dataset = self.dataset.clone(); - let write_params = self.write_params.clone(); + let mut write_params = self.write_params.clone(); let partial_transactions = self.partial_transactions.clone(); let total_partitions = self.input.output_partitioning().partition_count(); let ds_wrapper = self.ds_wrapper.clone(); + let tracker = self.tracker.clone(); let output_bytes = MetricBuilder::new(&self.metrics).output_bytes(partition); let input_schema = input_stream.schema(); @@ -195,6 +211,20 @@ impl ExecutionPlan for InsertExec { )); let stream = futures::stream::once(async move { + if let Some(tracker) = tracker + && write_params.write_progress.is_none() + { + let last_bytes = Arc::new(AtomicU64::new(0)); + write_params.write_progress = Some(WriteProgressFn::new(move |stats| { + let previous = last_bytes.swap(stats.bytes_written, Ordering::Relaxed); + if stats.bytes_written > previous { + let delta = + usize::try_from(stats.bytes_written - previous).unwrap_or(usize::MAX); + tracker.record_bytes(delta); + } + })); + } + let transaction = InsertBuilder::new(dataset.clone()) .with_params(&write_params) .execute_uncommitted_stream(input_stream) diff --git a/rust/lancedb/src/table/write_progress.rs b/rust/lancedb/src/table/write_progress.rs index 7a5c30008..191c3284d 100644 --- a/rust/lancedb/src/table/write_progress.rs +++ b/rust/lancedb/src/table/write_progress.rs @@ -142,11 +142,21 @@ impl WriteProgressTracker { cb(&progress); } - /// Record wire bytes from the insert layer (e.g. IPC-encoded bytes for - /// remote writes). When wire bytes are recorded, they take precedence over - /// the in-memory Arrow bytes tracked by [`record_batch`]. + /// Record wire bytes from the insert layer. + /// + /// These bytes may be IPC-encoded bytes for remote writes or bytes handed + /// to Lance's local writer. When wire bytes are recorded, they take + /// precedence over the in-memory Arrow bytes tracked by [`record_batch`]. pub fn record_bytes(&self, bytes: usize) { self.wire_bytes.fetch_add(bytes, Ordering::Relaxed); + let mut cb = self.callback.lock().unwrap_or_else(|e| e.into_inner()); + let guard = self + .rows_and_bytes + .lock() + .unwrap_or_else(|e| e.into_inner()); + let progress = self.snapshot(guard.0, guard.1, false); + drop(guard); + cb(&progress); } /// Emit the final progress callback indicating the write is complete. @@ -169,8 +179,6 @@ impl WriteProgressTracker { let wire = self.wire_bytes.load(Ordering::Relaxed); // Prefer wire bytes (actual I/O size) when the insert layer is // tracking them; fall back to in-memory Arrow size otherwise. - // TODO: for local writes, track actual bytes written by Lance - // instead of using in-memory Arrow size as a proxy. let output_bytes = if wire > 0 { wire } else { in_memory_bytes }; WriteProgress { elapsed: self.start.elapsed(), @@ -383,6 +391,54 @@ mod tests { } } + #[tokio::test] + async fn test_progress_uses_lance_write_bytes_for_local_tables() { + let dir = tempfile::tempdir().unwrap(); + let db = connect(dir.path().to_str().unwrap()) + .execute() + .await + .unwrap(); + + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let table = db + .create_table("local_write_bytes", batch) + .execute() + .await + .unwrap(); + + let new_data = record_batch!(("id", Int32, [4, 5, 6])).unwrap(); + let in_memory_bytes = new_data.get_array_memory_size(); + let final_bytes = Arc::new(AtomicUsize::new(0)); + let seen_non_memory_bytes = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let final_bytes_cb = final_bytes.clone(); + let seen_non_memory_bytes_cb = seen_non_memory_bytes.clone(); + + table + .add(new_data) + .write_parallelism(1) + .progress(move |p| { + if p.output_bytes() > 0 && p.output_bytes() != in_memory_bytes { + seen_non_memory_bytes_cb.store(true, Ordering::SeqCst); + } + if p.done() { + final_bytes_cb.store(p.output_bytes(), Ordering::SeqCst); + } + }) + .execute() + .await + .unwrap(); + + assert!( + seen_non_memory_bytes.load(Ordering::SeqCst), + "progress should report Lance writer bytes, not only Arrow memory bytes" + ); + assert_ne!( + final_bytes.load(Ordering::SeqCst), + in_memory_bytes, + "final progress bytes should come from Lance write stats" + ); + } + #[test] fn test_record_batch_recovers_from_poisoned_callback_lock() { use super::{ProgressCallback, WriteProgressTracker};