diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 06f4854a9..456c242ee 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -23,9 +23,7 @@ pub use lance::dataset::ColumnAlteration; pub use lance::dataset::NewColumnTransform; pub use lance::dataset::ReadParams; pub use lance::dataset::Version; -use lance::dataset::{ - InsertBuilder, UpdateBuilder as LanceUpdateBuilder, WhenMatched, WriteMode, WriteParams, -}; +use lance::dataset::{InsertBuilder, WhenMatched, WriteMode, WriteParams}; use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource}; use lance::index::vector::utils::infer_vector_dim; use lance::index::vector::VectorIndexParams; @@ -81,6 +79,8 @@ pub mod datafusion; pub(crate) mod dataset; pub mod delete; pub mod merge; +pub mod update; + use crate::index::waiter::wait_for_index; pub use chrono::Duration; pub use delete::DeleteResult; @@ -92,6 +92,7 @@ use lance::dataset::statistics::DatasetStatisticsExt; use lance_index::frag_reuse::FRAG_REUSE_INDEX_NAME; pub use lance_index::optimize::OptimizeOptions; use serde_with::skip_serializing_none; +pub use update::{UpdateBuilder, UpdateResult}; /// Defines the type of column #[derive(Debug, Clone, Serialize, Deserialize)] @@ -328,72 +329,6 @@ impl AddDataBuilder { } } -/// A builder for configuring an [`Table::update`] operation -#[derive(Debug, Clone)] -pub struct UpdateBuilder { - parent: Arc, - pub(crate) filter: Option, - pub(crate) columns: Vec<(String, String)>, -} - -impl UpdateBuilder { - fn new(parent: Arc) -> Self { - Self { - parent, - filter: None, - columns: Vec::new(), - } - } - - /// Limits the update operation to rows matching the given filter - /// - /// If a row does not match the filter then it will be left unchanged. - pub fn only_if(mut self, filter: impl Into) -> Self { - self.filter = Some(filter.into()); - self - } - - /// Specifies a column to update - /// - /// This method may be called multiple times to update multiple columns - /// - /// The `update_expr` should be an SQL expression explaining how to calculate - /// the new value for the column. The expression will be evaluated against the - /// previous row's value. - /// - /// # Examples - /// - /// ``` - /// # use lancedb::Table; - /// # async fn doctest_helper(tbl: Table) { - /// let mut operation = tbl.update(); - /// // Increments the `bird_count` value by 1 - /// operation = operation.column("bird_count", "bird_count + 1"); - /// operation.execute().await.unwrap(); - /// # } - /// ``` - pub fn column( - mut self, - column_name: impl Into, - update_expr: impl Into, - ) -> Self { - self.columns.push((column_name.into(), update_expr.into())); - self - } - - /// Executes the update operation. - /// Returns the update result - pub async fn execute(self) -> Result { - if self.columns.is_empty() { - Err(Error::InvalidInput { - message: "at least one column must be specified in an update operation".to_string(), - }) - } else { - self.parent.clone().update(self).await - } - } -} - /// Filters that can be used to limit the rows returned by a query pub enum Filter { /// A SQL filter string @@ -427,17 +362,6 @@ pub trait Tags: Send + Sync { async fn update(&mut self, tag: &str, version: u64) -> Result<()>; } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] -pub struct UpdateResult { - #[serde(default)] - pub rows_updated: u64, - // The commit version associated with the operation. - // A version of `0` indicates compatibility with legacy servers that do not return - /// a commit version. - #[serde(default)] - pub version: u64, -} - #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] pub struct AddResult { // The commit version associated with the operation. @@ -2794,25 +2718,8 @@ impl BaseTable for NativeTable { } async fn update(&self, update: UpdateBuilder) -> Result { - let dataset = self.dataset.get().await?.clone(); - let mut builder = LanceUpdateBuilder::new(Arc::new(dataset)); - if let Some(predicate) = update.filter { - builder = builder.update_where(&predicate)?; - } - - for (column, value) in update.columns { - builder = builder.set(column, &value)?; - } - - let operation = builder.build()?; - let res = operation.execute().await?; - self.dataset - .set_latest(res.new_dataset.as_ref().clone()) - .await; - Ok(UpdateResult { - rows_updated: res.rows_updated, - version: res.new_dataset.version().version, - }) + // Delegate to the submodule implementation + update::execute_update(self, update).await } async fn create_plan( @@ -3395,15 +3302,12 @@ mod tests { use arrow_array::{ builder::{ListBuilder, StringBuilder}, - Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array, Float64Array, - Int32Array, Int64Array, LargeStringArray, RecordBatch, RecordBatchIterator, - RecordBatchReader, StringArray, TimestampMillisecondArray, TimestampNanosecondArray, - UInt32Array, + Array, BooleanArray, FixedSizeListArray, Float32Array, Int32Array, LargeStringArray, + RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray, }; use arrow_array::{BinaryArray, LargeBinaryArray}; use arrow_data::ArrayDataBuilder; - use arrow_schema::{DataType, Field, Schema, TimeUnit}; - use futures::TryStreamExt; + use arrow_schema::{DataType, Field, Schema}; use lance::dataset::WriteMode; use lance::io::{ObjectStoreParams, WrappingObjectStore}; use lance::Dataset; @@ -3415,7 +3319,6 @@ mod tests { use crate::connection::ConnectBuilder; use crate::index::scalar::{BTreeIndexBuilder, BitmapIndexBuilder}; use crate::index::vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder}; - use crate::query::{ExecutableQuery, QueryBase}; #[tokio::test] async fn test_open() { @@ -3637,306 +3540,6 @@ mod tests { assert_eq!(table.name(), "test"); } - #[tokio::test] - async fn test_update_with_predicate() { - let tmp_dir = tempdir().unwrap(); - let dataset_path = tmp_dir.path().join("test.lance"); - let uri = dataset_path.to_str().unwrap(); - let conn = connect(uri) - .read_consistency_interval(Duration::from_secs(0)) - .execute() - .await - .unwrap(); - - let schema = Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, false), - ])); - - let record_batch_iter = RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_iter_values(0..10)), - Arc::new(StringArray::from_iter_values(vec![ - "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", - ])), - ], - ) - .unwrap()] - .into_iter() - .map(Ok), - schema.clone(), - ); - - let table = conn - .create_table("my_table", record_batch_iter) - .execute() - .await - .unwrap(); - - table - .update() - .only_if("id > 5") - .column("name", "'foo'") - .execute() - .await - .unwrap(); - - let mut batches = table - .query() - .select(Select::columns(&["id", "name"])) - .execute() - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - - while let Some(batch) = batches.pop() { - let ids = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect::>(); - let names = batch - .column(1) - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect::>(); - for (i, name) in names.iter().enumerate() { - let id = ids[i].unwrap(); - let name = name.unwrap(); - if id > 5 { - assert_eq!(name, "foo"); - } else { - assert_eq!(name, &format!("{}", (b'a' + id as u8) as char)); - } - } - } - } - - #[tokio::test] - async fn test_update_all_types() { - let tmp_dir = tempdir().unwrap(); - let dataset_path = tmp_dir.path().join("test.lance"); - let uri = dataset_path.to_str().unwrap(); - let conn = connect(uri) - .read_consistency_interval(Duration::from_secs(0)) - .execute() - .await - .unwrap(); - - let schema = Arc::new(Schema::new(vec![ - Field::new("int32", DataType::Int32, false), - Field::new("int64", DataType::Int64, false), - Field::new("uint32", DataType::UInt32, false), - Field::new("string", DataType::Utf8, false), - Field::new("large_string", DataType::LargeUtf8, false), - Field::new("float32", DataType::Float32, false), - Field::new("float64", DataType::Float64, false), - Field::new("bool", DataType::Boolean, false), - Field::new("date32", DataType::Date32, false), - Field::new( - "timestamp_ns", - DataType::Timestamp(TimeUnit::Nanosecond, None), - false, - ), - Field::new( - "timestamp_ms", - DataType::Timestamp(TimeUnit::Millisecond, None), - false, - ), - Field::new( - "vec_f32", - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2), - false, - ), - Field::new( - "vec_f64", - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 2), - false, - ), - ])); - - let record_batch_iter = RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_iter_values(0..10)), - Arc::new(Int64Array::from_iter_values(0..10)), - Arc::new(UInt32Array::from_iter_values(0..10)), - Arc::new(StringArray::from_iter_values(vec![ - "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", - ])), - Arc::new(LargeStringArray::from_iter_values(vec![ - "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", - ])), - Arc::new(Float32Array::from_iter_values((0..10).map(|i| i as f32))), - Arc::new(Float64Array::from_iter_values((0..10).map(|i| i as f64))), - Arc::new(Into::::into(vec![ - true, false, true, false, true, false, true, false, true, false, - ])), - Arc::new(Date32Array::from_iter_values(0..10)), - Arc::new(TimestampNanosecondArray::from_iter_values(0..10)), - Arc::new(TimestampMillisecondArray::from_iter_values(0..10)), - Arc::new( - create_fixed_size_list( - Float32Array::from_iter_values((0..20).map(|i| i as f32)), - 2, - ) - .unwrap(), - ), - Arc::new( - create_fixed_size_list( - Float64Array::from_iter_values((0..20).map(|i| i as f64)), - 2, - ) - .unwrap(), - ), - ], - ) - .unwrap()] - .into_iter() - .map(Ok), - schema.clone(), - ); - - let table = conn - .create_table("my_table", record_batch_iter) - .execute() - .await - .unwrap(); - - // check it can do update for each type - let updates: Vec<(&str, &str)> = vec![ - ("string", "'foo'"), - ("large_string", "'large_foo'"), - ("int32", "1"), - ("int64", "1"), - ("uint32", "1"), - ("float32", "1.0"), - ("float64", "1.0"), - ("bool", "true"), - ("date32", "1"), - ("timestamp_ns", "1"), - ("timestamp_ms", "1"), - ("vec_f32", "[1.0, 1.0]"), - ("vec_f64", "[1.0, 1.0]"), - ]; - - let mut update_op = table.update(); - for (column, value) in updates { - update_op = update_op.column(column, value); - } - update_op.execute().await.unwrap(); - - let mut batches = table - .query() - .select(Select::columns(&[ - "string", - "large_string", - "int32", - "int64", - "uint32", - "float32", - "float64", - "bool", - "date32", - "timestamp_ns", - "timestamp_ms", - "vec_f32", - "vec_f64", - ])) - .execute() - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - let batch = batches.pop().unwrap(); - - macro_rules! assert_column { - ($column:expr, $array_type:ty, $expected:expr) => { - let array = $column - .as_any() - .downcast_ref::<$array_type>() - .unwrap() - .iter() - .collect::>(); - for v in array { - assert_eq!(v, Some($expected)); - } - }; - } - - assert_column!(batch.column(0), StringArray, "foo"); - assert_column!(batch.column(1), LargeStringArray, "large_foo"); - assert_column!(batch.column(2), Int32Array, 1); - assert_column!(batch.column(3), Int64Array, 1); - assert_column!(batch.column(4), UInt32Array, 1); - assert_column!(batch.column(5), Float32Array, 1.0); - assert_column!(batch.column(6), Float64Array, 1.0); - assert_column!(batch.column(7), BooleanArray, true); - assert_column!(batch.column(8), Date32Array, 1); - assert_column!(batch.column(9), TimestampNanosecondArray, 1); - assert_column!(batch.column(10), TimestampMillisecondArray, 1); - - let array = batch - .column(11) - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect::>(); - for v in array { - let v = v.unwrap(); - let f32array = v.as_any().downcast_ref::().unwrap(); - for v in f32array { - assert_eq!(v, Some(1.0)); - } - } - - let array = batch - .column(12) - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect::>(); - for v in array { - let v = v.unwrap(); - let f64array = v.as_any().downcast_ref::().unwrap(); - for v in f64array { - assert_eq!(v, Some(1.0)); - } - } - } - - #[tokio::test] - async fn test_update_via_expr() { - let tmp_dir = tempdir().unwrap(); - let dataset_path = tmp_dir.path().join("test.lance"); - let uri = dataset_path.to_str().unwrap(); - let conn = connect(uri) - .read_consistency_interval(Duration::from_secs(0)) - .execute() - .await - .unwrap(); - let tbl = conn - .create_table("my_table", make_test_batches()) - .execute() - .await - .unwrap(); - assert_eq!(1, tbl.count_rows(Some("i == 0".to_string())).await.unwrap()); - tbl.update().column("i", "i+1").execute().await.unwrap(); - assert_eq!(0, tbl.count_rows(Some("i == 0".to_string())).await.unwrap()); - } - #[derive(Default, Debug)] struct NoOpCacheWrapper { called: AtomicBool, diff --git a/rust/lancedb/src/table/update.rs b/rust/lancedb/src/table/update.rs new file mode 100644 index 000000000..6616dddc2 --- /dev/null +++ b/rust/lancedb/src/table/update.rs @@ -0,0 +1,441 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::sync::Arc; + +use lance::dataset::UpdateBuilder as LanceUpdateBuilder; +use serde::{Deserialize, Serialize}; + +use super::{BaseTable, NativeTable}; +use crate::Error; +use crate::Result; + +/// The result of an update operation +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +pub struct UpdateResult { + #[serde(default)] + pub rows_updated: u64, + /// The commit version associated with the operation. + #[serde(default)] + pub version: u64, +} + +/// A builder for configuring a [`crate::table::Table::update`] operation +#[derive(Debug, Clone)] +pub struct UpdateBuilder { + parent: Arc, + pub(crate) filter: Option, + pub(crate) columns: Vec<(String, String)>, +} + +impl UpdateBuilder { + pub(crate) fn new(parent: Arc) -> Self { + Self { + parent, + filter: None, + columns: Vec::new(), + } + } + + /// Limits the update operation to rows matching the given filter + /// + /// If a row does not match the filter then it will be left unchanged. + pub fn only_if(mut self, filter: impl Into) -> Self { + self.filter = Some(filter.into()); + self + } + + /// Specifies a column to update + /// + /// This method may be called multiple times to update multiple columns + /// + /// The `update_expr` should be an SQL expression explaining how to calculate + /// the new value for the column. The expression will be evaluated against the + /// previous row's value. + pub fn column( + mut self, + column_name: impl Into, + update_expr: impl Into, + ) -> Self { + self.columns.push((column_name.into(), update_expr.into())); + self + } + + /// Executes the update operation. + pub async fn execute(self) -> Result { + if self.columns.is_empty() { + Err(Error::InvalidInput { + message: "at least one column must be specified in an update operation".to_string(), + }) + } else { + self.parent.clone().update(self).await + } + } +} + +/// Internal implementation of the update logic +pub(crate) async fn execute_update( + table: &NativeTable, + update: UpdateBuilder, +) -> Result { + // 1. Snapshot the current dataset + let dataset = table.dataset.get().await?.clone(); + + // 2. Initialize the Lance Core builder + let mut builder = LanceUpdateBuilder::new(Arc::new(dataset)); + + // 3. Apply the filter (WHERE clause) + if let Some(predicate) = update.filter { + builder = builder.update_where(&predicate)?; + } + + // 4. Apply the columns (SET clause) + for (column, value) in update.columns { + builder = builder.set(column, &value)?; + } + + // 5. Execute the operation (Write new files) + let operation = builder.build()?; + let res = operation.execute().await?; + + // 6. Update the table's view of the latest version + table + .dataset + .set_latest(res.new_dataset.as_ref().clone()) + .await; + + Ok(UpdateResult { + rows_updated: res.rows_updated, + version: res.new_dataset.version().version, + }) +} + +#[cfg(test)] +mod tests { + use crate::connect; + use crate::query::QueryBase; + use crate::query::{ExecutableQuery, Select}; + use arrow_array::{ + record_batch, Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array, + Float64Array, Int32Array, Int64Array, LargeStringArray, RecordBatch, RecordBatchIterator, + RecordBatchReader, StringArray, TimestampMillisecondArray, TimestampNanosecondArray, + UInt32Array, + }; + use arrow_data::ArrayDataBuilder; + use arrow_schema::{ArrowError, DataType, Field, Schema, TimeUnit}; + use futures::TryStreamExt; + use std::sync::Arc; + use std::time::Duration; + + #[tokio::test] + async fn test_update_all_types() { + let conn = connect("memory://") + .read_consistency_interval(Duration::from_secs(0)) + .execute() + .await + .unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("int32", DataType::Int32, false), + Field::new("int64", DataType::Int64, false), + Field::new("uint32", DataType::UInt32, false), + Field::new("string", DataType::Utf8, false), + Field::new("large_string", DataType::LargeUtf8, false), + Field::new("float32", DataType::Float32, false), + Field::new("float64", DataType::Float64, false), + Field::new("bool", DataType::Boolean, false), + Field::new("date32", DataType::Date32, false), + Field::new( + "timestamp_ns", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new( + "timestamp_ms", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new( + "vec_f32", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2), + false, + ), + Field::new( + "vec_f64", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 2), + false, + ), + ])); + + let record_batch_iter = RecordBatchIterator::new( + vec![RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_iter_values(0..10)), + Arc::new(Int64Array::from_iter_values(0..10)), + Arc::new(UInt32Array::from_iter_values(0..10)), + Arc::new(StringArray::from_iter_values(vec![ + "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", + ])), + Arc::new(LargeStringArray::from_iter_values(vec![ + "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", + ])), + Arc::new(Float32Array::from_iter_values((0..10).map(|i| i as f32))), + Arc::new(Float64Array::from_iter_values((0..10).map(|i| i as f64))), + Arc::new(Into::::into(vec![ + true, false, true, false, true, false, true, false, true, false, + ])), + Arc::new(Date32Array::from_iter_values(0..10)), + Arc::new(TimestampNanosecondArray::from_iter_values(0..10)), + Arc::new(TimestampMillisecondArray::from_iter_values(0..10)), + Arc::new( + create_fixed_size_list( + Float32Array::from_iter_values((0..20).map(|i| i as f32)), + 2, + ) + .unwrap(), + ), + Arc::new( + create_fixed_size_list( + Float64Array::from_iter_values((0..20).map(|i| i as f64)), + 2, + ) + .unwrap(), + ), + ], + ) + .unwrap()] + .into_iter() + .map(Ok), + schema.clone(), + ); + + let table = conn + .create_table("my_table", record_batch_iter) + .execute() + .await + .unwrap(); + + // check it can do update for each type + let updates: Vec<(&str, &str)> = vec![ + ("string", "'foo'"), + ("large_string", "'large_foo'"), + ("int32", "1"), + ("int64", "1"), + ("uint32", "1"), + ("float32", "1.0"), + ("float64", "1.0"), + ("bool", "true"), + ("date32", "1"), + ("timestamp_ns", "1"), + ("timestamp_ms", "1"), + ("vec_f32", "[1.0, 1.0]"), + ("vec_f64", "[1.0, 1.0]"), + ]; + + let mut update_op = table.update(); + for (column, value) in updates { + update_op = update_op.column(column, value); + } + update_op.execute().await.unwrap(); + + let mut batches = table + .query() + .select(Select::columns(&[ + "string", + "large_string", + "int32", + "int64", + "uint32", + "float32", + "float64", + "bool", + "date32", + "timestamp_ns", + "timestamp_ms", + "vec_f32", + "vec_f64", + ])) + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + let batch = batches.pop().unwrap(); + + macro_rules! assert_column { + ($column:expr, $array_type:ty, $expected:expr) => { + let array = $column + .as_any() + .downcast_ref::<$array_type>() + .unwrap() + .iter() + .collect::>(); + for v in array { + assert_eq!(v, Some($expected)); + } + }; + } + + assert_column!(batch.column(0), StringArray, "foo"); + assert_column!(batch.column(1), LargeStringArray, "large_foo"); + assert_column!(batch.column(2), Int32Array, 1); + assert_column!(batch.column(3), Int64Array, 1); + assert_column!(batch.column(4), UInt32Array, 1); + assert_column!(batch.column(5), Float32Array, 1.0); + assert_column!(batch.column(6), Float64Array, 1.0); + assert_column!(batch.column(7), BooleanArray, true); + assert_column!(batch.column(8), Date32Array, 1); + assert_column!(batch.column(9), TimestampNanosecondArray, 1); + assert_column!(batch.column(10), TimestampMillisecondArray, 1); + + let array = batch + .column(11) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect::>(); + for v in array { + let v = v.unwrap(); + let f32array = v.as_any().downcast_ref::().unwrap(); + for v in f32array { + assert_eq!(v, Some(1.0)); + } + } + + let array = batch + .column(12) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect::>(); + for v in array { + let v = v.unwrap(); + let f64array = v.as_any().downcast_ref::().unwrap(); + for v in f64array { + assert_eq!(v, Some(1.0)); + } + } + } + ///Two helper functions + fn create_fixed_size_list( + values: T, + list_size: i32, + ) -> Result { + let list_type = DataType::FixedSizeList( + Arc::new(Field::new("item", values.data_type().clone(), true)), + list_size, + ); + let data = ArrayDataBuilder::new(list_type) + .len(values.len() / list_size as usize) + .add_child_data(values.into_data()) + .build() + .unwrap(); + + Ok(FixedSizeListArray::from(data)) + } + + fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); + RecordBatchIterator::new( + vec![RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter_values(0..10))], + )], + schema, + ) + } + + #[tokio::test] + async fn test_update_with_predicate() { + let conn = connect("memory://") + .read_consistency_interval(Duration::from_secs(0)) + .execute() + .await + .unwrap(); + + let batch = record_batch!( + ("id", Int32, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), + ( + "name", + Utf8, + ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] + ) + ) + .unwrap(); + + let schema = batch.schema(); + // need the iterator for create table + let record_batch_iter = RecordBatchIterator::new(vec![Ok(batch)], schema); + + let table = conn + .create_table("my_table", record_batch_iter) + .execute() + .await + .unwrap(); + + table + .update() + .only_if("id > 5") + .column("name", "'foo'") + .execute() + .await + .unwrap(); + + let mut batches = table + .query() + .select(Select::columns(&["id", "name"])) + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + while let Some(batch) = batches.pop() { + let ids = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect::>(); + let names = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect::>(); + for (i, name) in names.iter().enumerate() { + let id = ids[i].unwrap(); + let name = name.unwrap(); + if id > 5 { + assert_eq!(name, "foo"); + } else { + assert_eq!(name, &format!("{}", (b'a' + id as u8) as char)); + } + } + } + } + + #[tokio::test] + async fn test_update_via_expr() { + let conn = connect("memory://") + .read_consistency_interval(Duration::from_secs(0)) + .execute() + .await + .unwrap(); + let tbl = conn + .create_table("my_table", make_test_batches()) + .execute() + .await + .unwrap(); + assert_eq!(1, tbl.count_rows(Some("i == 0".to_string())).await.unwrap()); + tbl.update().column("i", "i+1").execute().await.unwrap(); + assert_eq!(0, tbl.count_rows(Some("i == 0".to_string())).await.unwrap()); + } +}