mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-15 19:10:39 +00:00
refactor: extract update logic to src/table/update.rs (#2964)
References #2949 Part 2 of table.rs refactor. Moved UpdateResult, UpdateBuilder, and execution logic to src/table/update.rs. No functional changes API remains identical. --------- Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
@@ -23,9 +23,7 @@ pub use lance::dataset::ColumnAlteration;
|
||||
pub use lance::dataset::NewColumnTransform;
|
||||
pub use lance::dataset::ReadParams;
|
||||
pub use lance::dataset::Version;
|
||||
use lance::dataset::{
|
||||
InsertBuilder, UpdateBuilder as LanceUpdateBuilder, WhenMatched, WriteMode, WriteParams,
|
||||
};
|
||||
use lance::dataset::{InsertBuilder, WhenMatched, WriteMode, WriteParams};
|
||||
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
|
||||
use lance::index::vector::utils::infer_vector_dim;
|
||||
use lance::index::vector::VectorIndexParams;
|
||||
@@ -81,6 +79,8 @@ pub mod datafusion;
|
||||
pub(crate) mod dataset;
|
||||
pub mod delete;
|
||||
pub mod merge;
|
||||
pub mod update;
|
||||
|
||||
use crate::index::waiter::wait_for_index;
|
||||
pub use chrono::Duration;
|
||||
pub use delete::DeleteResult;
|
||||
@@ -92,6 +92,7 @@ use lance::dataset::statistics::DatasetStatisticsExt;
|
||||
use lance_index::frag_reuse::FRAG_REUSE_INDEX_NAME;
|
||||
pub use lance_index::optimize::OptimizeOptions;
|
||||
use serde_with::skip_serializing_none;
|
||||
pub use update::{UpdateBuilder, UpdateResult};
|
||||
|
||||
/// Defines the type of column
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -328,72 +329,6 @@ impl<T: IntoArrow> AddDataBuilder<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// A builder for configuring an [`Table::update`] operation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UpdateBuilder {
|
||||
parent: Arc<dyn BaseTable>,
|
||||
pub(crate) filter: Option<String>,
|
||||
pub(crate) columns: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
impl UpdateBuilder {
|
||||
fn new(parent: Arc<dyn BaseTable>) -> Self {
|
||||
Self {
|
||||
parent,
|
||||
filter: None,
|
||||
columns: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Limits the update operation to rows matching the given filter
|
||||
///
|
||||
/// If a row does not match the filter then it will be left unchanged.
|
||||
pub fn only_if(mut self, filter: impl Into<String>) -> Self {
|
||||
self.filter = Some(filter.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Specifies a column to update
|
||||
///
|
||||
/// This method may be called multiple times to update multiple columns
|
||||
///
|
||||
/// The `update_expr` should be an SQL expression explaining how to calculate
|
||||
/// the new value for the column. The expression will be evaluated against the
|
||||
/// previous row's value.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// # use lancedb::Table;
|
||||
/// # async fn doctest_helper(tbl: Table) {
|
||||
/// let mut operation = tbl.update();
|
||||
/// // Increments the `bird_count` value by 1
|
||||
/// operation = operation.column("bird_count", "bird_count + 1");
|
||||
/// operation.execute().await.unwrap();
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn column(
|
||||
mut self,
|
||||
column_name: impl Into<String>,
|
||||
update_expr: impl Into<String>,
|
||||
) -> Self {
|
||||
self.columns.push((column_name.into(), update_expr.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// Executes the update operation.
|
||||
/// Returns the update result
|
||||
pub async fn execute(self) -> Result<UpdateResult> {
|
||||
if self.columns.is_empty() {
|
||||
Err(Error::InvalidInput {
|
||||
message: "at least one column must be specified in an update operation".to_string(),
|
||||
})
|
||||
} else {
|
||||
self.parent.clone().update(self).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Filters that can be used to limit the rows returned by a query
|
||||
pub enum Filter {
|
||||
/// A SQL filter string
|
||||
@@ -427,17 +362,6 @@ pub trait Tags: Send + Sync {
|
||||
async fn update(&mut self, tag: &str, version: u64) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
pub struct UpdateResult {
|
||||
#[serde(default)]
|
||||
pub rows_updated: u64,
|
||||
// The commit version associated with the operation.
|
||||
// A version of `0` indicates compatibility with legacy servers that do not return
|
||||
/// a commit version.
|
||||
#[serde(default)]
|
||||
pub version: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
pub struct AddResult {
|
||||
// The commit version associated with the operation.
|
||||
@@ -2794,25 +2718,8 @@ impl BaseTable for NativeTable {
|
||||
}
|
||||
|
||||
async fn update(&self, update: UpdateBuilder) -> Result<UpdateResult> {
|
||||
let dataset = self.dataset.get().await?.clone();
|
||||
let mut builder = LanceUpdateBuilder::new(Arc::new(dataset));
|
||||
if let Some(predicate) = update.filter {
|
||||
builder = builder.update_where(&predicate)?;
|
||||
}
|
||||
|
||||
for (column, value) in update.columns {
|
||||
builder = builder.set(column, &value)?;
|
||||
}
|
||||
|
||||
let operation = builder.build()?;
|
||||
let res = operation.execute().await?;
|
||||
self.dataset
|
||||
.set_latest(res.new_dataset.as_ref().clone())
|
||||
.await;
|
||||
Ok(UpdateResult {
|
||||
rows_updated: res.rows_updated,
|
||||
version: res.new_dataset.version().version,
|
||||
})
|
||||
// Delegate to the submodule implementation
|
||||
update::execute_update(self, update).await
|
||||
}
|
||||
|
||||
async fn create_plan(
|
||||
@@ -3395,15 +3302,12 @@ mod tests {
|
||||
|
||||
use arrow_array::{
|
||||
builder::{ListBuilder, StringBuilder},
|
||||
Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array, Float64Array,
|
||||
Int32Array, Int64Array, LargeStringArray, RecordBatch, RecordBatchIterator,
|
||||
RecordBatchReader, StringArray, TimestampMillisecondArray, TimestampNanosecondArray,
|
||||
UInt32Array,
|
||||
Array, BooleanArray, FixedSizeListArray, Float32Array, Int32Array, LargeStringArray,
|
||||
RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray,
|
||||
};
|
||||
use arrow_array::{BinaryArray, LargeBinaryArray};
|
||||
use arrow_data::ArrayDataBuilder;
|
||||
use arrow_schema::{DataType, Field, Schema, TimeUnit};
|
||||
use futures::TryStreamExt;
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use lance::dataset::WriteMode;
|
||||
use lance::io::{ObjectStoreParams, WrappingObjectStore};
|
||||
use lance::Dataset;
|
||||
@@ -3415,7 +3319,6 @@ mod tests {
|
||||
use crate::connection::ConnectBuilder;
|
||||
use crate::index::scalar::{BTreeIndexBuilder, BitmapIndexBuilder};
|
||||
use crate::index::vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder};
|
||||
use crate::query::{ExecutableQuery, QueryBase};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_open() {
|
||||
@@ -3637,306 +3540,6 @@ mod tests {
|
||||
assert_eq!(table.name(), "test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_with_predicate() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path().join("test.lance");
|
||||
let uri = dataset_path.to_str().unwrap();
|
||||
let conn = connect(uri)
|
||||
.read_consistency_interval(Duration::from_secs(0))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new("name", DataType::Utf8, false),
|
||||
]));
|
||||
|
||||
let record_batch_iter = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
Arc::new(StringArray::from_iter_values(vec![
|
||||
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
|
||||
])),
|
||||
],
|
||||
)
|
||||
.unwrap()]
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
schema.clone(),
|
||||
);
|
||||
|
||||
let table = conn
|
||||
.create_table("my_table", record_batch_iter)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
table
|
||||
.update()
|
||||
.only_if("id > 5")
|
||||
.column("name", "'foo'")
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut batches = table
|
||||
.query()
|
||||
.select(Select::columns(&["id", "name"]))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
while let Some(batch) = batches.pop() {
|
||||
let ids = batch
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<Int32Array>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
let names = batch
|
||||
.column(1)
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
for (i, name) in names.iter().enumerate() {
|
||||
let id = ids[i].unwrap();
|
||||
let name = name.unwrap();
|
||||
if id > 5 {
|
||||
assert_eq!(name, "foo");
|
||||
} else {
|
||||
assert_eq!(name, &format!("{}", (b'a' + id as u8) as char));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_all_types() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path().join("test.lance");
|
||||
let uri = dataset_path.to_str().unwrap();
|
||||
let conn = connect(uri)
|
||||
.read_consistency_interval(Duration::from_secs(0))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("int32", DataType::Int32, false),
|
||||
Field::new("int64", DataType::Int64, false),
|
||||
Field::new("uint32", DataType::UInt32, false),
|
||||
Field::new("string", DataType::Utf8, false),
|
||||
Field::new("large_string", DataType::LargeUtf8, false),
|
||||
Field::new("float32", DataType::Float32, false),
|
||||
Field::new("float64", DataType::Float64, false),
|
||||
Field::new("bool", DataType::Boolean, false),
|
||||
Field::new("date32", DataType::Date32, false),
|
||||
Field::new(
|
||||
"timestamp_ns",
|
||||
DataType::Timestamp(TimeUnit::Nanosecond, None),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"timestamp_ms",
|
||||
DataType::Timestamp(TimeUnit::Millisecond, None),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"vec_f32",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"vec_f64",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 2),
|
||||
false,
|
||||
),
|
||||
]));
|
||||
|
||||
let record_batch_iter = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
Arc::new(Int64Array::from_iter_values(0..10)),
|
||||
Arc::new(UInt32Array::from_iter_values(0..10)),
|
||||
Arc::new(StringArray::from_iter_values(vec![
|
||||
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
|
||||
])),
|
||||
Arc::new(LargeStringArray::from_iter_values(vec![
|
||||
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
|
||||
])),
|
||||
Arc::new(Float32Array::from_iter_values((0..10).map(|i| i as f32))),
|
||||
Arc::new(Float64Array::from_iter_values((0..10).map(|i| i as f64))),
|
||||
Arc::new(Into::<BooleanArray>::into(vec![
|
||||
true, false, true, false, true, false, true, false, true, false,
|
||||
])),
|
||||
Arc::new(Date32Array::from_iter_values(0..10)),
|
||||
Arc::new(TimestampNanosecondArray::from_iter_values(0..10)),
|
||||
Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
|
||||
Arc::new(
|
||||
create_fixed_size_list(
|
||||
Float32Array::from_iter_values((0..20).map(|i| i as f32)),
|
||||
2,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
Arc::new(
|
||||
create_fixed_size_list(
|
||||
Float64Array::from_iter_values((0..20).map(|i| i as f64)),
|
||||
2,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
],
|
||||
)
|
||||
.unwrap()]
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
schema.clone(),
|
||||
);
|
||||
|
||||
let table = conn
|
||||
.create_table("my_table", record_batch_iter)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// check it can do update for each type
|
||||
let updates: Vec<(&str, &str)> = vec![
|
||||
("string", "'foo'"),
|
||||
("large_string", "'large_foo'"),
|
||||
("int32", "1"),
|
||||
("int64", "1"),
|
||||
("uint32", "1"),
|
||||
("float32", "1.0"),
|
||||
("float64", "1.0"),
|
||||
("bool", "true"),
|
||||
("date32", "1"),
|
||||
("timestamp_ns", "1"),
|
||||
("timestamp_ms", "1"),
|
||||
("vec_f32", "[1.0, 1.0]"),
|
||||
("vec_f64", "[1.0, 1.0]"),
|
||||
];
|
||||
|
||||
let mut update_op = table.update();
|
||||
for (column, value) in updates {
|
||||
update_op = update_op.column(column, value);
|
||||
}
|
||||
update_op.execute().await.unwrap();
|
||||
|
||||
let mut batches = table
|
||||
.query()
|
||||
.select(Select::columns(&[
|
||||
"string",
|
||||
"large_string",
|
||||
"int32",
|
||||
"int64",
|
||||
"uint32",
|
||||
"float32",
|
||||
"float64",
|
||||
"bool",
|
||||
"date32",
|
||||
"timestamp_ns",
|
||||
"timestamp_ms",
|
||||
"vec_f32",
|
||||
"vec_f64",
|
||||
]))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
let batch = batches.pop().unwrap();
|
||||
|
||||
macro_rules! assert_column {
|
||||
($column:expr, $array_type:ty, $expected:expr) => {
|
||||
let array = $column
|
||||
.as_any()
|
||||
.downcast_ref::<$array_type>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
for v in array {
|
||||
assert_eq!(v, Some($expected));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
assert_column!(batch.column(0), StringArray, "foo");
|
||||
assert_column!(batch.column(1), LargeStringArray, "large_foo");
|
||||
assert_column!(batch.column(2), Int32Array, 1);
|
||||
assert_column!(batch.column(3), Int64Array, 1);
|
||||
assert_column!(batch.column(4), UInt32Array, 1);
|
||||
assert_column!(batch.column(5), Float32Array, 1.0);
|
||||
assert_column!(batch.column(6), Float64Array, 1.0);
|
||||
assert_column!(batch.column(7), BooleanArray, true);
|
||||
assert_column!(batch.column(8), Date32Array, 1);
|
||||
assert_column!(batch.column(9), TimestampNanosecondArray, 1);
|
||||
assert_column!(batch.column(10), TimestampMillisecondArray, 1);
|
||||
|
||||
let array = batch
|
||||
.column(11)
|
||||
.as_any()
|
||||
.downcast_ref::<FixedSizeListArray>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
for v in array {
|
||||
let v = v.unwrap();
|
||||
let f32array = v.as_any().downcast_ref::<Float32Array>().unwrap();
|
||||
for v in f32array {
|
||||
assert_eq!(v, Some(1.0));
|
||||
}
|
||||
}
|
||||
|
||||
let array = batch
|
||||
.column(12)
|
||||
.as_any()
|
||||
.downcast_ref::<FixedSizeListArray>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
for v in array {
|
||||
let v = v.unwrap();
|
||||
let f64array = v.as_any().downcast_ref::<Float64Array>().unwrap();
|
||||
for v in f64array {
|
||||
assert_eq!(v, Some(1.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_via_expr() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path().join("test.lance");
|
||||
let uri = dataset_path.to_str().unwrap();
|
||||
let conn = connect(uri)
|
||||
.read_consistency_interval(Duration::from_secs(0))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
let tbl = conn
|
||||
.create_table("my_table", make_test_batches())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(1, tbl.count_rows(Some("i == 0".to_string())).await.unwrap());
|
||||
tbl.update().column("i", "i+1").execute().await.unwrap();
|
||||
assert_eq!(0, tbl.count_rows(Some("i == 0".to_string())).await.unwrap());
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct NoOpCacheWrapper {
|
||||
called: AtomicBool,
|
||||
|
||||
441
rust/lancedb/src/table/update.rs
Normal file
441
rust/lancedb/src/table/update.rs
Normal file
@@ -0,0 +1,441 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use lance::dataset::UpdateBuilder as LanceUpdateBuilder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{BaseTable, NativeTable};
|
||||
use crate::Error;
|
||||
use crate::Result;
|
||||
|
||||
/// The result of an update operation
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
pub struct UpdateResult {
|
||||
#[serde(default)]
|
||||
pub rows_updated: u64,
|
||||
/// The commit version associated with the operation.
|
||||
#[serde(default)]
|
||||
pub version: u64,
|
||||
}
|
||||
|
||||
/// A builder for configuring a [`crate::table::Table::update`] operation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UpdateBuilder {
|
||||
parent: Arc<dyn BaseTable>,
|
||||
pub(crate) filter: Option<String>,
|
||||
pub(crate) columns: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
impl UpdateBuilder {
|
||||
pub(crate) fn new(parent: Arc<dyn BaseTable>) -> Self {
|
||||
Self {
|
||||
parent,
|
||||
filter: None,
|
||||
columns: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Limits the update operation to rows matching the given filter
|
||||
///
|
||||
/// If a row does not match the filter then it will be left unchanged.
|
||||
pub fn only_if(mut self, filter: impl Into<String>) -> Self {
|
||||
self.filter = Some(filter.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Specifies a column to update
|
||||
///
|
||||
/// This method may be called multiple times to update multiple columns
|
||||
///
|
||||
/// The `update_expr` should be an SQL expression explaining how to calculate
|
||||
/// the new value for the column. The expression will be evaluated against the
|
||||
/// previous row's value.
|
||||
pub fn column(
|
||||
mut self,
|
||||
column_name: impl Into<String>,
|
||||
update_expr: impl Into<String>,
|
||||
) -> Self {
|
||||
self.columns.push((column_name.into(), update_expr.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// Executes the update operation.
|
||||
pub async fn execute(self) -> Result<UpdateResult> {
|
||||
if self.columns.is_empty() {
|
||||
Err(Error::InvalidInput {
|
||||
message: "at least one column must be specified in an update operation".to_string(),
|
||||
})
|
||||
} else {
|
||||
self.parent.clone().update(self).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal implementation of the update logic
|
||||
pub(crate) async fn execute_update(
|
||||
table: &NativeTable,
|
||||
update: UpdateBuilder,
|
||||
) -> Result<UpdateResult> {
|
||||
// 1. Snapshot the current dataset
|
||||
let dataset = table.dataset.get().await?.clone();
|
||||
|
||||
// 2. Initialize the Lance Core builder
|
||||
let mut builder = LanceUpdateBuilder::new(Arc::new(dataset));
|
||||
|
||||
// 3. Apply the filter (WHERE clause)
|
||||
if let Some(predicate) = update.filter {
|
||||
builder = builder.update_where(&predicate)?;
|
||||
}
|
||||
|
||||
// 4. Apply the columns (SET clause)
|
||||
for (column, value) in update.columns {
|
||||
builder = builder.set(column, &value)?;
|
||||
}
|
||||
|
||||
// 5. Execute the operation (Write new files)
|
||||
let operation = builder.build()?;
|
||||
let res = operation.execute().await?;
|
||||
|
||||
// 6. Update the table's view of the latest version
|
||||
table
|
||||
.dataset
|
||||
.set_latest(res.new_dataset.as_ref().clone())
|
||||
.await;
|
||||
|
||||
Ok(UpdateResult {
|
||||
rows_updated: res.rows_updated,
|
||||
version: res.new_dataset.version().version,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::connect;
|
||||
use crate::query::QueryBase;
|
||||
use crate::query::{ExecutableQuery, Select};
|
||||
use arrow_array::{
|
||||
record_batch, Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array,
|
||||
Float64Array, Int32Array, Int64Array, LargeStringArray, RecordBatch, RecordBatchIterator,
|
||||
RecordBatchReader, StringArray, TimestampMillisecondArray, TimestampNanosecondArray,
|
||||
UInt32Array,
|
||||
};
|
||||
use arrow_data::ArrayDataBuilder;
|
||||
use arrow_schema::{ArrowError, DataType, Field, Schema, TimeUnit};
|
||||
use futures::TryStreamExt;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_all_types() {
|
||||
let conn = connect("memory://")
|
||||
.read_consistency_interval(Duration::from_secs(0))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("int32", DataType::Int32, false),
|
||||
Field::new("int64", DataType::Int64, false),
|
||||
Field::new("uint32", DataType::UInt32, false),
|
||||
Field::new("string", DataType::Utf8, false),
|
||||
Field::new("large_string", DataType::LargeUtf8, false),
|
||||
Field::new("float32", DataType::Float32, false),
|
||||
Field::new("float64", DataType::Float64, false),
|
||||
Field::new("bool", DataType::Boolean, false),
|
||||
Field::new("date32", DataType::Date32, false),
|
||||
Field::new(
|
||||
"timestamp_ns",
|
||||
DataType::Timestamp(TimeUnit::Nanosecond, None),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"timestamp_ms",
|
||||
DataType::Timestamp(TimeUnit::Millisecond, None),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"vec_f32",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"vec_f64",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 2),
|
||||
false,
|
||||
),
|
||||
]));
|
||||
|
||||
let record_batch_iter = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
Arc::new(Int64Array::from_iter_values(0..10)),
|
||||
Arc::new(UInt32Array::from_iter_values(0..10)),
|
||||
Arc::new(StringArray::from_iter_values(vec![
|
||||
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
|
||||
])),
|
||||
Arc::new(LargeStringArray::from_iter_values(vec![
|
||||
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
|
||||
])),
|
||||
Arc::new(Float32Array::from_iter_values((0..10).map(|i| i as f32))),
|
||||
Arc::new(Float64Array::from_iter_values((0..10).map(|i| i as f64))),
|
||||
Arc::new(Into::<BooleanArray>::into(vec![
|
||||
true, false, true, false, true, false, true, false, true, false,
|
||||
])),
|
||||
Arc::new(Date32Array::from_iter_values(0..10)),
|
||||
Arc::new(TimestampNanosecondArray::from_iter_values(0..10)),
|
||||
Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
|
||||
Arc::new(
|
||||
create_fixed_size_list(
|
||||
Float32Array::from_iter_values((0..20).map(|i| i as f32)),
|
||||
2,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
Arc::new(
|
||||
create_fixed_size_list(
|
||||
Float64Array::from_iter_values((0..20).map(|i| i as f64)),
|
||||
2,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
],
|
||||
)
|
||||
.unwrap()]
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
schema.clone(),
|
||||
);
|
||||
|
||||
let table = conn
|
||||
.create_table("my_table", record_batch_iter)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// check it can do update for each type
|
||||
let updates: Vec<(&str, &str)> = vec![
|
||||
("string", "'foo'"),
|
||||
("large_string", "'large_foo'"),
|
||||
("int32", "1"),
|
||||
("int64", "1"),
|
||||
("uint32", "1"),
|
||||
("float32", "1.0"),
|
||||
("float64", "1.0"),
|
||||
("bool", "true"),
|
||||
("date32", "1"),
|
||||
("timestamp_ns", "1"),
|
||||
("timestamp_ms", "1"),
|
||||
("vec_f32", "[1.0, 1.0]"),
|
||||
("vec_f64", "[1.0, 1.0]"),
|
||||
];
|
||||
|
||||
let mut update_op = table.update();
|
||||
for (column, value) in updates {
|
||||
update_op = update_op.column(column, value);
|
||||
}
|
||||
update_op.execute().await.unwrap();
|
||||
|
||||
let mut batches = table
|
||||
.query()
|
||||
.select(Select::columns(&[
|
||||
"string",
|
||||
"large_string",
|
||||
"int32",
|
||||
"int64",
|
||||
"uint32",
|
||||
"float32",
|
||||
"float64",
|
||||
"bool",
|
||||
"date32",
|
||||
"timestamp_ns",
|
||||
"timestamp_ms",
|
||||
"vec_f32",
|
||||
"vec_f64",
|
||||
]))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
let batch = batches.pop().unwrap();
|
||||
|
||||
macro_rules! assert_column {
|
||||
($column:expr, $array_type:ty, $expected:expr) => {
|
||||
let array = $column
|
||||
.as_any()
|
||||
.downcast_ref::<$array_type>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
for v in array {
|
||||
assert_eq!(v, Some($expected));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
assert_column!(batch.column(0), StringArray, "foo");
|
||||
assert_column!(batch.column(1), LargeStringArray, "large_foo");
|
||||
assert_column!(batch.column(2), Int32Array, 1);
|
||||
assert_column!(batch.column(3), Int64Array, 1);
|
||||
assert_column!(batch.column(4), UInt32Array, 1);
|
||||
assert_column!(batch.column(5), Float32Array, 1.0);
|
||||
assert_column!(batch.column(6), Float64Array, 1.0);
|
||||
assert_column!(batch.column(7), BooleanArray, true);
|
||||
assert_column!(batch.column(8), Date32Array, 1);
|
||||
assert_column!(batch.column(9), TimestampNanosecondArray, 1);
|
||||
assert_column!(batch.column(10), TimestampMillisecondArray, 1);
|
||||
|
||||
let array = batch
|
||||
.column(11)
|
||||
.as_any()
|
||||
.downcast_ref::<FixedSizeListArray>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
for v in array {
|
||||
let v = v.unwrap();
|
||||
let f32array = v.as_any().downcast_ref::<Float32Array>().unwrap();
|
||||
for v in f32array {
|
||||
assert_eq!(v, Some(1.0));
|
||||
}
|
||||
}
|
||||
|
||||
let array = batch
|
||||
.column(12)
|
||||
.as_any()
|
||||
.downcast_ref::<FixedSizeListArray>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
for v in array {
|
||||
let v = v.unwrap();
|
||||
let f64array = v.as_any().downcast_ref::<Float64Array>().unwrap();
|
||||
for v in f64array {
|
||||
assert_eq!(v, Some(1.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
///Two helper functions
|
||||
fn create_fixed_size_list<T: Array>(
|
||||
values: T,
|
||||
list_size: i32,
|
||||
) -> Result<FixedSizeListArray, ArrowError> {
|
||||
let list_type = DataType::FixedSizeList(
|
||||
Arc::new(Field::new("item", values.data_type().clone(), true)),
|
||||
list_size,
|
||||
);
|
||||
let data = ArrayDataBuilder::new(list_type)
|
||||
.len(values.len() / list_size as usize)
|
||||
.add_child_data(values.into_data())
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
Ok(FixedSizeListArray::from(data))
|
||||
}
|
||||
|
||||
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
|
||||
RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from_iter_values(0..10))],
|
||||
)],
|
||||
schema,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_with_predicate() {
|
||||
let conn = connect("memory://")
|
||||
.read_consistency_interval(Duration::from_secs(0))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let batch = record_batch!(
|
||||
("id", Int32, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
|
||||
(
|
||||
"name",
|
||||
Utf8,
|
||||
["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]
|
||||
)
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let schema = batch.schema();
|
||||
// need the iterator for create table
|
||||
let record_batch_iter = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
|
||||
let table = conn
|
||||
.create_table("my_table", record_batch_iter)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
table
|
||||
.update()
|
||||
.only_if("id > 5")
|
||||
.column("name", "'foo'")
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut batches = table
|
||||
.query()
|
||||
.select(Select::columns(&["id", "name"]))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
while let Some(batch) = batches.pop() {
|
||||
let ids = batch
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<Int32Array>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
let names = batch
|
||||
.column(1)
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
for (i, name) in names.iter().enumerate() {
|
||||
let id = ids[i].unwrap();
|
||||
let name = name.unwrap();
|
||||
if id > 5 {
|
||||
assert_eq!(name, "foo");
|
||||
} else {
|
||||
assert_eq!(name, &format!("{}", (b'a' + id as u8) as char));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_via_expr() {
|
||||
let conn = connect("memory://")
|
||||
.read_consistency_interval(Duration::from_secs(0))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
let tbl = conn
|
||||
.create_table("my_table", make_test_batches())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(1, tbl.count_rows(Some("i == 0".to_string())).await.unwrap());
|
||||
tbl.update().column("i", "i+1").execute().await.unwrap();
|
||||
assert_eq!(0, tbl.count_rows(Some("i == 0".to_string())).await.unwrap());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user