From 7d790bd9e7496c51f12ed18860a66af8b714e823 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Wed, 20 Mar 2024 13:28:17 -0700 Subject: [PATCH] feat: introduce ArrowNative wrapper struct for adding data that is already a RecordBatchReader (#1139) In https://github.com/lancedb/lancedb/commit/2de226220b78c51b22319fb4fbc4678f4d269d65 I added a new `IntoArrow` trait for adding data into a table. Unfortunately, it seems my approach for implementing the trait for "things that are already record batch readers" was flawed. This PR corrects that flaw and, conveniently, removes the need to box readers at all (though it is ok if you do). --- nodejs/src/connection.rs | 2 +- nodejs/src/table.rs | 2 +- python/src/connection.rs | 2 +- python/src/table.rs | 2 +- rust/ffi/node/src/table.rs | 4 +-- rust/lancedb/src/arrow.rs | 4 +-- rust/lancedb/src/table.rs | 59 ++++++++++++-------------------------- 7 files changed, 26 insertions(+), 49 deletions(-) diff --git a/nodejs/src/connection.rs b/nodejs/src/connection.rs index 295fccf0..6f569473 100644 --- a/nodejs/src/connection.rs +++ b/nodejs/src/connection.rs @@ -124,7 +124,7 @@ impl Connection { let mode = Self::parse_create_mode_str(&mode)?; let tbl = self .get_inner()? - .create_table(&name, Box::new(batches)) + .create_table(&name, batches) .mode(mode) .execute() .await diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index 4c3e4d33..c1cf7d46 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -89,7 +89,7 @@ impl Table { pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result<()> { let batches = ipc_file_to_batches(buf.to_vec()) .map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?; - let mut op = self.inner_ref()?.add(Box::new(batches)); + let mut op = self.inner_ref()?.add(batches); op = if mode == "append" { op.mode(AddDataMode::Append) diff --git a/python/src/connection.rs b/python/src/connection.rs index 93f5332a..eb2bfae5 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -95,7 +95,7 @@ impl Connection { let mode = Self::parse_create_mode_str(mode)?; - let batches = Box::new(ArrowArrayStreamReader::from_pyarrow(data)?); + let batches = ArrowArrayStreamReader::from_pyarrow(data)?; future_into_py(self_.py(), async move { let table = inner .create_table(name, batches) diff --git a/python/src/table.rs b/python/src/table.rs index 0f4ed73d..05aa408e 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -64,7 +64,7 @@ impl Table { } pub fn add<'a>(self_: PyRef<'a, Self>, data: &PyAny, mode: String) -> PyResult<&'a PyAny> { - let batches = Box::new(ArrowArrayStreamReader::from_pyarrow(data)?); + let batches = ArrowArrayStreamReader::from_pyarrow(data)?; let mut op = self_.inner_ref()?.add(batches); if mode == "append" { op = op.mode(AddDataMode::Append); diff --git a/rust/ffi/node/src/table.rs b/rust/ffi/node/src/table.rs index 4e35825b..13f1e895 100644 --- a/rust/ffi/node/src/table.rs +++ b/rust/ffi/node/src/table.rs @@ -80,7 +80,7 @@ impl JsTable { rt.spawn(async move { let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); let table_rst = database - .create_table(&table_name, Box::new(batch_reader)) + .create_table(&table_name, batch_reader) .write_options(WriteOptions { lance_write_params: Some(params), }) @@ -126,7 +126,7 @@ impl JsTable { rt.spawn(async move { let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); let add_result = table - .add(Box::new(batch_reader)) + .add(batch_reader) .write_options(WriteOptions { lance_write_params: Some(params), }) diff --git a/rust/lancedb/src/arrow.rs b/rust/lancedb/src/arrow.rs index 55e34a60..da990975 100644 --- a/rust/lancedb/src/arrow.rs +++ b/rust/lancedb/src/arrow.rs @@ -114,8 +114,8 @@ pub trait IntoArrow { fn into_arrow(self) -> Result>; } -impl IntoArrow for Box { +impl IntoArrow for T { fn into_arrow(self) -> Result> { - Ok(self) + Ok(Box::new(self)) } } diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 20208925..252267ef 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -1601,11 +1601,7 @@ mod tests { let batches = make_test_batches(); let schema = batches.schema().clone(); - let table = conn - .create_table("test", Box::new(batches)) - .execute() - .await - .unwrap(); + let table = conn.create_table("test", batches).execute().await.unwrap(); assert_eq!(table.count_rows(None).await.unwrap(), 10); let new_batches = RecordBatchIterator::new( @@ -1619,7 +1615,7 @@ mod tests { schema.clone(), ); - table.add(Box::new(new_batches)).execute().await.unwrap(); + table.add(new_batches).execute().await.unwrap(); assert_eq!(table.count_rows(None).await.unwrap(), 20); assert_eq!(table.name(), "test"); } @@ -1633,7 +1629,7 @@ mod tests { // Create a dataset with i=0..10 let batches = merge_insert_test_batches(0, 0); let table = conn - .create_table("my_table", Box::new(batches)) + .create_table("my_table", batches) .execute() .await .unwrap(); @@ -1681,11 +1677,7 @@ mod tests { let batches = make_test_batches(); let schema = batches.schema().clone(); - let table = conn - .create_table("test", Box::new(batches)) - .execute() - .await - .unwrap(); + let table = conn.create_table("test", batches).execute().await.unwrap(); assert_eq!(table.count_rows(None).await.unwrap(), 10); let batches = vec![RecordBatch::try_new( @@ -1700,7 +1692,7 @@ mod tests { // Can overwrite using AddDataOptions::mode table - .add(Box::new(new_batches)) + .add(new_batches) .mode(AddDataMode::Overwrite) .execute() .await @@ -1718,7 +1710,7 @@ mod tests { let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone()); table - .add(Box::new(new_batches)) + .add(new_batches) .write_options(WriteOptions { lance_write_params: Some(param), }) @@ -1763,7 +1755,7 @@ mod tests { ); let table = conn - .create_table("my_table", Box::new(record_batch_iter)) + .create_table("my_table", record_batch_iter) .execute() .await .unwrap(); @@ -1900,7 +1892,7 @@ mod tests { ); let table = conn - .create_table("my_table", Box::new(record_batch_iter)) + .create_table("my_table", record_batch_iter) .execute() .await .unwrap(); @@ -2021,7 +2013,7 @@ mod tests { .await .unwrap(); let tbl = conn - .create_table("my_table", Box::new(make_test_batches())) + .create_table("my_table", make_test_batches()) .execute() .await .unwrap(); @@ -2060,7 +2052,7 @@ mod tests { let batches = make_test_batches(); - conn.create_table("my_table", Box::new(batches)) + conn.create_table("my_table", batches) .execute() .await .unwrap(); @@ -2153,11 +2145,7 @@ mod tests { schema, ); - let table = conn - .create_table("test", Box::new(batches)) - .execute() - .await - .unwrap(); + let table = conn.create_table("test", batches).execute().await.unwrap(); assert_eq!( table @@ -2228,7 +2216,7 @@ mod tests { Ok(FixedSizeListArray::from(data)) } - fn some_sample_data() -> impl RecordBatchReader { + fn some_sample_data() -> Box { let batch = RecordBatch::try_new( Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])), vec![Arc::new(Int32Array::from(vec![1]))], @@ -2237,7 +2225,7 @@ mod tests { let schema = batch.schema().clone(); let batch = Ok(batch); - RecordBatchIterator::new(vec![batch], schema) + Box::new(RecordBatchIterator::new(vec![batch], schema)) } #[tokio::test] @@ -2254,10 +2242,7 @@ mod tests { let table = conn .create_table( "my_table", - Box::new(RecordBatchIterator::new( - vec![Ok(batch.clone())], - batch.schema(), - )), + RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()), ) .execute() .await @@ -2321,7 +2306,7 @@ mod tests { assert_eq!(table1.count_rows(None).await.unwrap(), 0); assert_eq!(table2.count_rows(None).await.unwrap(), 0); - table1.add(Box::new(data)).execute().await.unwrap(); + table1.add(data).execute().await.unwrap(); assert_eq!(table1.count_rows(None).await.unwrap(), 1); match interval { @@ -2354,21 +2339,13 @@ mod tests { .await .unwrap(); let table = conn - .create_table("my_table", Box::new(some_sample_data())) + .create_table("my_table", some_sample_data()) .execute() .await .unwrap(); let version = table.version().await.unwrap(); - table - .add(Box::new(some_sample_data())) - .execute() - .await - .unwrap(); + table.add(some_sample_data()).execute().await.unwrap(); table.checkout(version).await.unwrap(); - assert!(table - .add(Box::new(some_sample_data())) - .execute() - .await - .is_err()) + assert!(table.add(some_sample_data()).execute().await.is_err()) } }