diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index b50ae4e7a..47831fee2 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -25,7 +25,7 @@ use crate::table::MergeResult; use crate::table::Tags; use crate::table::UpdateResult; use crate::table::query::create_multi_vector_plan; -use crate::table::{AnyQuery, Filter, PreprocessingOutput, TableStatistics}; +use crate::table::{AnyQuery, Filter, Predicate, PreprocessingOutput, TableStatistics}; use crate::utils::background_cache::BackgroundCache; use crate::utils::{ resolve_arrow_field_path, supported_btree_data_type, supported_vector_data_type, @@ -1483,9 +1483,13 @@ impl BaseTable for RemoteTable { Ok(update_response) } - async fn delete(&self, predicate: &str) -> Result { + async fn delete(&self, predicate: Predicate<'_>) -> Result { self.check_mutable().await?; - let body = serde_json::json!({ "predicate": predicate }); + let predicate_sql = match predicate { + Predicate::String(s) => s.to_string(), + Predicate::Expr(expr) => expr_to_sql_string(expr)?, + }; + let body = serde_json::json!({ "predicate": predicate_sql }); let request = self .client .post(&format!("/v1/table/{}/delete/", self.identifier)) @@ -2851,6 +2855,33 @@ mod tests { assert_eq!(result.version, if old_server { 0 } else { 43 }); } + #[tokio::test] + async fn test_delete_expr() { + use datafusion_expr::{col, lit}; + + let table = Table::new_with_handler("my_table", move |request| { + if request.url().path() == "/v1/table/my_table/delete/" { + assert_eq!(request.method(), "POST"); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + assert!(body.get("predicate").unwrap().is_string()); + + http::Response::builder() + .status(200) + .body(r#"{"num_deleted_rows": 4, "version": 2}"#) + .unwrap() + } else { + panic!("Unexpected request path: {}", request.url().path()); + } + }); + + let expr = col("id").gt(lit(5)); + let result = table.delete(&expr).await.unwrap(); + assert_eq!(result.num_deleted_rows, 4); + assert_eq!(result.version, 2); + } + #[rstest] #[case(true)] #[case(false)] diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 03f967e6e..6d639a95b 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -253,6 +253,36 @@ pub enum Filter { Datafusion(Expr), } +/// A predicate for filtering rows in delete operations. +/// +/// Accepts either a SQL string or a DataFusion [`Expr`]. Use the [`From`] +/// implementations to convert from `&str` or `&Expr` automatically. +/// See [`Table::delete`] for usage examples. +pub enum Predicate<'a> { + /// A SQL predicate string + String(&'a str), + /// A DataFusion logical expression + Expr(&'a Expr), +} + +impl<'a> From<&'a str> for Predicate<'a> { + fn from(s: &'a str) -> Self { + Predicate::String(s) + } +} + +impl<'a> From<&'a String> for Predicate<'a> { + fn from(s: &'a String) -> Self { + Predicate::String(s.as_str()) + } +} + +impl<'a> From<&'a Expr> for Predicate<'a> { + fn from(e: &'a Expr) -> Self { + Predicate::Expr(e) + } +} + #[async_trait] pub trait Tags: Send + Sync { /// List the tags of the table. @@ -491,8 +521,8 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { /// Add new records to the table. async fn add(&self, add: AddDataBuilder) -> Result; - /// Delete rows from the table. - async fn delete(&self, predicate: &str) -> Result; + /// Delete rows from the table matching the given [`Predicate`]. + async fn delete(&self, predicate: Predicate<'_>) -> Result; /// Update rows in the table. async fn update(&self, update: UpdateBuilder) -> Result; /// Create an index on the provided column(s). @@ -860,7 +890,8 @@ impl Table { /// Delete the rows from table that match the predicate. /// /// # Arguments - /// - `predicate` - The SQL predicate string to filter the rows to be deleted. + /// - `predicate` - A SQL string (`&str`) or DataFusion expression (`&Expr`) + /// that selects the rows to delete. /// /// # Example /// @@ -869,6 +900,7 @@ impl Table { /// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch, /// # RecordBatchIterator, Int32Array}; /// # use arrow_schema::{Schema, Field, DataType}; + /// use datafusion_expr::{col, lit}; /// # tokio::runtime::Runtime::new().unwrap().block_on(async { /// let tmpdir = tempfile::tempdir().unwrap(); /// let db = lancedb::connect(tmpdir.path().to_str().unwrap()) @@ -898,11 +930,17 @@ impl Table { /// .execute() /// .await /// .unwrap(); + /// + /// // Using a SQL string: /// tbl.delete("id > 5").await.unwrap(); + /// + /// // Using a DataFusion expression: + /// let expr = col("id").lt(lit(4)); + /// tbl.delete(&expr).await.unwrap(); /// # }); /// ``` - pub async fn delete(&self, predicate: &str) -> Result { - self.inner.delete(predicate).await + pub async fn delete(&self, predicate: impl Into>) -> Result { + self.inner.delete(predicate.into()).await } /// Create an index on the provided column(s). @@ -2777,8 +2815,7 @@ impl BaseTable for NativeTable { } /// Delete rows from the table - async fn delete(&self, predicate: &str) -> Result { - // Delegate to the submodule implementation + async fn delete(&self, predicate: Predicate<'_>) -> Result { delete::execute_delete(self, predicate).await } diff --git a/rust/lancedb/src/table/delete.rs b/rust/lancedb/src/table/delete.rs index 3d469393c..8f11ee019 100644 --- a/rust/lancedb/src/table/delete.rs +++ b/rust/lancedb/src/table/delete.rs @@ -1,9 +1,12 @@ +use std::sync::Arc; + use futures::FutureExt; +use lance::dataset::DeleteBuilder; // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors use serde::{Deserialize, Serialize}; -use super::NativeTable; +use super::{NativeTable, Predicate}; use crate::Result; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] @@ -21,17 +24,39 @@ pub struct DeleteResult { /// Internal implementation of the delete logic /// /// This logic was moved from NativeTable::delete to keep table.rs clean. -pub(crate) async fn execute_delete(table: &NativeTable, predicate: &str) -> Result { +pub(crate) async fn execute_delete( + table: &NativeTable, + predicate: Predicate<'_>, +) -> Result { table.dataset.ensure_mutable()?; - let mut dataset = (*table.dataset.get().await?).clone(); - let delete_result = dataset.delete(predicate).boxed().await?; - let num_deleted_rows = delete_result.num_deleted_rows; - let version = dataset.version().version; - table.dataset.update(dataset); - Ok(DeleteResult { - num_deleted_rows, - version, - }) + match predicate { + Predicate::String(s) => { + let mut dataset = (*table.dataset.get().await?).clone(); + let delete_result = dataset.delete(s).boxed().await?; + let num_deleted_rows = delete_result.num_deleted_rows; + let version = dataset.version().version; + table.dataset.update(dataset); + Ok(DeleteResult { + num_deleted_rows, + version, + }) + } + Predicate::Expr(expr) => { + let dataset = table.dataset.get().await?; + let delete_result = DeleteBuilder::from_expr(Arc::clone(&dataset), expr.clone()) + .execute() + .await?; + let num_deleted_rows = delete_result.num_deleted_rows; + let version = delete_result.new_dataset.version().version; + table.dataset.update( + Arc::try_unwrap(delete_result.new_dataset).unwrap_or_else(|arc| (*arc).clone()), + ); + Ok(DeleteResult { + num_deleted_rows, + version, + }) + } + } } #[cfg(test)] @@ -176,4 +201,100 @@ mod tests { "Table version must increment after delete operation" ); } + + #[tokio::test] + async fn test_delete_expr() { + use datafusion_expr::{col, lit}; + + let conn = connect("memory://").execute().await.unwrap(); + + // 1. Create a table with values 0 to 9 + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter_values(0..10))], + ) + .unwrap(); + + let table = conn + .create_table("test_delete_expr", batch) + .execute() + .await + .unwrap(); + + // 2. Verify initial state + assert_eq!(table.count_rows(None).await.unwrap(), 10); + let initial_version = table.version().await.unwrap(); + + // 3. Execute Delete with Expr (removes values > 5) + let expr = col("i").gt(lit(5)); + table.delete(&expr).await.unwrap(); + + // 4. Verify results + assert_eq!(table.count_rows(None).await.unwrap(), 6); // 0, 1, 2, 3, 4, 5 remain + let current_version = table.version().await.unwrap(); + assert!( + current_version > initial_version, + "Table version must increment after delete_expr operation" + ); + + // 5. Verify specific data consistency + let batches = table + .query() + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + let batch = &batches[0]; + let array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + // Ensure no value > 5 exists + for val in array.iter() { + assert!(val.unwrap() <= 5); + } + } + + #[tokio::test] + async fn test_delete_expr_increments_version() { + use datafusion_expr::lit; + + let conn = connect("memory://").execute().await.unwrap(); + + // Create a table with 5 rows + let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5])).unwrap(); + + let table = conn + .create_table("test_delete_expr_noop", batch) + .execute() + .await + .unwrap(); + + // Capture the initial state (Rows = 5, Version = 1) + let initial_rows = table.count_rows(None).await.unwrap(); + let initial_version = table.version().await.unwrap(); + + assert_eq!(initial_rows, 5); + let expr = lit(false); + table.delete(&expr).await.unwrap(); + + // Rows should still be 5 + let current_rows = table.count_rows(None).await.unwrap(); + assert_eq!( + current_rows, initial_rows, + "Data should not change when predicate is false" + ); + + // version check + let current_version = table.version().await.unwrap(); + assert!( + current_version > initial_version, + "Table version must increment after delete_expr operation" + ); + } }