diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 39208c0d8..eaa1eb7af 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -326,6 +326,24 @@ def test_add_struct(mem_db: DBConnection): table = mem_db.create_table("test2", schema=schema) table.add(data) + struct_type = pa.struct( + [ + ("b", pa.int64()), + ("a", pa.int64()), + ] + ) + expected = pa.table( + { + "s_list": [ + [ + pa.scalar({"b": 1, "a": 2}, type=struct_type), + pa.scalar({"b": 4, "a": None}, type=struct_type), + ] + ], + } + ) + assert table.to_arrow() == expected + def test_add_subschema(mem_db: DBConnection): schema = pa.schema( diff --git a/rust/lancedb/src/table/datafusion/cast.rs b/rust/lancedb/src/table/datafusion/cast.rs index 7220de96d..b4abb16c5 100644 --- a/rust/lancedb/src/table/datafusion/cast.rs +++ b/rust/lancedb/src/table/datafusion/cast.rs @@ -3,12 +3,13 @@ use std::sync::Arc; +use arrow_cast::can_cast_types; use arrow_schema::{DataType, Field, FieldRef, Fields, Schema}; use datafusion::functions::core::{get_field, named_struct}; use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_physical_expr::ScalarFunctionExpr; -use datafusion_physical_expr::expressions::{Literal, cast}; +use datafusion_physical_expr::expressions::{CastExpr, Literal}; use datafusion_physical_plan::expressions::Column; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::{ExecutionPlan, PhysicalExpr}; @@ -25,12 +26,9 @@ pub fn cast_to_table_schema( return Ok(input); } - let exprs = build_field_exprs( - input_schema.fields(), - table_schema.fields(), - &|idx| Arc::new(Column::new(input_schema.field(idx).name(), idx)) as Arc, - &input_schema, - )?; + let exprs = build_field_exprs(input_schema.fields(), table_schema.fields(), &|idx| { + Arc::new(Column::new(input_schema.field(idx).name(), idx)) as Arc + })?; let exprs: Vec<(Arc, String)> = exprs .into_iter() @@ -51,7 +49,6 @@ fn build_field_exprs( input_fields: &Fields, table_fields: &Fields, get_input_expr: &dyn Fn(usize) -> Arc, - input_schema: &Schema, ) -> Result, FieldRef)>> { let config = Arc::new(ConfigOptions::default()); let mut result = Vec::new(); @@ -72,24 +69,19 @@ fn build_field_exprs( (DataType::Struct(in_children), DataType::Struct(tbl_children)) if in_children != tbl_children => { - let sub_exprs = build_field_exprs( - in_children, - tbl_children, - &|child_idx| { - let child_name = in_children[child_idx].name(); - Arc::new(ScalarFunctionExpr::new( - &format!("get_field({child_name})"), - get_field(), - vec![ - input_expr.clone(), - Arc::new(Literal::new(ScalarValue::from(child_name.as_str()))), - ], - Arc::new(in_children[child_idx].as_ref().clone()), - config.clone(), - )) as Arc - }, - input_schema, - )?; + let sub_exprs = build_field_exprs(in_children, tbl_children, &|child_idx| { + let child_name = in_children[child_idx].name(); + Arc::new(ScalarFunctionExpr::new( + &format!("get_field({child_name})"), + get_field(), + vec![ + input_expr.clone(), + Arc::new(Literal::new(ScalarValue::from(child_name.as_str()))), + ], + Arc::new(in_children[child_idx].as_ref().clone()), + config.clone(), + )) as Arc + })?; let output_struct_fields: Fields = sub_exprs .iter() @@ -125,17 +117,21 @@ fn build_field_exprs( // Types match: pass through. (inp, tbl) if inp == tbl => input_expr, // Types differ: cast. - _ => cast(input_expr, input_schema, table_field.data_type().clone()).map_err(|e| { - Error::InvalidInput { + // safe: false (the default) means overflow/truncation errors surface at execution time. + (_, _) if can_cast_types(input_field.data_type(), table_field.data_type()) => Arc::new( + CastExpr::new(input_expr, table_field.data_type().clone(), None), + ) + as Arc, + (inp, tbl) => { + return Err(Error::InvalidInput { message: format!( - "cannot cast field '{}' from {} to {}: {}", + "cannot cast field '{}' from {} to {}", table_field.name(), - input_field.data_type(), - table_field.data_type(), - e + inp, + tbl, ), - } - })?, + }); + } }; let output_field = Arc::new(Field::new( @@ -153,10 +149,12 @@ fn build_field_exprs( mod tests { use std::sync::Arc; + use arrow::buffer::OffsetBuffer; use arrow_array::{ - Float32Array, Float64Array, Int32Array, Int64Array, RecordBatch, StringArray, StructArray, + Array, Float32Array, Float64Array, Int32Array, Int64Array, ListArray, RecordBatch, + StringArray, StructArray, UInt32Array, UInt64Array, }; - use arrow_schema::{DataType, Field, Schema}; + use arrow_schema::{DataType, Field, Fields, Schema}; use datafusion::prelude::SessionContext; use datafusion_catalog::MemTable; use futures::TryStreamExt; @@ -495,4 +493,129 @@ mod tests { assert_eq!(b.value(0), "hello"); assert_eq!(b.value(1), "world"); } + + #[tokio::test] + async fn test_narrowing_numeric_cast_success() { + let input_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::UInt64, false)])), + vec![Arc::new(UInt64Array::from(vec![1u64, 2, 3]))], + ) + .unwrap(); + + let table_schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]); + + let plan = plan_from_batch(input_batch).await; + let casted = cast_to_table_schema(plan, &table_schema).unwrap(); + let result = collect(casted).await; + + assert_eq!(result.schema().field(0).data_type(), &DataType::UInt32); + let a: &UInt32Array = result.column(0).as_any().downcast_ref().unwrap(); + assert_eq!(a.values(), &[1u32, 2, 3]); + } + + #[tokio::test] + async fn test_narrowing_numeric_cast_overflow_errors() { + let overflow_val = u32::MAX as u64 + 1; + let input_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::UInt64, false)])), + vec![Arc::new(UInt64Array::from(vec![overflow_val]))], + ) + .unwrap(); + + let table_schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]); + + let plan = plan_from_batch(input_batch).await; + // Planning succeeds — the overflow is only detected at execution time. + let casted = cast_to_table_schema(plan, &table_schema).unwrap(); + + let ctx = SessionContext::new(); + let stream = casted.execute(0, ctx.task_ctx()).unwrap(); + let result: Result, _> = stream.try_collect().await; + assert!(result.is_err(), "expected overflow error at execution time"); + } + + #[tokio::test] + async fn test_list_struct_field_reorder() { + // list> → list> + // Tests both reordering (a,b → b,a) and element-type widening (Int32 → Int64). + let inner_fields: Fields = vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ] + .into(); + let struct_array = StructArray::from(vec![ + ( + Arc::new(inner_fields[0].as_ref().clone()), + Arc::new(Int32Array::from(vec![1, 3])) as _, + ), + ( + Arc::new(inner_fields[1].as_ref().clone()), + Arc::new(Int32Array::from(vec![2, 4])) as _, + ), + ]); + // Offsets: one list element containing two struct rows (0..2). + let offsets = OffsetBuffer::from_lengths(vec![2]); + let list_array = ListArray::try_new( + Arc::new(Field::new("item", DataType::Struct(inner_fields), true)), + offsets, + Arc::new(struct_array), + None, + ) + .unwrap(); + let input_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "s_list", + list_array.data_type().clone(), + false, + )])), + vec![Arc::new(list_array)], + ) + .unwrap(); + + let table_inner: Fields = vec![ + Field::new("b", DataType::Int64, true), + Field::new("a", DataType::Int64, true), + ] + .into(); + let table_schema = Schema::new(vec![Field::new( + "s_list", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(table_inner), + true, + ))), + false, + )]); + + let plan = plan_from_batch(input_batch).await; + let casted = cast_to_table_schema(plan, &table_schema).unwrap(); + let result = collect(casted).await; + + let list_col = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let struct_col = list_col + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(struct_col.num_columns(), 2); + + let b: &Int64Array = struct_col + .column_by_name("b") + .unwrap() + .as_any() + .downcast_ref() + .unwrap(); + assert_eq!(b.values(), &[2, 4]); + let a: &Int64Array = struct_col + .column_by_name("a") + .unwrap() + .as_any() + .downcast_ref() + .unwrap(); + assert_eq!(a.values(), &[1, 3]); + } } diff --git a/rust/lancedb/src/table/delete.rs b/rust/lancedb/src/table/delete.rs index d58263026..3d469393c 100644 --- a/rust/lancedb/src/table/delete.rs +++ b/rust/lancedb/src/table/delete.rs @@ -1,3 +1,4 @@ +use futures::FutureExt; // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors use serde::{Deserialize, Serialize}; @@ -23,7 +24,7 @@ pub struct DeleteResult { pub(crate) async fn execute_delete(table: &NativeTable, predicate: &str) -> Result { table.dataset.ensure_mutable()?; let mut dataset = (*table.dataset.get().await?).clone(); - let delete_result = dataset.delete(predicate).await?; + let delete_result = dataset.delete(predicate).boxed().await?; let num_deleted_rows = delete_result.num_deleted_rows; let version = dataset.version().version; table.dataset.update(dataset);