mirror of
https://github.com/lancedb/lancedb.git
synced 2026-03-22 08:30:40 +00:00
fix: propagate cast errors in add() (#3075)
When we write data with `add()`, we can input data to the table's schema. However, we were using "safe" mode, which propagates errors as nulls. For example, if you pass `u64::max` into a field that is a `u32`, it will just write null instead of giving overflow error. Now it propagates the overflow. This is the same behavior as other systems like DuckDB. --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -326,6 +326,24 @@ def test_add_struct(mem_db: DBConnection):
|
||||
table = mem_db.create_table("test2", schema=schema)
|
||||
table.add(data)
|
||||
|
||||
struct_type = pa.struct(
|
||||
[
|
||||
("b", pa.int64()),
|
||||
("a", pa.int64()),
|
||||
]
|
||||
)
|
||||
expected = pa.table(
|
||||
{
|
||||
"s_list": [
|
||||
[
|
||||
pa.scalar({"b": 1, "a": 2}, type=struct_type),
|
||||
pa.scalar({"b": 4, "a": None}, type=struct_type),
|
||||
]
|
||||
],
|
||||
}
|
||||
)
|
||||
assert table.to_arrow() == expected
|
||||
|
||||
|
||||
def test_add_subschema(mem_db: DBConnection):
|
||||
schema = pa.schema(
|
||||
|
||||
@@ -3,12 +3,13 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_cast::can_cast_types;
|
||||
use arrow_schema::{DataType, Field, FieldRef, Fields, Schema};
|
||||
use datafusion::functions::core::{get_field, named_struct};
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
use datafusion_physical_expr::ScalarFunctionExpr;
|
||||
use datafusion_physical_expr::expressions::{Literal, cast};
|
||||
use datafusion_physical_expr::expressions::{CastExpr, Literal};
|
||||
use datafusion_physical_plan::expressions::Column;
|
||||
use datafusion_physical_plan::projection::ProjectionExec;
|
||||
use datafusion_physical_plan::{ExecutionPlan, PhysicalExpr};
|
||||
@@ -25,12 +26,9 @@ pub fn cast_to_table_schema(
|
||||
return Ok(input);
|
||||
}
|
||||
|
||||
let exprs = build_field_exprs(
|
||||
input_schema.fields(),
|
||||
table_schema.fields(),
|
||||
&|idx| Arc::new(Column::new(input_schema.field(idx).name(), idx)) as Arc<dyn PhysicalExpr>,
|
||||
&input_schema,
|
||||
)?;
|
||||
let exprs = build_field_exprs(input_schema.fields(), table_schema.fields(), &|idx| {
|
||||
Arc::new(Column::new(input_schema.field(idx).name(), idx)) as Arc<dyn PhysicalExpr>
|
||||
})?;
|
||||
|
||||
let exprs: Vec<(Arc<dyn PhysicalExpr>, String)> = exprs
|
||||
.into_iter()
|
||||
@@ -51,7 +49,6 @@ fn build_field_exprs(
|
||||
input_fields: &Fields,
|
||||
table_fields: &Fields,
|
||||
get_input_expr: &dyn Fn(usize) -> Arc<dyn PhysicalExpr>,
|
||||
input_schema: &Schema,
|
||||
) -> Result<Vec<(Arc<dyn PhysicalExpr>, FieldRef)>> {
|
||||
let config = Arc::new(ConfigOptions::default());
|
||||
let mut result = Vec::new();
|
||||
@@ -72,24 +69,19 @@ fn build_field_exprs(
|
||||
(DataType::Struct(in_children), DataType::Struct(tbl_children))
|
||||
if in_children != tbl_children =>
|
||||
{
|
||||
let sub_exprs = build_field_exprs(
|
||||
in_children,
|
||||
tbl_children,
|
||||
&|child_idx| {
|
||||
let child_name = in_children[child_idx].name();
|
||||
Arc::new(ScalarFunctionExpr::new(
|
||||
&format!("get_field({child_name})"),
|
||||
get_field(),
|
||||
vec![
|
||||
input_expr.clone(),
|
||||
Arc::new(Literal::new(ScalarValue::from(child_name.as_str()))),
|
||||
],
|
||||
Arc::new(in_children[child_idx].as_ref().clone()),
|
||||
config.clone(),
|
||||
)) as Arc<dyn PhysicalExpr>
|
||||
},
|
||||
input_schema,
|
||||
)?;
|
||||
let sub_exprs = build_field_exprs(in_children, tbl_children, &|child_idx| {
|
||||
let child_name = in_children[child_idx].name();
|
||||
Arc::new(ScalarFunctionExpr::new(
|
||||
&format!("get_field({child_name})"),
|
||||
get_field(),
|
||||
vec![
|
||||
input_expr.clone(),
|
||||
Arc::new(Literal::new(ScalarValue::from(child_name.as_str()))),
|
||||
],
|
||||
Arc::new(in_children[child_idx].as_ref().clone()),
|
||||
config.clone(),
|
||||
)) as Arc<dyn PhysicalExpr>
|
||||
})?;
|
||||
|
||||
let output_struct_fields: Fields = sub_exprs
|
||||
.iter()
|
||||
@@ -125,17 +117,21 @@ fn build_field_exprs(
|
||||
// Types match: pass through.
|
||||
(inp, tbl) if inp == tbl => input_expr,
|
||||
// Types differ: cast.
|
||||
_ => cast(input_expr, input_schema, table_field.data_type().clone()).map_err(|e| {
|
||||
Error::InvalidInput {
|
||||
// safe: false (the default) means overflow/truncation errors surface at execution time.
|
||||
(_, _) if can_cast_types(input_field.data_type(), table_field.data_type()) => Arc::new(
|
||||
CastExpr::new(input_expr, table_field.data_type().clone(), None),
|
||||
)
|
||||
as Arc<dyn PhysicalExpr>,
|
||||
(inp, tbl) => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"cannot cast field '{}' from {} to {}: {}",
|
||||
"cannot cast field '{}' from {} to {}",
|
||||
table_field.name(),
|
||||
input_field.data_type(),
|
||||
table_field.data_type(),
|
||||
e
|
||||
inp,
|
||||
tbl,
|
||||
),
|
||||
}
|
||||
})?,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let output_field = Arc::new(Field::new(
|
||||
@@ -153,10 +149,12 @@ fn build_field_exprs(
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::buffer::OffsetBuffer;
|
||||
use arrow_array::{
|
||||
Float32Array, Float64Array, Int32Array, Int64Array, RecordBatch, StringArray, StructArray,
|
||||
Array, Float32Array, Float64Array, Int32Array, Int64Array, ListArray, RecordBatch,
|
||||
StringArray, StructArray, UInt32Array, UInt64Array,
|
||||
};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use arrow_schema::{DataType, Field, Fields, Schema};
|
||||
use datafusion::prelude::SessionContext;
|
||||
use datafusion_catalog::MemTable;
|
||||
use futures::TryStreamExt;
|
||||
@@ -495,4 +493,129 @@ mod tests {
|
||||
assert_eq!(b.value(0), "hello");
|
||||
assert_eq!(b.value(1), "world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_narrowing_numeric_cast_success() {
|
||||
let input_batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::UInt64, false)])),
|
||||
vec![Arc::new(UInt64Array::from(vec![1u64, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let table_schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
|
||||
|
||||
let plan = plan_from_batch(input_batch).await;
|
||||
let casted = cast_to_table_schema(plan, &table_schema).unwrap();
|
||||
let result = collect(casted).await;
|
||||
|
||||
assert_eq!(result.schema().field(0).data_type(), &DataType::UInt32);
|
||||
let a: &UInt32Array = result.column(0).as_any().downcast_ref().unwrap();
|
||||
assert_eq!(a.values(), &[1u32, 2, 3]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_narrowing_numeric_cast_overflow_errors() {
|
||||
let overflow_val = u32::MAX as u64 + 1;
|
||||
let input_batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::UInt64, false)])),
|
||||
vec![Arc::new(UInt64Array::from(vec![overflow_val]))],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let table_schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
|
||||
|
||||
let plan = plan_from_batch(input_batch).await;
|
||||
// Planning succeeds — the overflow is only detected at execution time.
|
||||
let casted = cast_to_table_schema(plan, &table_schema).unwrap();
|
||||
|
||||
let ctx = SessionContext::new();
|
||||
let stream = casted.execute(0, ctx.task_ctx()).unwrap();
|
||||
let result: Result<Vec<RecordBatch>, _> = stream.try_collect().await;
|
||||
assert!(result.is_err(), "expected overflow error at execution time");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_struct_field_reorder() {
|
||||
// list<struct<a: Int32, b: Int32>> → list<struct<b: Int64, a: Int64>>
|
||||
// Tests both reordering (a,b → b,a) and element-type widening (Int32 → Int64).
|
||||
let inner_fields: Fields = vec![
|
||||
Field::new("a", DataType::Int32, true),
|
||||
Field::new("b", DataType::Int32, true),
|
||||
]
|
||||
.into();
|
||||
let struct_array = StructArray::from(vec![
|
||||
(
|
||||
Arc::new(inner_fields[0].as_ref().clone()),
|
||||
Arc::new(Int32Array::from(vec![1, 3])) as _,
|
||||
),
|
||||
(
|
||||
Arc::new(inner_fields[1].as_ref().clone()),
|
||||
Arc::new(Int32Array::from(vec![2, 4])) as _,
|
||||
),
|
||||
]);
|
||||
// Offsets: one list element containing two struct rows (0..2).
|
||||
let offsets = OffsetBuffer::from_lengths(vec![2]);
|
||||
let list_array = ListArray::try_new(
|
||||
Arc::new(Field::new("item", DataType::Struct(inner_fields), true)),
|
||||
offsets,
|
||||
Arc::new(struct_array),
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
let input_batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new(
|
||||
"s_list",
|
||||
list_array.data_type().clone(),
|
||||
false,
|
||||
)])),
|
||||
vec![Arc::new(list_array)],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let table_inner: Fields = vec![
|
||||
Field::new("b", DataType::Int64, true),
|
||||
Field::new("a", DataType::Int64, true),
|
||||
]
|
||||
.into();
|
||||
let table_schema = Schema::new(vec![Field::new(
|
||||
"s_list",
|
||||
DataType::List(Arc::new(Field::new(
|
||||
"item",
|
||||
DataType::Struct(table_inner),
|
||||
true,
|
||||
))),
|
||||
false,
|
||||
)]);
|
||||
|
||||
let plan = plan_from_batch(input_batch).await;
|
||||
let casted = cast_to_table_schema(plan, &table_schema).unwrap();
|
||||
let result = collect(casted).await;
|
||||
|
||||
let list_col = result
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<ListArray>()
|
||||
.unwrap();
|
||||
let struct_col = list_col
|
||||
.values()
|
||||
.as_any()
|
||||
.downcast_ref::<StructArray>()
|
||||
.unwrap();
|
||||
assert_eq!(struct_col.num_columns(), 2);
|
||||
|
||||
let b: &Int64Array = struct_col
|
||||
.column_by_name("b")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.unwrap();
|
||||
assert_eq!(b.values(), &[2, 4]);
|
||||
let a: &Int64Array = struct_col
|
||||
.column_by_name("a")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.unwrap();
|
||||
assert_eq!(a.values(), &[1, 3]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use futures::FutureExt;
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -23,7 +24,7 @@ pub struct DeleteResult {
|
||||
pub(crate) async fn execute_delete(table: &NativeTable, predicate: &str) -> Result<DeleteResult> {
|
||||
table.dataset.ensure_mutable()?;
|
||||
let mut dataset = (*table.dataset.get().await?).clone();
|
||||
let delete_result = dataset.delete(predicate).await?;
|
||||
let delete_result = dataset.delete(predicate).boxed().await?;
|
||||
let num_deleted_rows = delete_result.num_deleted_rows;
|
||||
let version = dataset.version().version;
|
||||
table.dataset.update(dataset);
|
||||
|
||||
Reference in New Issue
Block a user