feat: introduce ArrowNative wrapper struct for adding data that is already a RecordBatchReader (#1139)

In
2de226220b
I added a new `IntoArrow` trait for adding data into a table.
Unfortunately, it seems my approach for implementing the trait for
"things that are already record batch readers" was flawed. This PR
corrects that flaw and, conveniently, removes the need to box readers at
all (though it is ok if you do).
This commit is contained in:
Weston Pace
2024-03-20 13:28:17 -07:00
parent f6e9f8e3f4
commit 968c62cb8f
7 changed files with 26 additions and 49 deletions

View File

@@ -124,7 +124,7 @@ impl Connection {
let mode = Self::parse_create_mode_str(&mode)?;
let tbl = self
.get_inner()?
.create_table(&name, Box::new(batches))
.create_table(&name, batches)
.mode(mode)
.execute()
.await

View File

@@ -89,7 +89,7 @@ impl Table {
pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result<()> {
let batches = ipc_file_to_batches(buf.to_vec())
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
let mut op = self.inner_ref()?.add(Box::new(batches));
let mut op = self.inner_ref()?.add(batches);
op = if mode == "append" {
op.mode(AddDataMode::Append)

View File

@@ -95,7 +95,7 @@ impl Connection {
let mode = Self::parse_create_mode_str(mode)?;
let batches = Box::new(ArrowArrayStreamReader::from_pyarrow(data)?);
let batches = ArrowArrayStreamReader::from_pyarrow(data)?;
future_into_py(self_.py(), async move {
let table = inner
.create_table(name, batches)

View File

@@ -64,7 +64,7 @@ impl Table {
}
pub fn add<'a>(self_: PyRef<'a, Self>, data: &PyAny, mode: String) -> PyResult<&'a PyAny> {
let batches = Box::new(ArrowArrayStreamReader::from_pyarrow(data)?);
let batches = ArrowArrayStreamReader::from_pyarrow(data)?;
let mut op = self_.inner_ref()?.add(batches);
if mode == "append" {
op = op.mode(AddDataMode::Append);

View File

@@ -80,7 +80,7 @@ impl JsTable {
rt.spawn(async move {
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let table_rst = database
.create_table(&table_name, Box::new(batch_reader))
.create_table(&table_name, batch_reader)
.write_options(WriteOptions {
lance_write_params: Some(params),
})
@@ -126,7 +126,7 @@ impl JsTable {
rt.spawn(async move {
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let add_result = table
.add(Box::new(batch_reader))
.add(batch_reader)
.write_options(WriteOptions {
lance_write_params: Some(params),
})

View File

@@ -114,8 +114,8 @@ pub trait IntoArrow {
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>>;
}
impl<T: arrow_array::RecordBatchReader + Send + 'static> IntoArrow for Box<T> {
impl<T: arrow_array::RecordBatchReader + Send + 'static> IntoArrow for T {
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>> {
Ok(self)
Ok(Box::new(self))
}
}

View File

@@ -1601,11 +1601,7 @@ mod tests {
let batches = make_test_batches();
let schema = batches.schema().clone();
let table = conn
.create_table("test", Box::new(batches))
.execute()
.await
.unwrap();
let table = conn.create_table("test", batches).execute().await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10);
let new_batches = RecordBatchIterator::new(
@@ -1619,7 +1615,7 @@ mod tests {
schema.clone(),
);
table.add(Box::new(new_batches)).execute().await.unwrap();
table.add(new_batches).execute().await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 20);
assert_eq!(table.name(), "test");
}
@@ -1633,7 +1629,7 @@ mod tests {
// Create a dataset with i=0..10
let batches = merge_insert_test_batches(0, 0);
let table = conn
.create_table("my_table", Box::new(batches))
.create_table("my_table", batches)
.execute()
.await
.unwrap();
@@ -1681,11 +1677,7 @@ mod tests {
let batches = make_test_batches();
let schema = batches.schema().clone();
let table = conn
.create_table("test", Box::new(batches))
.execute()
.await
.unwrap();
let table = conn.create_table("test", batches).execute().await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10);
let batches = vec![RecordBatch::try_new(
@@ -1700,7 +1692,7 @@ mod tests {
// Can overwrite using AddDataOptions::mode
table
.add(Box::new(new_batches))
.add(new_batches)
.mode(AddDataMode::Overwrite)
.execute()
.await
@@ -1718,7 +1710,7 @@ mod tests {
let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone());
table
.add(Box::new(new_batches))
.add(new_batches)
.write_options(WriteOptions {
lance_write_params: Some(param),
})
@@ -1763,7 +1755,7 @@ mod tests {
);
let table = conn
.create_table("my_table", Box::new(record_batch_iter))
.create_table("my_table", record_batch_iter)
.execute()
.await
.unwrap();
@@ -1900,7 +1892,7 @@ mod tests {
);
let table = conn
.create_table("my_table", Box::new(record_batch_iter))
.create_table("my_table", record_batch_iter)
.execute()
.await
.unwrap();
@@ -2021,7 +2013,7 @@ mod tests {
.await
.unwrap();
let tbl = conn
.create_table("my_table", Box::new(make_test_batches()))
.create_table("my_table", make_test_batches())
.execute()
.await
.unwrap();
@@ -2060,7 +2052,7 @@ mod tests {
let batches = make_test_batches();
conn.create_table("my_table", Box::new(batches))
conn.create_table("my_table", batches)
.execute()
.await
.unwrap();
@@ -2153,11 +2145,7 @@ mod tests {
schema,
);
let table = conn
.create_table("test", Box::new(batches))
.execute()
.await
.unwrap();
let table = conn.create_table("test", batches).execute().await.unwrap();
assert_eq!(
table
@@ -2228,7 +2216,7 @@ mod tests {
Ok(FixedSizeListArray::from(data))
}
fn some_sample_data() -> impl RecordBatchReader {
fn some_sample_data() -> Box<dyn RecordBatchReader + Send> {
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1]))],
@@ -2237,7 +2225,7 @@ mod tests {
let schema = batch.schema().clone();
let batch = Ok(batch);
RecordBatchIterator::new(vec![batch], schema)
Box::new(RecordBatchIterator::new(vec![batch], schema))
}
#[tokio::test]
@@ -2254,10 +2242,7 @@ mod tests {
let table = conn
.create_table(
"my_table",
Box::new(RecordBatchIterator::new(
vec![Ok(batch.clone())],
batch.schema(),
)),
RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()),
)
.execute()
.await
@@ -2321,7 +2306,7 @@ mod tests {
assert_eq!(table1.count_rows(None).await.unwrap(), 0);
assert_eq!(table2.count_rows(None).await.unwrap(), 0);
table1.add(Box::new(data)).execute().await.unwrap();
table1.add(data).execute().await.unwrap();
assert_eq!(table1.count_rows(None).await.unwrap(), 1);
match interval {
@@ -2354,21 +2339,13 @@ mod tests {
.await
.unwrap();
let table = conn
.create_table("my_table", Box::new(some_sample_data()))
.create_table("my_table", some_sample_data())
.execute()
.await
.unwrap();
let version = table.version().await.unwrap();
table
.add(Box::new(some_sample_data()))
.execute()
.await
.unwrap();
table.add(some_sample_data()).execute().await.unwrap();
table.checkout(version).await.unwrap();
assert!(table
.add(Box::new(some_sample_data()))
.execute()
.await
.is_err())
assert!(table.add(some_sample_data()).execute().await.is_err())
}
}