diff --git a/rust/lancedb/src/data/scannable.rs b/rust/lancedb/src/data/scannable.rs index 35c10be59..a200ac465 100644 --- a/rust/lancedb/src/data/scannable.rs +++ b/rust/lancedb/src/data/scannable.rs @@ -271,15 +271,26 @@ impl Scannable for WithEmbeddingsScannable { .map_err(|e| Error::Runtime { message: format!("Task panicked during embedding computation: {}", e), })??; - // Cast columns to match the declared output schema. The data is - // identical but field metadata (e.g. nested nullability) may - // differ between the embedding function output and the table. - let columns: Vec = result - .columns() + // Look up columns by name (not position) so the result matches + // the output schema even when columns appear in a different + // order — e.g. `add_columns` placed a new column after the + // embedding column, but the computed batch appends embeddings + // at the end. Cast per-column because field metadata (e.g. + // nested nullability) may also differ between the embedding + // function output and the table. + let columns: Vec = output_schema + .fields() .iter() - .enumerate() - .map(|(i, col)| { - let target_type = output_schema.field(i).data_type(); + .map(|field| { + let col = result.column_by_name(field.name()).ok_or_else(|| { + Error::InvalidInput { + message: format!( + "Column '{}' required by the table schema was not present in the input batch", + field.name() + ), + } + })?; + let target_type = field.data_type(); if col.data_type() == target_type { Ok(col.clone()) } else { @@ -964,5 +975,118 @@ mod tests { "Expected EmbeddingFunctionNotFound" ); } + + /// Regression test for https://github.com/lancedb/lancedb/issues/3136. + /// + /// When a column is added to the table after the embedding column via + /// schema evolution, the table schema becomes + /// `[..., embedding, extra]`. The input batch (without the embedding) + /// is `[..., extra]`, and `compute_embeddings_for_batch` appends the + /// embedding at the end giving `[..., extra, embedding]`. A positional + /// cast to the output schema would map `extra` onto `embedding` and + /// fail with a CastError. Columns must be matched by name. + #[tokio::test] + async fn test_with_embeddings_scannable_column_added_after_embedding() { + let input_schema = Arc::new(Schema::new(vec![ + Field::new("text", DataType::Utf8, false), + Field::new("score", DataType::Float64, true), + ])); + let batch = RecordBatch::try_new( + input_schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["hello", "world"])) as ArrayRef, + Arc::new(arrow_array::Float64Array::from(vec![1.0, 2.0])) as ArrayRef, + ], + ) + .unwrap(); + + let mock_embedding: Arc = Arc::new(MockEmbed::new("mock", 4)); + let embedding_def = EmbeddingDefinition::new("text", "mock", Some("text_vec")); + + // Table schema: embedding column is BEFORE `score`, as would + // happen if `score` was added via `add_columns` after creating + // the table with an embedding on `text`. + let output_schema = Arc::new(Schema::new(vec![ + Field::new("text", DataType::Utf8, false), + Field::new( + "text_vec", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + 4, + ), + false, + ), + Field::new("score", DataType::Float64, true), + ])); + + let mut scannable = WithEmbeddingsScannable::with_schema( + Box::new(batch), + vec![(embedding_def, mock_embedding)], + output_schema.clone(), + ) + .unwrap(); + + let stream = scannable.scan_as_stream(); + let results: Vec = stream.try_collect().await.unwrap(); + assert_eq!(results.len(), 1); + + let result_batch = &results[0]; + assert_eq!(result_batch.schema(), output_schema); + assert_eq!(result_batch.num_rows(), 2); + // Position 1 must actually hold the FixedSizeList embedding — + // not the score column reinterpreted by a permissive cast. + let embedding = result_batch + .column(1) + .as_any() + .downcast_ref::() + .expect("position 1 should be a FixedSizeList embedding"); + assert_eq!(embedding.value_length(), 4); + assert_eq!(embedding.null_count(), 0); + } + + /// If the input batch is missing a non-embedding column required by + /// the table schema, we should return a clear error rather than + /// silently producing a malformed batch. + #[tokio::test] + async fn test_with_embeddings_scannable_missing_required_column() { + let input_schema = + Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)])); + let batch = RecordBatch::try_new( + input_schema, + vec![Arc::new(StringArray::from(vec!["hello", "world"])) as ArrayRef], + ) + .unwrap(); + + let mock_embedding: Arc = Arc::new(MockEmbed::new("mock", 4)); + let embedding_def = EmbeddingDefinition::new("text", "mock", Some("text_vec")); + + let output_schema = Arc::new(Schema::new(vec![ + Field::new("text", DataType::Utf8, false), + Field::new( + "text_vec", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + 4, + ), + false, + ), + Field::new("score", DataType::Float64, true), + ])); + + let mut scannable = WithEmbeddingsScannable::with_schema( + Box::new(batch), + vec![(embedding_def, mock_embedding)], + output_schema, + ) + .unwrap(); + + let stream = scannable.scan_as_stream(); + let results: Result> = stream.try_collect().await; + let err = results.expect_err("expected an error"); + assert!( + matches!(&err, Error::InvalidInput { message } if message.contains("score")), + "expected InvalidInput about missing 'score' column, got: {err:?}" + ); + } } } diff --git a/rust/lancedb/src/table/add_data.rs b/rust/lancedb/src/table/add_data.rs index 1c4b4bdf3..be8ec28ad 100644 --- a/rust/lancedb/src/table/add_data.rs +++ b/rust/lancedb/src/table/add_data.rs @@ -268,7 +268,9 @@ mod tests { }; use crate::query::{ExecutableQuery, QueryBase, Select}; use crate::table::add_data::NaNVectorBehavior; - use crate::table::{ColumnDefinition, ColumnKind, Table, TableDefinition, WriteOptions}; + use crate::table::{ + ColumnDefinition, ColumnKind, NewColumnTransform, Table, TableDefinition, WriteOptions, + }; use crate::test_utils::TestCustomError; use crate::test_utils::embeddings::MockEmbed; @@ -518,6 +520,225 @@ mod tests { } } + /// Regression test for https://github.com/lancedb/lancedb/issues/3136. + /// + /// When a column is added via `add_columns` AFTER an embedding column, + /// the table schema becomes `[..., embedding, extra]`. Subsequent + /// `table.add()` calls used to fail with a CastError because columns + /// were matched positionally rather than by name. + #[tokio::test] + async fn test_add_with_embeddings_after_add_columns() { + let registry = Arc::new(MemoryRegistry::new()); + let mock_embedding: Arc = Arc::new(MockEmbed::new("mock", 4)); + registry.register("mock", mock_embedding).unwrap(); + + let conn = connect("memory://") + .embedding_registry(registry) + .execute() + .await + .unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("text", DataType::Utf8, false), + Field::new( + "text_vec", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4), + false, + ), + ])); + + let embedding_def = EmbeddingDefinition::new("text", "mock", Some("text_vec")); + let table_def = TableDefinition::new( + schema.clone(), + vec![ + ColumnDefinition { + kind: ColumnKind::Physical, + }, + ColumnDefinition { + kind: ColumnKind::Embedding(embedding_def), + }, + ], + ); + let rich_schema = table_def.into_rich_schema(); + + let table = conn + .create_empty_table("embed_evol_test", rich_schema) + .execute() + .await + .unwrap(); + + // Seed a row so add_columns has data to compute against. + let seed_batch = record_batch!(("text", Utf8, ["hello"])).unwrap(); + table.add(seed_batch).execute().await.unwrap(); + + // Add a new physical column AFTER the embedding column. + table + .add_columns( + NewColumnTransform::SqlExpressions(vec![("score".into(), "42.0".into())]), + None, + ) + .await + .unwrap(); + + // Now add data including the new column but WITHOUT the embedding. + // The input batch column order is [text, score]; after computing the + // embedding it becomes [text, score, text_vec], but the table schema + // is [text, text_vec, score]. Columns must be matched by name. + let new_schema = Arc::new(Schema::new(vec![ + Field::new("text", DataType::Utf8, false), + Field::new("score", DataType::Float64, true), + ])); + let new_batch = RecordBatch::try_new( + new_schema, + vec![ + Arc::new(arrow_array::StringArray::from(vec!["foo", "bar"])), + Arc::new(arrow_array::Float64Array::from(vec![1.0, 2.0])), + ], + ) + .unwrap(); + table.add(new_batch).execute().await.unwrap(); + + assert_eq!(table.count_rows(None).await.unwrap(), 3); + + let results: Vec = table + .query() + .select(Select::columns(&["text", "text_vec", "score"])) + .execute() + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 3); + for batch in &results { + // text_vec must be populated for the newly added rows too. + assert_eq!(batch.column(1).null_count(), 0); + } + } + + /// Like `test_add_with_embeddings_after_add_columns`, but the column + /// added after the embedding is a nested struct rather than a scalar. + /// Verifies that name-based column matching also works when the + /// post-embedding column has a complex Arrow type. + #[tokio::test] + async fn test_add_with_embeddings_after_add_nested_columns() { + let registry = Arc::new(MemoryRegistry::new()); + let mock_embedding: Arc = Arc::new(MockEmbed::new("mock", 4)); + registry.register("mock", mock_embedding).unwrap(); + + let conn = connect("memory://") + .embedding_registry(registry) + .execute() + .await + .unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("text", DataType::Utf8, false), + Field::new( + "text_vec", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4), + false, + ), + ])); + + let embedding_def = EmbeddingDefinition::new("text", "mock", Some("text_vec")); + let table_def = TableDefinition::new( + schema, + vec![ + ColumnDefinition { + kind: ColumnKind::Physical, + }, + ColumnDefinition { + kind: ColumnKind::Embedding(embedding_def), + }, + ], + ); + let rich_schema = table_def.into_rich_schema(); + + let table = conn + .create_empty_table("embed_nested_test", rich_schema) + .execute() + .await + .unwrap(); + + let seed_batch = record_batch!(("text", Utf8, ["hello"])).unwrap(); + table.add(seed_batch).execute().await.unwrap(); + + // Add a STRUCT column after the embedding column. + let meta_struct = DataType::Struct( + vec![ + Field::new("source", DataType::Utf8, true), + Field::new("score", DataType::Float64, true), + ] + .into(), + ); + let nested_schema = Arc::new(Schema::new(vec![Field::new( + "meta", + meta_struct.clone(), + true, + )])); + table + .add_columns(NewColumnTransform::AllNulls(nested_schema), None) + .await + .unwrap(); + + // Insert with the nested struct present but the embedding column + // absent. The computed batch is [text, meta, text_vec], but the + // table schema is [text, text_vec, meta] — only name-based matching + // can put `meta` (a struct) in the right slot. + let source = Arc::new(arrow_array::StringArray::from(vec!["foo", "bar"])); + let score = Arc::new(arrow_array::Float64Array::from(vec![1.0, 2.0])); + let meta = Arc::new(arrow_array::StructArray::from(vec![ + ( + Arc::new(Field::new("source", DataType::Utf8, true)), + source as Arc, + ), + ( + Arc::new(Field::new("score", DataType::Float64, true)), + score as Arc, + ), + ])); + let new_schema = Arc::new(Schema::new(vec![ + Field::new("text", DataType::Utf8, false), + Field::new("meta", meta_struct, true), + ])); + let new_batch = RecordBatch::try_new( + new_schema, + vec![ + Arc::new(arrow_array::StringArray::from(vec!["foo", "bar"])), + meta, + ], + ) + .unwrap(); + table.add(new_batch).execute().await.unwrap(); + + assert_eq!(table.count_rows(None).await.unwrap(), 3); + + let results: Vec = table + .query() + .select(Select::columns(&["text", "text_vec", "meta"])) + .execute() + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 3); + for batch in &results { + assert_eq!(batch.schema().field(2).name(), "meta"); + assert!(matches!( + batch.schema().field(2).data_type(), + DataType::Struct(_) + )); + // text_vec must be populated for the newly added rows too. + assert_eq!(batch.column(1).null_count(), 0); + } + } + #[tokio::test] async fn test_add_casts_to_table_schema() { let table_schema = Arc::new(Schema::new(vec![