From 0e486511fa45711c78f1c4df011541798daee221 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 23 Feb 2026 14:43:31 -0800 Subject: [PATCH] feat: hook up new writer for insert (#3029) This hooks up a new writer implementation for the `add()` method. The main immediate benefit is it allows streaming requests to remote tables, and at the same time allowing retries for most inputs. In NodeJS, we always convert the data to `Vec`, so it's always retry-able. For Python, all are retry-able, except `Iterator` and `pa.RecordBatchReader`, which can only be consumed once. Some, like `pa.datasets.Dataset` are retry-able *and* streaming. A lot of the changes here are to make the new DataFusion write pipeline maintain the same behavior as the existing Python-based preprocessing, such as: * casting input data to target schema * rejecting NaN values if `on_bad_vectors="error"` * applying embedding functions. In future PRs, we'll enhance these by moving the embedding calls into DataFusion and making sure we parallelize them. See: https://github.com/lancedb/lancedb/issues/3048 --------- Co-authored-by: Claude Opus 4.6 --- nodejs/src/table.rs | 11 + python/python/lancedb/arrow.py | 31 + python/python/lancedb/scannable.py | 214 ++++++ python/python/lancedb/table.py | 37 +- python/python/tests/test_table.py | 53 +- python/src/table.rs | 9 +- python/src/table/scannable.rs | 145 ++++ rust/lancedb/src/arrow.rs | 4 +- rust/lancedb/src/data/scannable.rs | 45 +- rust/lancedb/src/error.rs | 21 + rust/lancedb/src/remote/client.rs | 48 +- rust/lancedb/src/remote/retry.rs | 139 +++- rust/lancedb/src/remote/table.rs | 640 ++++++++++-------- rust/lancedb/src/remote/table/insert.rs | 90 ++- rust/lancedb/src/table.rs | 57 +- rust/lancedb/src/table/add_data.rs | 373 +++++++++- rust/lancedb/src/table/datafusion.rs | 3 + rust/lancedb/src/table/datafusion/cast.rs | 498 ++++++++++++++ .../src/table/datafusion/reject_nan.rs | 269 ++++++++ .../src/table/datafusion/scannable_exec.rs | 118 ++++ 20 files changed, 2446 insertions(+), 359 deletions(-) create mode 100644 python/python/lancedb/scannable.py create mode 100644 python/src/table/scannable.rs create mode 100644 rust/lancedb/src/table/datafusion/cast.rs create mode 100644 rust/lancedb/src/table/datafusion/reject_nan.rs create mode 100644 rust/lancedb/src/table/datafusion/scannable_exec.rs diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index 272cc878c..ed39b9bae 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -71,6 +71,17 @@ impl Table { pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result { let batches = ipc_file_to_batches(buf.to_vec()) .map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?; + let batches = batches + .into_iter() + .map(|batch| { + batch.map_err(|e| { + napi::Error::from_reason(format!( + "Failed to read record batch from IPC file: {}", + e + )) + }) + }) + .collect::>>()?; let mut op = self.inner_ref()?.add(batches); op = if mode == "append" { diff --git a/python/python/lancedb/arrow.py b/python/python/lancedb/arrow.py index ccb62ec55..404efafea 100644 --- a/python/python/lancedb/arrow.py +++ b/python/python/lancedb/arrow.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors +from functools import singledispatch from typing import List, Optional, Tuple, Union +from lancedb.pydantic import LanceModel, model_to_dict import pyarrow as pa from ._lancedb import RecordBatchStream @@ -80,3 +82,32 @@ def peek_reader( yield from reader return batch, pa.RecordBatchReader.from_batches(batch.schema, all_batches()) + + +@singledispatch +def to_arrow(data) -> pa.Table: + """Convert a single data object to a pa.Table.""" + raise NotImplementedError(f"to_arrow not implemented for type {type(data)}") + + +@to_arrow.register(pa.RecordBatch) +def _arrow_from_batch(data: pa.RecordBatch) -> pa.Table: + return pa.Table.from_batches([data]) + + +@to_arrow.register(pa.Table) +def _arrow_from_table(data: pa.Table) -> pa.Table: + return data + + +@to_arrow.register(list) +def _arrow_from_list(data: list) -> pa.Table: + if not data: + raise ValueError("Cannot create table from empty list without a schema") + + if isinstance(data[0], LanceModel): + schema = data[0].__class__.to_arrow_schema() + dicts = [model_to_dict(d) for d in data] + return pa.Table.from_pylist(dicts, schema=schema) + + return pa.Table.from_pylist(data) diff --git a/python/python/lancedb/scannable.py b/python/python/lancedb/scannable.py new file mode 100644 index 000000000..beccc8f2e --- /dev/null +++ b/python/python/lancedb/scannable.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +from dataclasses import dataclass +from functools import singledispatch +import sys +from typing import Callable, Iterator, Optional +from lancedb.arrow import to_arrow +import pyarrow as pa +import pyarrow.dataset as ds + +from .pydantic import LanceModel + + +@dataclass +class Scannable: + schema: pa.Schema + num_rows: Optional[int] + # Factory function to create a new reader each time (supports re-scanning) + reader: Callable[[], pa.RecordBatchReader] + # Whether reader can be called more than once. For example, an iterator can + # only be consumed once, while a DataFrame can be converted to a new reader + # each time. + rescannable: bool = True + + +@singledispatch +def to_scannable(data) -> Scannable: + # Fallback: try iterable protocol + if hasattr(data, "__iter__"): + return _from_iterable(iter(data)) + raise NotImplementedError(f"to_scannable not implemented for type {type(data)}") + + +@to_scannable.register(pa.RecordBatchReader) +def _from_reader(data: pa.RecordBatchReader) -> Scannable: + # RecordBatchReader can only be consumed once - not rescannable + return Scannable( + schema=data.schema, num_rows=None, reader=lambda: data, rescannable=False + ) + + +@to_scannable.register(pa.RecordBatch) +def _from_batch(data: pa.RecordBatch) -> Scannable: + return Scannable( + schema=data.schema, + num_rows=data.num_rows, + reader=lambda: pa.RecordBatchReader.from_batches(data.schema, [data]), + ) + + +@to_scannable.register(pa.Table) +def _from_table(data: pa.Table) -> Scannable: + return Scannable(schema=data.schema, num_rows=data.num_rows, reader=data.to_reader) + + +@to_scannable.register(ds.Dataset) +def _from_dataset(data: ds.Dataset) -> Scannable: + return Scannable( + schema=data.schema, + num_rows=data.count_rows(), + reader=lambda: data.scanner().to_reader(), + ) + + +@to_scannable.register(ds.Scanner) +def _from_scanner(data: ds.Scanner) -> Scannable: + # Scanner can only be consumed once - not rescannable + return Scannable( + schema=data.projected_schema, + num_rows=None, + reader=data.to_reader, + rescannable=False, + ) + + +@to_scannable.register(list) +def _from_list(data: list) -> Scannable: + if not data: + raise ValueError("Cannot create table from empty list without a schema") + table = to_arrow(data) + return Scannable( + schema=table.schema, num_rows=table.num_rows, reader=table.to_reader + ) + + +@to_scannable.register(dict) +def _from_dict(data: dict) -> Scannable: + raise ValueError("Cannot add a single dictionary to a table. Use a list.") + + +@to_scannable.register(LanceModel) +def _from_lance_model(data: LanceModel) -> Scannable: + raise ValueError("Cannot add a single LanceModel to a table. Use a list.") + + +def _from_iterable(data: Iterator) -> Scannable: + first_item = next(data, None) + if first_item is None: + raise ValueError("Cannot create table from empty iterator") + first = to_arrow(first_item) + schema = first.schema + + def iter(): + yield from first.to_batches() + for item in data: + batch = to_arrow(item) + if batch.schema != schema: + try: + batch = batch.cast(schema) + except pa.lib.ArrowInvalid: + raise ValueError( + f"Input iterator yielded a batch with schema that " + f"does not match the schema of other batches.\n" + f"Expected:\n{schema}\nGot:\n{batch.schema}" + ) + yield from batch.to_batches() + + reader = pa.RecordBatchReader.from_batches(schema, iter()) + return to_scannable(reader) + + +_registered_modules: set[str] = set() + + +def _register_optional_converters(): + """Register converters for optional dependencies that are already imported.""" + + if "pandas" in sys.modules and "pandas" not in _registered_modules: + _registered_modules.add("pandas") + import pandas as pd + + @to_arrow.register(pd.DataFrame) + def _arrow_from_pandas(data: pd.DataFrame) -> pa.Table: + table = pa.Table.from_pandas(data, preserve_index=False) + return table.replace_schema_metadata(None) + + @to_scannable.register(pd.DataFrame) + def _from_pandas(data: pd.DataFrame) -> Scannable: + return to_scannable(_arrow_from_pandas(data)) + + if "polars" in sys.modules and "polars" not in _registered_modules: + _registered_modules.add("polars") + import polars as pl + + @to_arrow.register(pl.DataFrame) + def _arrow_from_polars(data: pl.DataFrame) -> pa.Table: + return data.to_arrow() + + @to_scannable.register(pl.DataFrame) + def _from_polars(data: pl.DataFrame) -> Scannable: + arrow = data.to_arrow() + return Scannable( + schema=arrow.schema, num_rows=len(data), reader=arrow.to_reader + ) + + @to_scannable.register(pl.LazyFrame) + def _from_polars_lazy(data: pl.LazyFrame) -> Scannable: + arrow = data.collect().to_arrow() + return Scannable( + schema=arrow.schema, num_rows=arrow.num_rows, reader=arrow.to_reader + ) + + if "datasets" in sys.modules and "datasets" not in _registered_modules: + _registered_modules.add("datasets") + from datasets import Dataset as HFDataset + from datasets import DatasetDict as HFDatasetDict + + @to_scannable.register(HFDataset) + def _from_hf_dataset(data: HFDataset) -> Scannable: + table = data.data.table # Access underlying Arrow table + return Scannable( + schema=table.schema, num_rows=len(data), reader=table.to_reader + ) + + @to_scannable.register(HFDatasetDict) + def _from_hf_dataset_dict(data: HFDatasetDict) -> Scannable: + # HuggingFace DatasetDict: combine all splits with a 'split' column + schema = data[list(data.keys())[0]].features.arrow_schema + if "split" not in schema.names: + schema = schema.append(pa.field("split", pa.string())) + + def gen(): + for split_name, dataset in data.items(): + for batch in dataset.data.to_batches(): + split_arr = pa.array( + [split_name] * len(batch), type=pa.string() + ) + yield pa.RecordBatch.from_arrays( + list(batch.columns) + [split_arr], schema=schema + ) + + total_rows = sum(len(dataset) for dataset in data.values()) + return Scannable( + schema=schema, + num_rows=total_rows, + reader=lambda: pa.RecordBatchReader.from_batches(schema, gen()), + ) + + if "lance" in sys.modules and "lance" not in _registered_modules: + _registered_modules.add("lance") + import lance + + @to_scannable.register(lance.LanceDataset) + def _from_lance(data: lance.LanceDataset) -> Scannable: + return Scannable( + schema=data.schema, + num_rows=data.count_rows(), + reader=lambda: data.scanner().to_reader(), + ) + + +# Register on module load +_register_optional_converters() diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index b006586a8..029d234f9 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -25,6 +25,8 @@ from typing import ( ) from urllib.parse import urlparse +from lancedb.scannable import _register_optional_converters, to_scannable + from . import __version__ from lancedb.arrow import peek_reader from lancedb.background_loop import LOOP @@ -3727,18 +3729,31 @@ class AsyncTable: on_bad_vectors = "error" if fill_value is None: fill_value = 0.0 - data = _sanitize_data( - data, - schema, - metadata=schema.metadata, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - allow_subschema=True, - ) - if isinstance(data, pa.Table): - data = data.to_reader() - return await self._inner.add(data, mode or "append") + # _santitize_data is an old code path, but we will use it until the + # new code path is ready. + if on_bad_vectors != "error" or ( + schema.metadata is not None and b"embedding_functions" in schema.metadata + ): + data = _sanitize_data( + data, + schema, + metadata=schema.metadata, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + allow_subschema=True, + ) + _register_optional_converters() + data = to_scannable(data) + try: + return await self._inner.add(data, mode or "append") + except RuntimeError as e: + if "Cast error" in str(e): + raise ValueError(e) + elif "Vector column contains NaN" in str(e): + raise ValueError(e) + else: + raise def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: """ diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index eecfd40e6..39208c0d8 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -810,7 +810,7 @@ def test_create_index_name_and_train_parameters( ) -def test_add_with_nans(mem_db: DBConnection): +def test_create_with_nans(mem_db: DBConnection): # by default we raise an error on bad input vectors bad_data = [ {"vector": [np.nan], "item": "bar", "price": 20.0}, @@ -854,6 +854,57 @@ def test_add_with_nans(mem_db: DBConnection): assert np.allclose(v, np.array([0.0, 0.0])) +def test_add_with_nans(mem_db: DBConnection): + schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), 2), nullable=True), + pa.field("item", pa.string(), nullable=True), + pa.field("price", pa.float64(), nullable=False), + ], + ) + table = mem_db.create_table("test", schema=schema) + # by default we raise an error on bad input vectors + bad_data = [ + {"vector": [np.nan], "item": "bar", "price": 20.0}, + {"vector": [5], "item": "bar", "price": 20.0}, + {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, + {"vector": [np.nan, 5.0], "item": "bar", "price": 20.0}, + ] + for row in bad_data: + with pytest.raises(ValueError): + table.add( + data=[row], + ) + + table.add( + [ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [2.1, 4.1], "item": "foo", "price": 9.0}, + {"vector": [np.nan], "item": "bar", "price": 20.0}, + {"vector": [5], "item": "bar", "price": 20.0}, + {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, + ], + on_bad_vectors="drop", + ) + assert len(table) == 2 + table.delete("true") + + # We can fill bad input with some value + table.add( + data=[ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [np.nan], "item": "bar", "price": 20.0}, + {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, + ], + on_bad_vectors="fill", + fill_value=0.0, + ) + assert len(table) == 3 + arrow_tbl = table.search().where("item == 'bar'").to_arrow() + v = arrow_tbl["vector"].to_pylist()[0] + assert np.allclose(v, np.array([0.0, 0.0])) + + def test_restore(mem_db: DBConnection): table = mem_db.create_table( "my_table", diff --git a/python/src/table.rs b/python/src/table.rs index 353b22ff0..31279847d 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -7,6 +7,7 @@ use crate::{ error::PythonErrorExt, index::{extract_index_params, IndexConfig}, query::{Query, TakeQuery}, + table::scannable::PyScannable, }; use arrow::{ datatypes::{DataType, Schema}, @@ -25,6 +26,8 @@ use pyo3::{ }; use pyo3_async_runtimes::tokio::future_into_py; +mod scannable; + /// Statistics about a compaction operation. #[pyclass(get_all)] #[derive(Clone, Debug)] @@ -293,12 +296,10 @@ impl Table { pub fn add<'a>( self_: PyRef<'a, Self>, - data: Bound<'_, PyAny>, + data: PyScannable, mode: String, ) -> PyResult> { - let batches: Box = - Box::new(ArrowArrayStreamReader::from_pyarrow_bound(&data)?); - let mut op = self_.inner_ref()?.add(batches); + let mut op = self_.inner_ref()?.add(data); if mode == "append" { op = op.mode(AddDataMode::Append); } else if mode == "overwrite" { diff --git a/python/src/table/scannable.rs b/python/src/table/scannable.rs new file mode 100644 index 000000000..e2bfc1295 --- /dev/null +++ b/python/src/table/scannable.rs @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::sync::Arc; + +use arrow::{ + datatypes::{Schema, SchemaRef}, + ffi_stream::ArrowArrayStreamReader, + pyarrow::{FromPyArrow, PyArrowType}, +}; +use futures::StreamExt; +use lancedb::{ + arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}, + data::scannable::Scannable, + Error, +}; +use pyo3::{types::PyAnyMethods, FromPyObject, Py, PyAny, Python}; + +/// Adapter that implements Scannable for a Python reader factory callable. +/// +/// This holds a Python callable that returns a RecordBatchReader when called. +/// For rescannable sources, the callable can be invoked multiple times to +/// get fresh readers. +pub struct PyScannable { + /// Python callable that returns a RecordBatchReader + reader_factory: Py, + schema: SchemaRef, + num_rows: Option, + rescannable: bool, +} + +impl std::fmt::Debug for PyScannable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PyScannable") + .field("schema", &self.schema) + .field("num_rows", &self.num_rows) + .field("rescannable", &self.rescannable) + .finish() + } +} + +impl Scannable for PyScannable { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn scan_as_stream(&mut self) -> SendableRecordBatchStream { + let reader: Result = { + Python::attach(|py| { + let result = + self.reader_factory + .call0(py) + .map_err(|e| lancedb::Error::Runtime { + message: format!("Python reader factory failed: {}", e), + })?; + ArrowArrayStreamReader::from_pyarrow_bound(result.bind(py)).map_err(|e| { + lancedb::Error::Runtime { + message: format!("Failed to create Arrow reader from Python: {}", e), + } + }) + }) + }; + + // Reader is blocking but stream is non-blocking, so we need to spawn a task to pull. + let (tx, rx) = tokio::sync::mpsc::channel(1); + + let join_handle = tokio::task::spawn_blocking(move || { + let reader = match reader { + Ok(reader) => reader, + Err(e) => { + let _ = tx.blocking_send(Err(e)); + return; + } + }; + for batch in reader { + match batch { + Ok(batch) => { + if tx.blocking_send(Ok(batch)).is_err() { + // Receiver dropped, stop processing + break; + } + } + Err(source) => { + let _ = tx.blocking_send(Err(Error::Arrow { source })); + break; + } + } + } + }); + + let schema = self.schema.clone(); + let stream = futures::stream::unfold( + (rx, Some(join_handle)), + |(mut rx, join_handle)| async move { + match rx.recv().await { + Some(Ok(batch)) => Some((Ok(batch), (rx, join_handle))), + Some(Err(e)) => Some((Err(e), (rx, join_handle))), + None => { + // Channel closed. Check if the task panicked — a panic + // drops the sender without sending an error, so without + // this check we'd silently return a truncated stream. + if let Some(handle) = join_handle { + if let Err(join_err) = handle.await { + return Some(( + Err(Error::Runtime { + message: format!("Reader task panicked: {}", join_err), + }), + (rx, None), + )); + } + } + None + } + } + }, + ); + Box::pin(SimpleRecordBatchStream::new(stream.fuse(), schema)) + } + + fn num_rows(&self) -> Option { + self.num_rows + } + + fn rescannable(&self) -> bool { + self.rescannable + } +} + +impl<'py> FromPyObject<'py> for PyScannable { + fn extract_bound(ob: &pyo3::Bound<'py, PyAny>) -> pyo3::PyResult { + // Convert from Scannable dataclass. + let schema: PyArrowType = ob.getattr("schema")?.extract()?; + let schema = Arc::new(schema.0); + let num_rows: Option = ob.getattr("num_rows")?.extract()?; + let rescannable: bool = ob.getattr("rescannable")?.extract()?; + let reader_factory: Py = ob.getattr("reader")?.unbind(); + + Ok(Self { + schema, + reader_factory, + num_rows, + rescannable, + }) + } +} diff --git a/rust/lancedb/src/arrow.rs b/rust/lancedb/src/arrow.rs index df6e51a34..a8c710865 100644 --- a/rust/lancedb/src/arrow.rs +++ b/rust/lancedb/src/arrow.rs @@ -155,9 +155,7 @@ impl IntoArrowStream for SendableRecordBatchStream { impl IntoArrowStream for datafusion_physical_plan::SendableRecordBatchStream { fn into_arrow(self) -> Result { let schema = self.schema(); - let stream = self.map_err(|df_err| Error::Runtime { - message: df_err.to_string(), - }); + let stream = self.map_err(|df_err| df_err.into()); Ok(Box::pin(SimpleRecordBatchStream::new(stream, schema))) } } diff --git a/rust/lancedb/src/data/scannable.rs b/rust/lancedb/src/data/scannable.rs index 350742bd7..8248a3202 100644 --- a/rust/lancedb/src/data/scannable.rs +++ b/rust/lancedb/src/data/scannable.rs @@ -9,7 +9,7 @@ use std::sync::Arc; -use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader}; +use arrow_array::{ArrayRef, RecordBatch, RecordBatchIterator, RecordBatchReader}; use arrow_schema::{ArrowError, SchemaRef}; use async_trait::async_trait; use futures::stream::once; @@ -228,6 +228,19 @@ impl WithEmbeddingsScannable { let table_definition = TableDefinition::new(output_schema, column_definitions); let output_schema = table_definition.into_rich_schema(); + Self::with_schema(inner, embeddings, output_schema) + } + + /// Create a WithEmbeddingsScannable with a specific output schema. + /// + /// Use this when the table schema is already known (e.g. during add) to + /// avoid nullability mismatches between the embedding function's declared + /// type and the table's stored type. + pub fn with_schema( + inner: Box, + embeddings: Vec<(EmbeddingDefinition, Arc)>, + output_schema: SchemaRef, + ) -> Result { Ok(Self { inner, embeddings, @@ -245,9 +258,11 @@ impl Scannable for WithEmbeddingsScannable { let inner_stream = self.inner.scan_as_stream(); let embeddings = self.embeddings.clone(); let output_schema = self.output_schema.clone(); + let stream_schema = output_schema.clone(); let mapped_stream = inner_stream.then(move |batch_result| { let embeddings = embeddings.clone(); + let output_schema = output_schema.clone(); async move { let batch = batch_result?; let result = tokio::task::spawn_blocking(move || { @@ -257,12 +272,29 @@ impl Scannable for WithEmbeddingsScannable { .map_err(|e| Error::Runtime { message: format!("Task panicked during embedding computation: {}", e), })??; + // Cast columns to match the declared output schema. The data is + // identical but field metadata (e.g. nested nullability) may + // differ between the embedding function output and the table. + let columns: Vec = result + .columns() + .iter() + .enumerate() + .map(|(i, col)| { + let target_type = output_schema.field(i).data_type(); + if col.data_type() == target_type { + Ok(col.clone()) + } else { + arrow_cast::cast(col, target_type).map_err(Error::from) + } + }) + .collect::>()?; + let result = RecordBatch::try_new(output_schema, columns)?; Ok(result) } }); Box::pin(SimpleRecordBatchStream { - schema: output_schema, + schema: stream_schema, stream: mapped_stream, }) } @@ -303,8 +335,13 @@ pub fn scannable_with_embeddings( } if !embeddings.is_empty() { - return Ok(Box::new(WithEmbeddingsScannable::try_new( - inner, embeddings, + // Use the table's schema so embedding column types (including nested + // nullability) match what's stored, avoiding mismatches with the + // embedding function's declared dest_type. + return Ok(Box::new(WithEmbeddingsScannable::with_schema( + inner, + embeddings, + table_definition.schema.clone(), )?)); } } diff --git a/rust/lancedb/src/error.rs b/rust/lancedb/src/error.rs index 55e2350ac..d04566d7d 100644 --- a/rust/lancedb/src/error.rs +++ b/rust/lancedb/src/error.rs @@ -4,6 +4,7 @@ use std::sync::PoisonError; use arrow_schema::ArrowError; +use datafusion_common::DataFusionError; use snafu::Snafu; pub(crate) type BoxError = Box; @@ -105,6 +106,26 @@ impl From for Error { } } +impl From for Error { + fn from(source: DataFusionError) -> Self { + match source { + DataFusionError::ArrowError(source, _) => (*source).into(), + DataFusionError::External(source) => match source.downcast::() { + Ok(e) => *e, + Err(source) => match source.downcast::() { + Ok(arrow_error) => Self::Arrow { + source: *arrow_error, + }, + Err(source) => Self::External { source }, + }, + }, + other => Self::External { + source: Box::new(other), + }, + } + } +} + impl From for Error { fn from(source: lance::Error) -> Self { // Try to unwrap external errors that were wrapped by lance diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 73d2f14da..de71727d9 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -724,12 +724,58 @@ pub mod test_utils { } } + /// Consume a reqwest body into bytes, returning an error if the body + /// stream fails. This is used by MockSender to materialize streaming + /// bodies so that data pipeline errors (e.g. NaN rejection) are triggered + /// during mock sends just as they would be during a real HTTP upload. + pub async fn try_collect_body(body: reqwest::Body) -> std::result::Result, String> { + use http_body::Body; + use std::pin::Pin; + + let mut body = body; + let mut data = Vec::new(); + let mut body_pin = Pin::new(&mut body); + while let Some(frame) = futures::StreamExt::next(&mut futures::stream::poll_fn(|cx| { + body_pin.as_mut().poll_frame(cx) + })) + .await + { + match frame { + Ok(frame) => { + if let Some(bytes) = frame.data_ref() { + data.extend_from_slice(bytes); + } + } + Err(e) => return Err(e.to_string()), + } + } + Ok(data) + } + impl HttpSend for MockSender { async fn send( &self, _client: &reqwest::Client, - request: reqwest::Request, + mut request: reqwest::Request, ) -> reqwest::Result { + // Consume any streaming body to materialize it into bytes. + // This triggers data pipeline errors (e.g. NaN rejection) that + // would otherwise only fire when a real HTTP client reads the body. + if let Some(body) = request.body_mut().take() { + match try_collect_body(body).await { + Ok(bytes) => { + *request.body_mut() = Some(reqwest::Body::from(bytes)); + } + Err(msg) => { + // Simulate a failed request by returning a 500 response. + return Ok(http::Response::builder() + .status(500) + .body(msg) + .unwrap() + .into()); + } + } + } let response = (self.f)(request); Ok(response) } diff --git a/rust/lancedb/src/remote/retry.rs b/rust/lancedb/src/remote/retry.rs index a20d91669..8f34195f6 100644 --- a/rust/lancedb/src/remote/retry.rs +++ b/rust/lancedb/src/remote/retry.rs @@ -60,6 +60,34 @@ impl<'a> RetryCounter<'a> { self.check_out_of_retries(Box::new(source), status_code) } + /// Increment the appropriate failure counter based on the error type. + /// + /// For `Error::Http` whose source is a connect error, increments + /// `connect_failures`. For read errors (`is_body` or `is_decode`), + /// increments `read_failures`. For all other errors, increments + /// `request_failures`. Calls `check_out_of_retries` to enforce global limits. + pub fn increment_from_error(&mut self, source: crate::Error) -> crate::Result<()> { + let reqwest_err = match &source { + crate::Error::Http { source, .. } => source.downcast_ref::(), + _ => None, + }; + + if reqwest_err.is_some_and(|e| e.is_connect()) { + self.connect_failures += 1; + } else if reqwest_err.is_some_and(|e| e.is_body() || e.is_decode()) { + self.read_failures += 1; + } else { + self.request_failures += 1; + } + + let status_code = if let crate::Error::Http { status_code, .. } = &source { + *status_code + } else { + None + }; + self.check_out_of_retries(Box::new(source), status_code) + } + pub fn increment_connect_failures(&mut self, source: reqwest::Error) -> crate::Result<()> { self.connect_failures += 1; let status_code = source.status(); @@ -77,7 +105,7 @@ impl<'a> RetryCounter<'a> { let jitter = rand::random::() * self.config.backoff_jitter; let sleep_time = Duration::from_secs_f32(backoff + jitter); debug!( - "Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}", + "Retrying request {:?} ({}/{} connect, {}/{} request, {}/{} read) in {:?}", self.request_id, self.connect_failures, self.config.connect_retries, @@ -91,6 +119,115 @@ impl<'a> RetryCounter<'a> { } } +#[cfg(test)] +mod tests { + use super::*; + + fn test_config() -> ResolvedRetryConfig { + ResolvedRetryConfig { + retries: 3, + connect_retries: 2, + read_retries: 3, + backoff_factor: 0.0, + backoff_jitter: 0.0, + statuses: vec![reqwest::StatusCode::BAD_GATEWAY], + } + } + + /// Get a real reqwest connect error by trying to connect to a refused port. + async fn make_connect_error() -> reqwest::Error { + // Port 1 is almost always refused/unavailable. + reqwest::Client::new() + .get("http://127.0.0.1:1") + .send() + .await + .unwrap_err() + } + + #[tokio::test] + async fn test_increment_from_error_connect() { + let config = test_config(); + let mut counter = RetryCounter::new(&config, "test".to_string()); + + let connect_err = make_connect_error().await; + assert!(connect_err.is_connect()); + + let http_err = crate::Error::Http { + source: Box::new(connect_err), + request_id: "test".to_string(), + status_code: None, + }; + + // First connect failure: should be ok (1 < 2) + counter.increment_from_error(http_err).unwrap(); + assert_eq!(counter.connect_failures, 1); + assert_eq!(counter.request_failures, 0); + + // Second connect failure: should hit the limit (2 >= 2) + let connect_err2 = make_connect_error().await; + let http_err2 = crate::Error::Http { + source: Box::new(connect_err2), + request_id: "test".to_string(), + status_code: None, + }; + let result = counter.increment_from_error(http_err2); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + crate::Error::Retry { + connect_failures: 2, + max_connect_failures: 2, + .. + } + )); + } + + #[test] + fn test_increment_from_error_request() { + let config = test_config(); + let mut counter = RetryCounter::new(&config, "test".to_string()); + + let http_err = crate::Error::Http { + source: "bad gateway".into(), + request_id: "test".to_string(), + status_code: Some(reqwest::StatusCode::BAD_GATEWAY), + }; + + counter.increment_from_error(http_err).unwrap(); + assert_eq!(counter.request_failures, 1); + assert_eq!(counter.connect_failures, 0); + } + + #[tokio::test] + async fn test_increment_from_error_respects_global_limits() { + // If request_failures is already at max, a connect error should still + // trigger the global limit check. + let config = test_config(); + let mut counter = RetryCounter::new(&config, "test".to_string()); + counter.request_failures = 3; // at max + + let connect_err = make_connect_error().await; + let http_err = crate::Error::Http { + source: Box::new(connect_err), + request_id: "test".to_string(), + status_code: None, + }; + + // Even though connect_failures would be 1 (under limit of 2), + // request_failures is already at 3 (>= limit of 3), so this should fail. + let result = counter.increment_from_error(http_err); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + crate::Error::Retry { + request_failures: 3, + connect_failures: 1, + .. + } + )); + } +} + #[derive(Debug, Clone)] pub struct ResolvedRetryConfig { pub retries: u8, diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index fcfca3b5a..e2096f856 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -3,17 +3,16 @@ pub mod insert; +use self::insert::RemoteInsertExec; + use super::client::RequestResultExt; use super::client::{HttpSend, RestfulLanceDbClient, Sender}; use super::db::ServerVersion; -use super::util::stream_as_body; use super::ARROW_STREAM_CONTENT_TYPE; -use crate::data::scannable::Scannable; use crate::index::waiter::wait_for_index; use crate::index::Index; use crate::index::IndexStatistics; use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest}; -use crate::remote::util::stream_as_ipc; use crate::table::query::create_multi_vector_plan; use crate::table::AddColumnsResult; use crate::table::AddResult; @@ -23,7 +22,7 @@ use crate::table::DropColumnsResult; use crate::table::MergeResult; use crate::table::Tags; use crate::table::UpdateResult; -use crate::table::{AddDataMode, AnyQuery, Filter, TableStatistics}; +use crate::table::{AnyQuery, Filter, TableStatistics}; use crate::utils::background_cache::BackgroundCache; use crate::utils::{supported_btree_data_type, supported_vector_data_type}; use crate::{ @@ -358,110 +357,6 @@ impl RemoteTable { Ok(res) } - /// Send a request with data from a Scannable source. - /// - /// For rescannable sources, this will retry on retryable errors by re-reading - /// the data. For non-rescannable sources (streams), only a single attempt is made. - async fn send_scannable( - &self, - req_builder: RequestBuilder, - data: &mut dyn Scannable, - ) -> Result<(String, Response)> { - use crate::remote::retry::RetryCounter; - - // Right now, Python and Typescript don't pass down re-scannable data yet. - // So to preserve existing retry behavior, we have to collect data in - // memory for now. Once they expose rescannable data sources, we can remove this. - if !data.rescannable() && self.client.retry_config.retries > 0 { - let mut body = Vec::new(); - stream_as_ipc(data.scan_as_stream())? - .try_for_each(|b| { - body.extend_from_slice(&b); - futures::future::ok(()) - }) - .await?; - let req_builder = req_builder.body(body); - return self.client.send_with_retry(req_builder, None, true).await; - } - - let rescannable = data.rescannable(); - let max_retries = if rescannable { - self.client.retry_config.retries - } else { - 0 - }; - - // Clone the request builder to extract the request id - let tmp_req = req_builder.try_clone().ok_or_else(|| Error::Runtime { - message: "Attempted to retry a request that cannot be cloned".to_string(), - })?; - let (_, r) = tmp_req.build_split(); - let mut r = r.map_err(|e| Error::Runtime { - message: format!("Failed to build request: {}", e), - })?; - let request_id = self.client.extract_request_id(&mut r); - let mut retry_counter = RetryCounter::new(&self.client.retry_config, request_id.clone()); - - loop { - // Re-read data on each attempt - let stream = data.scan_as_stream(); - let body = stream_as_body(stream)?; - - let mut req_builder = req_builder.try_clone().ok_or_else(|| Error::Runtime { - message: "Attempted to retry a request that cannot be cloned".to_string(), - })?; - req_builder = req_builder.body(body); - - let (c, request) = req_builder.build_split(); - let mut request = request.map_err(|e| Error::Runtime { - message: format!("Failed to build request: {}", e), - })?; - self.client.set_request_id(&mut request, &request_id); - - // Apply dynamic headers - request = self.client.apply_dynamic_headers(request).await?; - - self.client.log_request(&request, &request_id); - - let response = match self.client.sender.send(&c, request).await { - Ok(r) => r, - Err(err) => { - if err.is_connect() { - retry_counter.increment_connect_failures(err)?; - } else if err.is_body() || err.is_decode() { - retry_counter.increment_read_failures(err)?; - } else { - return Err(crate::Error::Http { - source: err.into(), - request_id, - status_code: None, - }); - } - tokio::time::sleep(retry_counter.next_sleep_time()).await; - continue; - } - }; - - let status = response.status(); - - // Check for retryable status codes - if self.client.retry_config.statuses.contains(&status) - && retry_counter.request_failures < max_retries - { - let http_err = crate::Error::Http { - source: format!("Retryable status code: {}", status).into(), - request_id: request_id.clone(), - status_code: Some(status), - }; - retry_counter.increment_request_failures(http_err)?; - tokio::time::sleep(retry_counter.next_sleep_time()).await; - continue; - } - - return Ok((request_id, response)); - } - } - pub(super) async fn handle_table_not_found( table_name: &str, response: reqwest::Response, @@ -1077,39 +972,75 @@ impl BaseTable for RemoteTable { status_code: None, }) } - async fn add(&self, mut add: AddDataBuilder) -> Result { - self.check_mutable().await?; - let mut request = self - .client - .post(&format!("/v1/table/{}/insert/", self.identifier)) - .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE); + async fn add(&self, add: AddDataBuilder) -> Result { + use crate::remote::retry::RetryCounter; - match add.mode { - AddDataMode::Append => {} - AddDataMode::Overwrite => { - request = request.query(&[("mode", "overwrite")]); + self.check_mutable().await?; + + let table_schema = self.schema().await?; + let table_def = TableDefinition::try_from_rich_schema(table_schema.clone())?; + let output = add.into_plan(&table_schema, &table_def)?; + + let mut insert: Arc = Arc::new(RemoteInsertExec::new( + self.name.clone(), + self.identifier.clone(), + self.client.clone(), + output.plan, + output.overwrite, + )); + + let mut retry_counter = + RetryCounter::new(&self.client.retry_config, uuid::Uuid::new_v4().to_string()); + + loop { + let stream = execute_plan(insert.clone(), Default::default())?; + let result: Result> = stream.try_collect().await.map_err(Error::from); + + match result { + Ok(_) => { + let add_result = insert + .as_any() + .downcast_ref::>() + .and_then(|i| i.add_result()) + .unwrap_or(AddResult { version: 0 }); + + if output.overwrite { + self.invalidate_schema_cache(); + } + + return Ok(add_result); + } + Err(err) if output.rescannable => { + let retryable = match &err { + Error::Http { + source, + status_code, + .. + } => { + // Don't retry read errors (is_body/is_decode): the + // server may have committed the write already, and + // without an idempotency key we'd duplicate data. + source + .downcast_ref::() + .is_some_and(|e| e.is_connect()) + || status_code + .is_some_and(|s| self.client.retry_config.statuses.contains(&s)) + } + _ => false, + }; + + if retryable { + retry_counter.increment_from_error(err)?; + tokio::time::sleep(retry_counter.next_sleep_time()).await; + insert = insert.reset_state()?; + continue; + } + + return Err(err); + } + Err(err) => return Err(err), } } - - let (request_id, response) = self.send_scannable(request, &mut *add.data).await?; - let response = self.check_table_response(&request_id, response).await?; - let body = response.text().await.err_to_http(request_id.clone())?; - if body.trim().is_empty() { - // Backward compatible with old servers - return Ok(AddResult { version: 0 }); - } - - let add_response: AddResult = serde_json::from_str(&body).map_err(|e| Error::Http { - source: format!("Failed to parse add response: {}", e).into(), - request_id, - status_code: None, - })?; - - if matches!(add.mode, AddDataMode::Overwrite) { - self.invalidate_schema_cache(); - } - - Ok(add_response) } async fn create_plan( @@ -1756,9 +1687,8 @@ impl BaseTable for RemoteTable { } async fn table_definition(&self) -> Result { - Err(Error::NotSupported { - message: "table_definition is not supported on LanceDB cloud.".into(), - }) + let schema = self.schema().await?; + TableDefinition::try_from_rich_schema(schema) } async fn uri(&self) -> Result { // Check if we already have the location cached @@ -1883,6 +1813,8 @@ mod tests { use super::*; + use crate::table::AddDataMode; + use arrow::{array::AsArray, compute::concat_batches, datatypes::Int32Type}; use arrow_array::{record_batch, Int32Array, RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; @@ -2095,6 +2027,16 @@ mod tests { body } + /// Build a JSON describe response for the given schema. + fn describe_response(schema: &Schema) -> String { + let json_schema = JsonSchema::try_from(schema).unwrap(); + serde_json::to_string(&json!({ + "version": 1, + "schema": json_schema, + })) + .unwrap() + } + #[rstest] #[case("", 0)] #[case("{}", 0)] @@ -2111,30 +2053,35 @@ mod tests { // Clone response_body to give it 'static lifetime for the closure let response_body = response_body.to_string(); + let describe_body = describe_response(&data.schema()); let (sender, receiver) = std::sync::mpsc::channel(); - let table = Table::new_with_handler("my_table", move |mut request| { - if request.url().path() == "/v1/table/my_table/insert/" { - assert_eq!(request.method(), "POST"); - assert!(request - .url() - .query_pairs() - .filter(|(k, _)| k == "mode") - .all(|(_, v)| v == "append")); - assert_eq!( - request.headers().get("Content-Type").unwrap(), - ARROW_STREAM_CONTENT_TYPE - ); - let mut body_out = reqwest::Body::from(Vec::new()); - std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out); - sender.send(body_out).unwrap(); - http::Response::builder() + let table = + Table::new_with_handler("my_table", move |mut request| match request.url().path() { + "/v1/table/my_table/describe/" => http::Response::builder() .status(200) - .body(response_body.clone()) - .unwrap() - } else { - panic!("Unexpected request path: {}", request.url().path()); - } - }); + .body(describe_body.clone()) + .unwrap(), + "/v1/table/my_table/insert/" => { + assert_eq!(request.method(), "POST"); + assert!(request + .url() + .query_pairs() + .filter(|(k, _)| k == "mode") + .all(|(_, v)| v == "append")); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + ARROW_STREAM_CONTENT_TYPE + ); + let mut body_out = reqwest::Body::from(Vec::new()); + std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out); + sender.send(body_out).unwrap(); + http::Response::builder() + .status(200) + .body(response_body.clone()) + .unwrap() + } + path => panic!("Unexpected request path: {}", path), + }); let result = table.add(data.clone()).execute().await.unwrap(); // Check version matches expected value @@ -2157,39 +2104,50 @@ mod tests { ) .unwrap(); + let describe_body = describe_response(&data.schema()); let (sender, receiver) = std::sync::mpsc::channel(); - let table = Table::new_with_handler("my_table", move |mut request| { - assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/v1/table/my_table/insert/"); - assert_eq!( - request - .url() - .query_pairs() - .find(|(k, _)| k == "mode") - .map(|kv| kv.1) - .as_deref(), - Some("overwrite"), - "Expected mode=overwrite" - ); - - assert_eq!( - request.headers().get("Content-Type").unwrap(), - ARROW_STREAM_CONTENT_TYPE - ); - - let mut body_out = reqwest::Body::from(Vec::new()); - std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out); - sender.send(body_out).unwrap(); - - if old_server { - http::Response::builder().status(200).body("").unwrap() - } else { - http::Response::builder() + let table = + Table::new_with_handler("my_table", move |mut request| match request.url().path() { + "/v1/table/my_table/describe/" => http::Response::builder() .status(200) - .body(r#"{"version": 43}"#) - .unwrap() - } - }); + .body(describe_body.clone()) + .unwrap(), + "/v1/table/my_table/insert/" => { + assert_eq!(request.method(), "POST"); + assert_eq!( + request + .url() + .query_pairs() + .find(|(k, _)| k == "mode") + .map(|kv| kv.1) + .as_deref(), + Some("overwrite"), + "Expected mode=overwrite" + ); + + assert_eq!( + request.headers().get("Content-Type").unwrap(), + ARROW_STREAM_CONTENT_TYPE + ); + + let mut body_out = reqwest::Body::from(Vec::new()); + std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out); + sender.send(body_out).unwrap(); + + if old_server { + http::Response::builder() + .status(200) + .body("".to_string()) + .unwrap() + } else { + http::Response::builder() + .status(200) + .body(r#"{"version": 43}"#.to_string()) + .unwrap() + } + } + path => panic!("Unexpected request path: {}", path), + }); let result = table .add(data.clone()) @@ -2206,6 +2164,131 @@ mod tests { assert_eq!(&body, &expected_body); } + #[tokio::test] + async fn test_add_preprocessing() { + use crate::table::NaNVectorBehavior; + use arrow_array::{FixedSizeListArray, Float32Array, Int64Array}; + + // The table schema: {id: Int64, vec: FixedSizeList[3]} + let table_schema = Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new( + "vec", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 3), + false, + ), + ]); + let json_schema = JsonSchema::try_from(&table_schema).unwrap(); + let describe_body = serde_json::to_string(&json!({ + "version": 1, + "schema": json_schema, + })) + .unwrap(); + + // ---- Part 1: NaN vectors should be rejected by default ---- + let nan_data = RecordBatch::try_new( + Arc::new(table_schema.clone()), + vec![ + Arc::new(Int64Array::from(vec![1])), + Arc::new( + FixedSizeListArray::try_new( + Arc::new(Field::new("item", DataType::Float32, true)), + 3, + Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0])), + None, + ) + .unwrap(), + ), + ], + ) + .unwrap(); + + let describe_body_clone = describe_body.clone(); + let table = + Table::new_with_handler("my_table", move |request| match request.url().path() { + "/v1/table/my_table/describe/" => http::Response::builder() + .status(200) + .body(describe_body_clone.clone()) + .unwrap(), + "/v1/table/my_table/insert/" => http::Response::builder() + .status(200) + .body(r#"{"version": 2}"#.to_string()) + .unwrap(), + path => panic!("Unexpected path: {path}"), + }); + + let result = table.add(nan_data).execute().await; + assert!(result.is_err(), "NaN vectors should be rejected by default"); + assert!( + result.unwrap_err().to_string().contains("NaN"), + "error should mention NaN" + ); + + // ---- Part 2: With Keep, should handle casting and missing columns ---- + // Input: {id: Int32 (needs cast to Int64), vec: FixedSizeList[3] with NaN} + // Table expects Int64 for id; NaN should be kept. + let input_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "vec", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 3), + false, + ), + ]); + let cast_data = RecordBatch::try_new( + Arc::new(input_schema), + vec![ + Arc::new(Int32Array::from(vec![42])), + Arc::new( + FixedSizeListArray::try_new( + Arc::new(Field::new("item", DataType::Float32, true)), + 3, + Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0])), + None, + ) + .unwrap(), + ), + ], + ) + .unwrap(); + + let (sender, receiver) = std::sync::mpsc::channel(); + let table = + Table::new_with_handler("my_table", move |mut request| match request.url().path() { + "/v1/table/my_table/describe/" => http::Response::builder() + .status(200) + .body(describe_body.clone()) + .unwrap(), + "/v1/table/my_table/insert/" => { + let mut body_out = reqwest::Body::from(Vec::new()); + std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out); + sender.send(body_out).unwrap(); + http::Response::builder() + .status(200) + .body(r#"{"version": 2}"#.to_string()) + .unwrap() + } + path => panic!("Unexpected path: {path}"), + }); + + table + .add(cast_data) + .on_nan_vectors(NaNVectorBehavior::Keep) + .execute() + .await + .unwrap(); + + // Verify the data sent to the server was cast to the table schema. + let body = receiver.recv().unwrap(); + let body = collect_body(body).await; + let cursor = std::io::Cursor::new(body); + let mut reader = arrow_ipc::reader::StreamReader::try_new(cursor, None).unwrap(); + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.schema().field(0).data_type(), &DataType::Int64); + let ids: &Int64Array = batch.column(0).as_any().downcast_ref().unwrap(); + assert_eq!(ids.value(0), 42); + } + #[rstest] #[case(true)] #[case(false)] @@ -3572,23 +3655,29 @@ mod tests { ) .unwrap(); + let describe_body = describe_response(&data.schema()); let (sender, receiver) = std::sync::mpsc::channel(); let table = Table::new_with_handler("prod$metrics", move |mut request| { - if request.url().path() == "/v1/table/prod$metrics/insert/" { - assert_eq!(request.method(), "POST"); - assert_eq!( - request.headers().get("Content-Type").unwrap(), - ARROW_STREAM_CONTENT_TYPE - ); - let mut body_out = reqwest::Body::from(Vec::new()); - std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out); - sender.send(body_out).unwrap(); - http::Response::builder() + match request.url().path() { + "/v1/table/prod$metrics/describe/" => http::Response::builder() .status(200) - .body(r#"{"version": 2}"#) - .unwrap() - } else { - panic!("Unexpected request path: {}", request.url().path()); + .body(describe_body.clone()) + .unwrap(), + "/v1/table/prod$metrics/insert/" => { + assert_eq!(request.method(), "POST"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + ARROW_STREAM_CONTENT_TYPE + ); + let mut body_out = reqwest::Body::from(Vec::new()); + std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out); + sender.send(body_out).unwrap(); + http::Response::builder() + .status(200) + .body(r#"{"version": 2}"#.to_string()) + .unwrap() + } + path => panic!("Unexpected request path: {}", path), } }); @@ -4480,93 +4569,70 @@ mod tests { } #[tokio::test] - async fn test_add_retries_rescannable_data() { - let call_count = Arc::new(AtomicUsize::new(0)); - let call_count_clone = call_count.clone(); - - // Configure with retries enabled (default is 3) - let config = crate::remote::ClientConfig::default(); - - let table = Table::new_with_handler_and_config( - "my_table", - move |_request| { - let count = call_count_clone.fetch_add(1, Ordering::SeqCst); - if count < 2 { - // First two attempts fail with a retryable error (409) - http::Response::builder().status(409).body("").unwrap() - } else { - // Third attempt succeeds - http::Response::builder() - .status(200) - .body(r#"{"version": 1}"#) - .unwrap() - } - }, - config, - ); - - // RecordBatch is rescannable - should retry and succeed + async fn test_add_insert_fails() { + // Verify that an HTTP error from the insert endpoint is properly + // surfaced with the status code intact. Use 400 (non-retryable). let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap(); - let result = table.add(batch).execute().await; + let describe_body = describe_response(&batch.schema()); - assert!( - result.is_ok(), - "Expected success after retries: {:?}", - result - ); - assert_eq!( - call_count.load(Ordering::SeqCst), - 3, - "Expected 2 failed attempts + 1 success = 3 total" - ); + let table = + Table::new_with_handler("my_table", move |request| match request.url().path() { + "/v1/table/my_table/describe/" => http::Response::builder() + .status(200) + .body(describe_body.clone()) + .unwrap(), + "/v1/table/my_table/insert/" => http::Response::builder() + .status(400) + .body("bad request".to_string()) + .unwrap(), + path => panic!("Unexpected request path: {}", path), + }); + + let result = table.add(batch).execute().await; + let err = result.unwrap_err(); + match &err { + Error::Http { status_code, .. } => { + assert_eq!(*status_code, Some(reqwest::StatusCode::BAD_REQUEST)); + } + other => panic!("Expected Http error, got: {:?}", other), + } } #[tokio::test] - async fn test_add_no_retry_for_non_rescannable() { - let call_count = Arc::new(AtomicUsize::new(0)); - let call_count_clone = call_count.clone(); - - // Configure with retries enabled - let config = crate::remote::ClientConfig::default(); - - let table = Table::new_with_handler_and_config( - "my_table", - move |_request| { - call_count_clone.fetch_add(1, Ordering::SeqCst); - // Always fail with retryable error - http::Response::builder().status(409).body("").unwrap() - }, - config, - ); - - // RecordBatchReader is NOT rescannable - should NOT retry + async fn test_add_retries_on_retryable_status() { + // Verify that rescannable data retries on retryable status codes (e.g. 502) + // and eventually succeeds. let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap(); - let reader: Box = Box::new( - RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()), - ); + let describe_body = describe_response(&batch.schema()); - let result = table.add(reader).execute().await; + let attempt = Arc::new(AtomicUsize::new(0)); + let attempt_clone = attempt.clone(); - // Should fail because we can't retry non-rescannable sources - assert!(result.is_err()); - // Right now, we actually do retry, so we get 3 failures. In the future - // this will change and we need to update the test. - assert!( - matches!( - result.unwrap_err(), - Error::Retry { - request_failures: 3, - .. + let table = + Table::new_with_handler("my_table", move |request| match request.url().path() { + "/v1/table/my_table/describe/" => http::Response::builder() + .status(200) + .body(describe_body.clone()) + .unwrap(), + "/v1/table/my_table/insert/" => { + let n = attempt_clone.fetch_add(1, Ordering::SeqCst); + if n < 2 { + http::Response::builder() + .status(502) + .body("bad gateway".to_string()) + .unwrap() + } else { + http::Response::builder() + .status(200) + .body(r#"{"version": 3}"#.to_string()) + .unwrap() + } } - ), - "Expected RequestFailed with status 409" - ); - // TODO: After we implement proper non-rescannable handling, uncomment below - // (This is blocked on getting Python and Node to pass down re-scannable data.) - // assert_eq!( - // call_count.load(Ordering::SeqCst), - // 1, - // "Expected only one attempt for non-rescannable source" - // ); + path => panic!("Unexpected request path: {}", path), + }); + + let result = table.add(batch).execute().await.unwrap(); + assert_eq!(result.version, 3); + assert_eq!(attempt.load(Ordering::SeqCst), 3); } } diff --git a/rust/lancedb/src/remote/table/insert.rs b/rust/lancedb/src/remote/table/insert.rs index 04caaaa4a..da8cb9af0 100644 --- a/rust/lancedb/src/remote/table/insert.rs +++ b/rust/lancedb/src/remote/table/insert.rs @@ -8,7 +8,6 @@ use std::sync::{Arc, Mutex}; use arrow_array::{ArrayRef, RecordBatch, UInt64Array}; use arrow_ipc::CompressionType; -use arrow_schema::ArrowError; use datafusion_common::{DataFusionError, Result as DataFusionResult}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::EquivalenceProperties; @@ -76,7 +75,15 @@ impl RemoteInsertExec { self.add_result.lock().unwrap().clone() } - fn stream_as_body(data: SendableRecordBatchStream) -> DataFusionResult { + /// Stream the input into an HTTP body as an Arrow IPC stream, capturing any + /// stream errors into the provided channel. Errors from the input plan + /// (e.g. NaN rejection) would otherwise be swallowed inside the HTTP body + /// upload; by stashing them in the channel we can surface them with their + /// original message after the request completes. + fn stream_as_http_body( + data: SendableRecordBatchStream, + error_tx: tokio::sync::oneshot::Sender, + ) -> DataFusionResult { let options = arrow_ipc::writer::IpcWriteOptions::default() .try_with_compression(Some(CompressionType::LZ4_FRAME))?; let writer = arrow_ipc::writer::StreamWriter::try_new_with_options( @@ -85,26 +92,44 @@ impl RemoteInsertExec { options, )?; - let stream = futures::stream::try_unfold((data, writer), move |(mut data, mut writer)| { - async move { + let stream = futures::stream::try_unfold( + (data, writer, Some(error_tx), false), + move |(mut data, mut writer, error_tx, finished)| async move { + if finished { + return Ok(None); + } match data.next().await { Some(Ok(batch)) => { - writer.write(&batch)?; + writer + .write(&batch) + .map_err(|e| std::io::Error::other(e.to_string()))?; let buffer = std::mem::take(writer.get_mut()); - Ok(Some((buffer, (data, writer)))) + Ok(Some((buffer, (data, writer, error_tx, false)))) + } + Some(Err(e)) => { + // Send the original error through the channel before + // returning a generic error to reqwest. + if let Some(tx) = error_tx { + let _ = tx.send(e); + } + Err(std::io::Error::other( + "input stream error (see error channel)", + )) } - Some(Err(e)) => Err(e), None => { - if let Err(ArrowError::IpcError(_msg)) = writer.finish() { - // Will error if already closed. - return Ok(None); - }; + writer + .finish() + .map_err(|e| std::io::Error::other(e.to_string()))?; let buffer = std::mem::take(writer.get_mut()); - Ok(Some((buffer, (data, writer)))) + if buffer.is_empty() { + Ok(None) + } else { + Ok(Some((buffer, (data, writer, None, true)))) + } } } - } - }); + }, + ); Ok(reqwest::Body::wrap_stream(stream)) } @@ -202,24 +227,41 @@ impl ExecutionPlan for RemoteInsertExec { request = request.query(&[("mode", "overwrite")]); } - let body = Self::stream_as_body(input_stream)?; + let (error_tx, mut error_rx) = tokio::sync::oneshot::channel(); + let body = Self::stream_as_http_body(input_stream, error_tx)?; let request = request.body(body); - let (request_id, response) = client - .send(request) - .await - .map_err(|e| DataFusionError::External(Box::new(e)))?; - - let response = - RemoteTable::::handle_table_not_found(&table_name, response, &request_id) + let result: DataFusionResult<(String, _)> = async { + let (request_id, response) = client + .send(request) .await .map_err(|e| DataFusionError::External(Box::new(e)))?; - let response = client - .check_response(&request_id, response) + let response = RemoteTable::::handle_table_not_found( + &table_name, + response, + &request_id, + ) .await .map_err(|e| DataFusionError::External(Box::new(e)))?; + let response = client + .check_response(&request_id, response) + .await + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + Ok((request_id, response)) + } + .await; + + // If the request failed due to an input stream error, surface the + // original error (e.g. NaN rejection) instead of the HTTP error. + if let Ok(stream_err) = error_rx.try_recv() { + return Err(stream_err); + } + + let (request_id, response) = result?; + let body_text = response.text().await.map_err(|e| { DataFusionError::External(Box::new(Error::Http { source: Box::new(e), diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index f1449e10c..31bb8a1b3 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -10,15 +10,18 @@ use datafusion_expr::Expr; use datafusion_physical_plan::display::DisplayableExecutionPlan; use datafusion_physical_plan::ExecutionPlan; use futures::StreamExt; +use futures::TryStreamExt; use lance::dataset::builder::DatasetBuilder; pub use lance::dataset::ColumnAlteration; pub use lance::dataset::NewColumnTransform; pub use lance::dataset::ReadParams; pub use lance::dataset::Version; -use lance::dataset::{InsertBuilder, WriteMode, WriteParams}; +use lance::dataset::WriteMode; +use lance::dataset::{InsertBuilder, WriteParams}; use lance::index::vector::utils::infer_vector_dim; use lance::index::vector::VectorIndexParams; use lance::io::{ObjectStoreParams, WrappingObjectStore}; +use lance_datafusion::exec::execute_plan; use lance_datafusion::utils::StreamingWriteSource; use lance_index::scalar::{BuiltinIndexType, ScalarIndexParams}; use lance_index::vector::bq::RQBuildParams; @@ -40,7 +43,7 @@ use std::format; use std::path::Path; use std::sync::Arc; -use crate::data::scannable::{scannable_with_embeddings, Scannable}; +use crate::data::scannable::Scannable; use crate::database::Database; use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MemoryRegistry}; use crate::error::{Error, Result}; @@ -49,6 +52,7 @@ use crate::index::IndexStatistics; use crate::index::{vector::suggested_num_sub_vectors, Index, IndexBuilder}; use crate::index::{IndexConfig, IndexStatisticsImpl}; use crate::query::{IntoQueryVector, Query, QueryExecutionOptions, TakeQuery, VectorQuery}; +use crate::table::datafusion::insert::InsertExec; use crate::utils::{ supported_bitmap_data_type, supported_btree_data_type, supported_fts_data_type, supported_label_list_data_type, supported_vector_data_type, PatchReadParam, PatchWriteParam, @@ -67,7 +71,7 @@ pub mod query; pub mod schema_evolution; pub mod update; use crate::index::waiter::wait_for_index; -pub use add_data::{AddDataBuilder, AddDataMode, AddResult}; +pub use add_data::{AddDataBuilder, AddDataMode, AddResult, NaNVectorBehavior}; pub use chrono::Duration; pub use delete::DeleteResult; use futures::future::join_all; @@ -2110,28 +2114,41 @@ impl BaseTable for NativeTable { } async fn add(&self, add: AddDataBuilder) -> Result { - let lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams { - mode: match add.mode { - AddDataMode::Append => WriteMode::Append, - AddDataMode::Overwrite => WriteMode::Overwrite, - }, - ..Default::default() - }); - - // Apply embeddings if configured let table_def = self.table_definition().await?; - let data = - scannable_with_embeddings(add.data, &table_def, add.embedding_registry.as_ref())?; self.dataset.ensure_mutable()?; + let ds_wrapper = self.dataset.clone(); let ds = self.dataset.get().await?; - let dataset = InsertBuilder::new(ds) - .with_params(&lance_params) - .execute_stream(data) - .await?; - let version = dataset.manifest().version; - self.dataset.update(dataset); + let table_schema = Schema::from(&ds.schema().clone()); + + let output = add.into_plan(&table_schema, &table_def)?; + + let lance_params = output + .write_options + .lance_write_params + .unwrap_or(WriteParams { + mode: match output.mode { + AddDataMode::Append => WriteMode::Append, + AddDataMode::Overwrite => WriteMode::Overwrite, + }, + ..Default::default() + }); + + let plan = Arc::new(InsertExec::new( + ds_wrapper.clone(), + ds, + output.plan, + lance_params, + )); + + let stream = execute_plan(plan, Default::default())?; + stream + .try_collect::>() + .await + .map_err(crate::Error::from)?; + + let version = ds_wrapper.get().await?.manifest().version; Ok(AddResult { version }) } diff --git a/rust/lancedb/src/table/add_data.rs b/rust/lancedb/src/table/add_data.rs index 740e53011..ad64e36d7 100644 --- a/rust/lancedb/src/table/add_data.rs +++ b/rust/lancedb/src/table/add_data.rs @@ -3,13 +3,19 @@ use std::sync::Arc; +use arrow_schema::{DataType, Fields, Schema}; +use lance::dataset::WriteMode; use serde::{Deserialize, Serialize}; +use crate::data::scannable::scannable_with_embeddings; use crate::data::scannable::Scannable; use crate::embeddings::EmbeddingRegistry; -use crate::Result; +use crate::table::datafusion::cast::cast_to_table_schema; +use crate::table::datafusion::reject_nan::reject_nan_vectors; +use crate::table::datafusion::scannable_exec::ScannableExec; +use crate::{Error, Result}; -use super::{BaseTable, WriteOptions}; +use super::{BaseTable, TableDefinition, WriteOptions}; #[derive(Debug, Clone, Default)] pub enum AddDataMode { @@ -29,12 +35,22 @@ pub struct AddResult { pub version: u64, } +#[derive(Debug, Default, Clone, Copy)] +pub enum NaNVectorBehavior { + /// Reject any vectors containing NaN values (the default) + #[default] + Error, + /// Allow NaN values to be added, but they will not be indexed for search + Keep, +} + /// A builder for configuring a [`crate::table::Table::add`] operation pub struct AddDataBuilder { pub(crate) parent: Arc, pub(crate) data: Box, pub(crate) mode: AddDataMode, pub(crate) write_options: WriteOptions, + pub(crate) on_nan_vectors: NaNVectorBehavior, pub(crate) embedding_registry: Option>, } @@ -59,6 +75,7 @@ impl AddDataBuilder { data, mode: AddDataMode::Append, write_options: WriteOptions::default(), + on_nan_vectors: NaNVectorBehavior::default(), embedding_registry, } } @@ -73,16 +90,121 @@ impl AddDataBuilder { self } + /// Configure how to handle NaN values in vector columns. + /// + /// By default, any vectors containing NaN values will be rejected with an + /// error, since NaNs cannot be indexed for search. Setting this to `Keep` + /// will allow NaN values to be added to the table, but they will not be + /// indexed and will not be searchable. + pub fn on_nan_vectors(mut self, behavior: NaNVectorBehavior) -> Self { + self.on_nan_vectors = behavior; + self + } + pub async fn execute(self) -> Result { self.parent.clone().add(self).await } + + /// Build a DataFusion execution plan that applies embeddings, casts data to + /// the table schema, and optionally rejects NaN vectors. + /// + /// Returns the plan along with whether the input is rescannable (for retry + /// decisions) and whether this is an overwrite operation. + pub(crate) fn into_plan( + mut self, + table_schema: &Schema, + table_def: &TableDefinition, + ) -> Result { + let overwrite = self + .write_options + .lance_write_params + .as_ref() + .is_some_and(|p| matches!(p.mode, WriteMode::Overwrite)) + || matches!(self.mode, AddDataMode::Overwrite); + + if !overwrite { + validate_schema(&self.data.schema(), table_schema)?; + } + + self.data = + scannable_with_embeddings(self.data, table_def, self.embedding_registry.as_ref())?; + + let rescannable = self.data.rescannable(); + let plan: Arc = + Arc::new(ScannableExec::new(self.data)); + // Skip casting when overwriting — the input schema replaces the table schema. + let plan = if overwrite { + plan + } else { + cast_to_table_schema(plan, table_schema)? + }; + let plan = match self.on_nan_vectors { + NaNVectorBehavior::Error => reject_nan_vectors(plan)?, + NaNVectorBehavior::Keep => plan, + }; + + Ok(PreprocessingOutput { + plan, + overwrite, + rescannable, + write_options: self.write_options, + mode: self.mode, + }) + } +} + +pub struct PreprocessingOutput { + pub plan: Arc, + pub overwrite: bool, + pub rescannable: bool, + pub write_options: WriteOptions, + pub mode: AddDataMode, +} + +/// Check that the input schema is valid for insert. +/// +/// Fields can be in different orders, so match by name. +/// +/// If a column exists in input but not in table, error (no extra columns allowed). +/// +/// If a column exists in table but not in input, that is okay - it may be filled with nulls. +/// +/// If the types are not exactly the same, we will attempt to cast later - so that is also okay at this stage. +/// +/// If the nullability is different, that is also okay - we can relax nullability when casting. +fn validate_schema(input: &Schema, table: &Schema) -> Result<()> { + validate_fields(input.fields(), table.fields()) +} + +fn validate_fields(input: &Fields, table: &Fields) -> Result<()> { + for field in input { + match table.iter().find(|f| f.name() == field.name()) { + None => { + return Err(Error::InvalidInput { + message: format!("field '{}' does not exist in table schema", field.name()), + }); + } + Some(table_field) => { + if let (DataType::Struct(in_children), DataType::Struct(tbl_children)) = + (field.data_type(), table_field.data_type()) + { + validate_fields(in_children, tbl_children)?; + } + } + } + } + Ok(()) } #[cfg(test)] mod tests { use std::sync::Arc; - use arrow_array::{record_batch, RecordBatch, RecordBatchIterator}; + use arrow::datatypes::Float64Type; + use arrow_array::{ + record_batch, FixedSizeListArray, Float32Array, Int32Array, LargeStringArray, ListArray, + RecordBatch, RecordBatchIterator, + }; use arrow_schema::{ArrowError, DataType, Field, Schema}; use futures::TryStreamExt; use lance::dataset::{WriteMode, WriteParams}; @@ -94,6 +216,7 @@ mod tests { EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, }; use crate::query::{ExecutableQuery, QueryBase, Select}; + use crate::table::add_data::NaNVectorBehavior; use crate::table::{ColumnDefinition, ColumnKind, Table, TableDefinition, WriteOptions}; use crate::test_utils::embeddings::MockEmbed; use crate::Error; @@ -340,4 +463,248 @@ mod tests { assert_eq!(embedding_col.null_count(), 0); } } + + #[tokio::test] + async fn test_add_casts_to_table_schema() { + let table_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("text", DataType::Utf8, false), + Field::new( + "embedding", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4), + false, + ), + ])); + + let input_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), // Upcast integer + Field::new("text", DataType::LargeUtf8, false), // Re-encode string + // Cast list of float64 to fixed-size list of float32 + // (This will only work if list size is correct. See next test. + Field::new( + "embedding", + DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + false, + ), + ])); + + let db = connect("memory://").execute().await.unwrap(); + let table = db + .create_empty_table("cast_test", table_schema.clone()) + .execute() + .await + .unwrap(); + + let batch = RecordBatch::try_new( + input_schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(LargeStringArray::from(vec!["hello", "world"])), + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![0.1, 0.2, 0.3, 0.4].into_iter().map(Some)), + Some(vec![0.5, 0.6, 0.7, 0.8].into_iter().map(Some)), + ])), + ], + ) + .unwrap(); + table.add(batch).execute().await.unwrap(); + + let row_count = table.count_rows(None).await.unwrap(); + assert_eq!(row_count, 2); + } + + #[tokio::test] + async fn test_add_rejects_bad_vector_dimensions() { + let table_schema = Arc::new(Schema::new(vec![Field::new( + "embedding", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4), + false, + )])); + + let input_schema = Arc::new(Schema::new(vec![Field::new( + "embedding", + DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + false, + )])); + + let db = connect("memory://").execute().await.unwrap(); + let table = db + .create_empty_table("cast_test", table_schema.clone()) + .execute() + .await + .unwrap(); + + let batch = RecordBatch::try_new( + input_schema, + vec![Arc::new( + ListArray::from_iter_primitive::(vec![ + Some(vec![0.1, 0.2, 0.3, 0.4].into_iter().map(Some)), + Some(vec![0.5, 0.6, 0.8].into_iter().map(Some)), + ]), + )], + ) + .unwrap(); + let res = table.add(batch).execute().await; + + // TODO: to recover the error, we will need fix upstream in Lance. + // assert!( + // matches!(res, Err(Error::Arrow { source: ArrowError::CastError(_) })), + // "Expected schema mismatch error due to wrong vector dimensions, but got: {res:?}" + // ); + assert!( + res.is_err(), + "Expected error due to wrong vector dimensions, but got success" + ); + } + + #[tokio::test] + async fn test_add_rejects_nan_vectors() { + let schema = Arc::new(Schema::new(vec![Field::new( + "embedding", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4), + false, + )])); + + let db = connect("memory://").execute().await.unwrap(); + let table = db + .create_empty_table("nan_test", schema.clone()) + .execute() + .await + .unwrap(); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::new( + FixedSizeListArray::try_new( + Arc::new(Field::new("item", DataType::Float32, true)), + 4, + Arc::new(Float32Array::from(vec![0.1, 0.2, f32::NAN, 0.4])), + None, + ) + .unwrap(), + )], + ) + .unwrap(); + let res = table.add(batch.clone()).execute().await; + let err = res.unwrap_err(); + assert!( + err.to_string().contains("NaN"), + "Expected error mentioning NaN values, but got: {err:?}" + ); + + table + .add(batch) + .on_nan_vectors(NaNVectorBehavior::Keep) + .execute() + .await + .unwrap(); + + let row_count = table.count_rows(None).await.unwrap(); + assert_eq!(row_count, 1); + } + + #[tokio::test] + async fn test_add_subschema() { + let data = record_batch!(("id", Int64, [4, 5]), ("text", Utf8, ["foo", "bar"])).unwrap(); + let db = connect("memory://").execute().await.unwrap(); + let table = db + .create_table("test", data.clone()) + .execute() + .await + .unwrap(); + + let new_data = record_batch!(("id", Int64, [6, 7])).unwrap(); + table.add(new_data).execute().await.unwrap(); + + assert_eq!(table.count_rows(None).await.unwrap(), 4); + assert_eq!( + table + .count_rows(Some("id IS NOT NULL".to_string())) + .await + .unwrap(), + 4 + ); + assert_eq!( + table + .count_rows(Some("text IS NOT NULL".to_string())) + .await + .unwrap(), + 2 + ); + + // We can still cast + let new_data = record_batch!(("text", LargeUtf8, ["baz", "qux"])).unwrap(); + table.add(new_data).execute().await.unwrap(); + + assert_eq!(table.count_rows(None).await.unwrap(), 6); + assert_eq!( + table + .count_rows(Some("id IS NOT NULL".to_string())) + .await + .unwrap(), + 4 + ); + assert_eq!( + table + .count_rows(Some("text IS NOT NULL".to_string())) + .await + .unwrap(), + 4 + ); + + // Extra columns mean an error + let new_data = + record_batch!(("id", Int64, [8, 9]), ("extra", Utf8, ["extra1", "extra2"])).unwrap(); + let res = table.add(new_data).execute().await; + assert!( + res.is_err(), + "Expected error due to extra column, but got: {res:?}" + ); + + // Insert with a subset of struct sub-fields + let struct_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new( + "metadata", + DataType::Struct( + vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Utf8, true), + ] + .into(), + ), + true, + ), + ])); + let db2 = connect("memory://").execute().await.unwrap(); + let table2 = db2 + .create_empty_table("struct_test", struct_schema) + .execute() + .await + .unwrap(); + + // Insert with only the "a" sub-field of the struct + let sub_struct_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new( + "metadata", + DataType::Struct(vec![Field::new("a", DataType::Int64, true)].into()), + true, + ), + ])); + let struct_batch = RecordBatch::try_new( + sub_struct_schema, + vec![ + Arc::new(arrow_array::Int64Array::from(vec![1, 2])), + Arc::new(arrow_array::StructArray::from(vec![( + Arc::new(Field::new("a", DataType::Int64, true)), + Arc::new(arrow_array::Int64Array::from(vec![10, 20])) + as Arc, + )])), + ], + ) + .unwrap(); + table2.add(struct_batch).execute().await.unwrap(); + assert_eq!(table2.count_rows(None).await.unwrap(), 2); + } } diff --git a/rust/lancedb/src/table/datafusion.rs b/rust/lancedb/src/table/datafusion.rs index c7ced18ea..aaa5b8d7f 100644 --- a/rust/lancedb/src/table/datafusion.rs +++ b/rust/lancedb/src/table/datafusion.rs @@ -3,7 +3,10 @@ //! This module contains adapters to allow LanceDB tables to be used as DataFusion table providers. +pub mod cast; pub mod insert; +pub mod reject_nan; +pub mod scannable_exec; pub mod udtf; use std::{collections::HashMap, sync::Arc}; diff --git a/rust/lancedb/src/table/datafusion/cast.rs b/rust/lancedb/src/table/datafusion/cast.rs new file mode 100644 index 000000000..76459ea87 --- /dev/null +++ b/rust/lancedb/src/table/datafusion/cast.rs @@ -0,0 +1,498 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::sync::Arc; + +use arrow_schema::{DataType, Field, FieldRef, Fields, Schema}; +use datafusion::functions::core::{get_field, named_struct}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::ScalarValue; +use datafusion_physical_expr::expressions::{cast, Literal}; +use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_plan::expressions::Column; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::{ExecutionPlan, PhysicalExpr}; + +use crate::{Error, Result}; + +pub fn cast_to_table_schema( + input: Arc, + table_schema: &Schema, +) -> Result> { + let input_schema = input.schema(); + + if input_schema.fields() == table_schema.fields() { + return Ok(input); + } + + let exprs = build_field_exprs( + input_schema.fields(), + table_schema.fields(), + &|idx| Arc::new(Column::new(input_schema.field(idx).name(), idx)) as Arc, + &input_schema, + )?; + + let exprs: Vec<(Arc, String)> = exprs + .into_iter() + .map(|(expr, field)| (expr, field.name().clone())) + .collect(); + + let projection = ProjectionExec::try_new(exprs, input).map_err(crate::Error::from)?; + + Ok(Arc::new(projection)) +} + +/// Build expressions to project input fields to match the table schema. +/// +/// For each table field that exists in the input, produce an expression that +/// reads from the input and casts if needed. Fields in the table but not in the +/// input are omitted (the storage layer handles missing columns). +fn build_field_exprs( + input_fields: &Fields, + table_fields: &Fields, + get_input_expr: &dyn Fn(usize) -> Arc, + input_schema: &Schema, +) -> Result, FieldRef)>> { + let config = Arc::new(ConfigOptions::default()); + let mut result = Vec::new(); + + for table_field in table_fields { + let Some(input_idx) = input_fields + .iter() + .position(|f| f.name() == table_field.name()) + else { + continue; + }; + + let input_field = &input_fields[input_idx]; + let input_expr = get_input_expr(input_idx); + + let expr = match (input_field.data_type(), table_field.data_type()) { + // Both are structs: recurse into sub-fields to handle subschemas and casts. + (DataType::Struct(in_children), DataType::Struct(tbl_children)) + if in_children != tbl_children => + { + let sub_exprs = build_field_exprs( + in_children, + tbl_children, + &|child_idx| { + let child_name = in_children[child_idx].name(); + Arc::new(ScalarFunctionExpr::new( + &format!("get_field({child_name})"), + get_field(), + vec![ + input_expr.clone(), + Arc::new(Literal::new(ScalarValue::from(child_name.as_str()))), + ], + Arc::new(in_children[child_idx].as_ref().clone()), + config.clone(), + )) as Arc + }, + input_schema, + )?; + + let output_struct_fields: Fields = sub_exprs + .iter() + .map(|(_, f)| f.clone()) + .collect::>() + .into(); + let output_field: FieldRef = Arc::new(Field::new( + table_field.name(), + DataType::Struct(output_struct_fields), + table_field.is_nullable(), + )); + + // Build named_struct(lit("a"), expr_a, lit("b"), expr_b, ...) + let mut ns_args: Vec> = Vec::new(); + for (sub_expr, sub_field) in &sub_exprs { + ns_args.push(Arc::new(Literal::new(ScalarValue::from( + sub_field.name().as_str(), + )))); + ns_args.push(sub_expr.clone()); + } + + let ns_expr: Arc = Arc::new(ScalarFunctionExpr::new( + &format!("named_struct({})", table_field.name()), + named_struct(), + ns_args, + output_field.clone(), + config.clone(), + )); + + result.push((ns_expr, output_field)); + continue; + } + // Types match: pass through. + (inp, tbl) if inp == tbl => input_expr, + // Types differ: cast. + _ => cast(input_expr, input_schema, table_field.data_type().clone()).map_err(|e| { + Error::InvalidInput { + message: format!( + "cannot cast field '{}' from {} to {}: {}", + table_field.name(), + input_field.data_type(), + table_field.data_type(), + e + ), + } + })?, + }; + + let output_field = Arc::new(Field::new( + table_field.name(), + table_field.data_type().clone(), + table_field.is_nullable(), + )); + result.push((expr, output_field)); + } + + Ok(result) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::{ + Float32Array, Float64Array, Int32Array, Int64Array, RecordBatch, StringArray, StructArray, + }; + use arrow_schema::{DataType, Field, Schema}; + use datafusion::prelude::SessionContext; + use datafusion_catalog::MemTable; + use futures::TryStreamExt; + + use super::cast_to_table_schema; + + async fn plan_from_batch( + batch: RecordBatch, + ) -> Arc { + let schema = batch.schema(); + let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap(); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::new(table)).unwrap(); + let df = ctx.table("t").await.unwrap(); + df.create_physical_plan().await.unwrap() + } + + async fn collect(plan: Arc) -> RecordBatch { + let ctx = SessionContext::new(); + let stream = plan.execute(0, ctx.task_ctx()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + arrow_select::concat::concat_batches(&plan.schema(), &batches).unwrap() + } + + #[tokio::test] + async fn test_noop_when_schemas_match() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["x", "y"])), + ], + ) + .unwrap(); + + let input = plan_from_batch(batch).await; + let input_ptr = Arc::as_ptr(&input); + let result = cast_to_table_schema(input, &schema).unwrap(); + assert_eq!(Arc::as_ptr(&result), input_ptr); + } + + #[tokio::test] + async fn test_simple_type_cast() { + let input_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Float32, false), + ])), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Float32Array::from(vec![1.5, 2.5, 3.5])), + ], + ) + .unwrap(); + + let table_schema = Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("val", DataType::Float64, false), + ]); + + let plan = plan_from_batch(input_batch).await; + let casted = cast_to_table_schema(plan, &table_schema).unwrap(); + let result = collect(casted).await; + + assert_eq!(result.schema().field(0).data_type(), &DataType::Int64); + assert_eq!(result.schema().field(1).data_type(), &DataType::Float64); + + let ids: &Int64Array = result.column(0).as_any().downcast_ref().unwrap(); + assert_eq!(ids.values(), &[1, 2, 3]); + + let vals: &Float64Array = result.column(1).as_any().downcast_ref().unwrap(); + assert!((vals.value(0) - 1.5).abs() < 1e-6); + assert!((vals.value(1) - 2.5).abs() < 1e-6); + assert!((vals.value(2) - 3.5).abs() < 1e-6); + } + + #[tokio::test] + async fn test_missing_table_field_skipped() { + // Input has "a", table expects "a" and "b". "b" is omitted from the + // projection since the storage layer fills in missing columns. + let input_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![10, 20]))], + ) + .unwrap(); + + let table_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ]); + + let plan = plan_from_batch(input_batch).await; + let casted = cast_to_table_schema(plan, &table_schema).unwrap(); + let result = collect(casted).await; + + assert_eq!(result.num_columns(), 1); + assert_eq!(result.schema().field(0).name(), "a"); + } + + #[tokio::test] + async fn test_extra_input_fields_dropped() { + // Input has "a" and "extra"; table only expects "a". + let input_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("extra", DataType::Utf8, false), + ])), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["x", "y"])), + ], + ) + .unwrap(); + + let table_schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]); + + let plan = plan_from_batch(input_batch).await; + let casted = cast_to_table_schema(plan, &table_schema).unwrap(); + let result = collect(casted).await; + + assert_eq!(result.num_columns(), 1); + assert_eq!(result.schema().field(0).name(), "a"); + assert_eq!(result.schema().field(0).data_type(), &DataType::Int64); + } + + #[tokio::test] + async fn test_reorders_to_table_schema() { + let input_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("b", DataType::Utf8, false), + Field::new("a", DataType::Int32, false), + ])), + vec![ + Arc::new(StringArray::from(vec!["x", "y"])), + Arc::new(Int32Array::from(vec![1, 2])), + ], + ) + .unwrap(); + + let table_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ]); + + let plan = plan_from_batch(input_batch).await; + let casted = cast_to_table_schema(plan, &table_schema).unwrap(); + let result = collect(casted).await; + + assert_eq!(result.schema().field(0).name(), "a"); + assert_eq!(result.schema().field(1).name(), "b"); + + let a: &Int32Array = result.column(0).as_any().downcast_ref().unwrap(); + assert_eq!(a.values(), &[1, 2]); + let b: &StringArray = result.column(1).as_any().downcast_ref().unwrap(); + assert_eq!(b.value(0), "x"); + } + + #[tokio::test] + async fn test_struct_subfield_cast() { + // Input struct has {x: Int32, y: Int32}, table expects {x: Int64, y: Int64}. + let inner_fields = vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Int32, false), + ]; + let struct_array = StructArray::from(vec![ + ( + Arc::new(inner_fields[0].clone()), + Arc::new(Int32Array::from(vec![1, 2])) as _, + ), + ( + Arc::new(inner_fields[1].clone()), + Arc::new(Int32Array::from(vec![3, 4])) as _, + ), + ]); + let input_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "s", + DataType::Struct(inner_fields.into()), + false, + )])), + vec![Arc::new(struct_array)], + ) + .unwrap(); + + let table_inner = vec![ + Field::new("x", DataType::Int64, false), + Field::new("y", DataType::Int64, false), + ]; + let table_schema = Schema::new(vec![Field::new( + "s", + DataType::Struct(table_inner.into()), + false, + )]); + + let plan = plan_from_batch(input_batch).await; + let casted = cast_to_table_schema(plan, &table_schema).unwrap(); + let result = collect(casted).await; + + let struct_col = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(struct_col.column(0).data_type(), &DataType::Int64); + assert_eq!(struct_col.column(1).data_type(), &DataType::Int64); + + let x: &Int64Array = struct_col.column(0).as_any().downcast_ref().unwrap(); + assert_eq!(x.values(), &[1, 2]); + let y: &Int64Array = struct_col.column(1).as_any().downcast_ref().unwrap(); + assert_eq!(y.values(), &[3, 4]); + } + + #[tokio::test] + async fn test_struct_subschema() { + // Input struct has {x, y, z}, table only expects {x, z}. + let inner_fields = vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Int32, false), + Field::new("z", DataType::Int32, false), + ]; + let struct_array = StructArray::from(vec![ + ( + Arc::new(inner_fields[0].clone()), + Arc::new(Int32Array::from(vec![1, 2])) as _, + ), + ( + Arc::new(inner_fields[1].clone()), + Arc::new(Int32Array::from(vec![10, 20])) as _, + ), + ( + Arc::new(inner_fields[2].clone()), + Arc::new(Int32Array::from(vec![100, 200])) as _, + ), + ]); + let input_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "s", + DataType::Struct(inner_fields.into()), + false, + )])), + vec![Arc::new(struct_array)], + ) + .unwrap(); + + let table_inner = vec![ + Field::new("x", DataType::Int32, false), + Field::new("z", DataType::Int32, false), + ]; + let table_schema = Schema::new(vec![Field::new( + "s", + DataType::Struct(table_inner.into()), + false, + )]); + + let plan = plan_from_batch(input_batch).await; + let casted = cast_to_table_schema(plan, &table_schema).unwrap(); + let result = collect(casted).await; + + let struct_col = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(struct_col.num_columns(), 2); + + let x: &Int32Array = struct_col + .column_by_name("x") + .unwrap() + .as_any() + .downcast_ref() + .unwrap(); + assert_eq!(x.values(), &[1, 2]); + let z: &Int32Array = struct_col + .column_by_name("z") + .unwrap() + .as_any() + .downcast_ref() + .unwrap(); + assert_eq!(z.values(), &[100, 200]); + } + + #[tokio::test] + async fn test_incompatible_cast_errors() { + let input_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Binary, false)])), + vec![Arc::new(arrow_array::BinaryArray::from_vec(vec![b"hi"]))], + ) + .unwrap(); + + let table_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let plan = plan_from_batch(input_batch).await; + let result = cast_to_table_schema(plan, &table_schema); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("cannot cast field 'a'"), + "unexpected error: {err_msg}" + ); + } + + #[tokio::test] + async fn test_mixed_cast_and_passthrough() { + // "a" needs cast (Int32→Int64), "b" passes through unchanged. + let input_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ])), + vec![ + Arc::new(Int32Array::from(vec![7, 8])), + Arc::new(StringArray::from(vec!["hello", "world"])), + ], + ) + .unwrap(); + + let table_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + ]); + + let plan = plan_from_batch(input_batch).await; + let casted = cast_to_table_schema(plan, &table_schema).unwrap(); + let result = collect(casted).await; + + assert_eq!(result.schema().field(0).data_type(), &DataType::Int64); + assert_eq!(result.schema().field(1).data_type(), &DataType::Utf8); + + let a: &Int64Array = result.column(0).as_any().downcast_ref().unwrap(); + assert_eq!(a.values(), &[7, 8]); + let b: &StringArray = result.column(1).as_any().downcast_ref().unwrap(); + assert_eq!(b.value(0), "hello"); + assert_eq!(b.value(1), "world"); + } +} diff --git a/rust/lancedb/src/table/datafusion/reject_nan.rs b/rust/lancedb/src/table/datafusion/reject_nan.rs new file mode 100644 index 000000000..ca0ba69db --- /dev/null +++ b/rust/lancedb/src/table/datafusion/reject_nan.rs @@ -0,0 +1,269 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! A DataFusion projection that rejects vectors containing NaN values. + +use std::any::Any; +use std::sync::{Arc, LazyLock}; + +use arrow_array::{Array, FixedSizeListArray}; +use arrow_schema::{DataType, Field, FieldRef}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_plan::expressions::Column; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::{ExecutionPlan, PhysicalExpr}; + +use crate::{Error, Result}; + +static REJECT_NAN_UDF: LazyLock> = + LazyLock::new(|| Arc::new(datafusion_expr::ScalarUDF::from(RejectNanUdf::new()))); + +/// Returns true if the field is a vector column: FixedSizeList. +fn is_vector_field(field: &Field) -> bool { + if let DataType::FixedSizeList(child, _) = field.data_type() { + matches!( + child.data_type(), + DataType::Float16 | DataType::Float32 | DataType::Float64 + ) + } else { + false + } +} + +/// Wraps the input plan with a projection that checks vector columns for NaN values. +/// +/// Non-vector columns pass through unchanged. Vector columns are wrapped with a +/// UDF that returns the column as-is if no NaNs are present, or errors otherwise. +pub fn reject_nan_vectors(input: Arc) -> Result> { + let schema = input.schema(); + let config = Arc::new(ConfigOptions::default()); + let udf = REJECT_NAN_UDF.clone(); + + let mut has_vector_cols = false; + let mut exprs: Vec<(Arc, String)> = Vec::new(); + + for (idx, field) in schema.fields().iter().enumerate() { + let col_expr: Arc = Arc::new(Column::new(field.name(), idx)); + + if is_vector_field(field) { + has_vector_cols = true; + let wrapped: Arc = Arc::new(ScalarFunctionExpr::new( + &format!("reject_nan({})", field.name()), + udf.clone(), + vec![col_expr], + Arc::clone(field) as FieldRef, + config.clone(), + )); + exprs.push((wrapped, field.name().clone())); + } else { + exprs.push((col_expr, field.name().clone())); + } + } + + if !has_vector_cols { + return Ok(input); + } + + let projection = ProjectionExec::try_new(exprs, input).map_err(Error::from)?; + Ok(Arc::new(projection)) +} + +/// A scalar UDF that passes through FixedSizeList arrays unchanged, but errors +/// if any float values in the list are NaN. +#[derive(Debug, Hash, PartialEq, Eq)] +struct RejectNanUdf { + signature: Signature, +} + +impl RejectNanUdf { + fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for RejectNanUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "reject_nan" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let arg = &args.args[0]; + match arg { + ColumnarValue::Array(array) => { + check_no_nans(array)?; + Ok(ColumnarValue::Array(array.clone())) + } + ColumnarValue::Scalar(_) => Ok(arg.clone()), + } + } +} + +fn check_no_nans(array: &dyn Array) -> datafusion_common::Result<()> { + let fsl = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion_common::DataFusionError::Internal( + "reject_nan expected FixedSizeList".to_string(), + ) + })?; + + // Only inspect elements that are both in a valid parent row and non-null + // themselves. Values backing null parent rows or null child elements may + // contain garbage (including NaN) per the Arrow spec. + let has_nan = (0..fsl.len()).filter(|i| fsl.is_valid(*i)).any(|i| { + let row = fsl.value(i); + match row.data_type() { + DataType::Float16 => row + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .any(|v| v.is_some_and(|v| v.is_nan())), + DataType::Float32 => row + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .any(|v| v.is_some_and(|v| v.is_nan())), + DataType::Float64 => row + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .any(|v| v.is_some_and(|v| v.is_nan())), + _ => false, + } + }); + + if has_nan { + return Err(datafusion_common::DataFusionError::ArrowError( + Box::new(arrow_schema::ArrowError::ComputeError( + "Vector column contains NaN values".to_string(), + )), + None, + )); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::Float32Array; + + #[test] + fn test_passes_clean_vectors() { + let fsl = FixedSizeListArray::try_new( + Arc::new(Field::new("item", DataType::Float32, true)), + 2, + Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0])), + None, + ) + .unwrap(); + assert!(check_no_nans(&fsl).is_ok()); + } + + #[test] + fn test_rejects_nan_vectors() { + let fsl = FixedSizeListArray::try_new( + Arc::new(Field::new("item", DataType::Float32, true)), + 2, + Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, 4.0])), + None, + ) + .unwrap(); + assert!(check_no_nans(&fsl).is_err()); + } + + #[test] + fn test_skips_null_rows() { + // Values backing null rows may contain NaN per the Arrow spec. + // We should not reject a batch just because of garbage in null slots. + let values = Float32Array::from(vec![1.0, 2.0, f32::NAN, f32::NAN]); + let fsl = FixedSizeListArray::try_new( + Arc::new(Field::new("item", DataType::Float32, true)), + 2, + Arc::new(values), + // Row 0 is valid [1.0, 2.0], row 1 is null [NAN, NAN] + Some(vec![true, false].into()), + ) + .unwrap(); + assert!(fsl.is_null(1)); + assert!(check_no_nans(&fsl).is_ok()); + } + + #[test] + fn test_skips_null_elements_within_valid_row() { + // A valid row with null child elements: the underlying buffer may hold + // NaN but the null bitmap says they're absent — should not reject. + let values = Float32Array::from(vec![ + Some(1.0), + None, // null element — buffer may contain NaN + Some(3.0), + None, // null element + ]); + let fsl = FixedSizeListArray::try_new( + Arc::new(Field::new("item", DataType::Float32, true)), + 2, + Arc::new(values), + None, // both rows are valid + ) + .unwrap(); + assert!(check_no_nans(&fsl).is_ok()); + } + + #[test] + fn test_rejects_nan_in_valid_row_with_nulls_present() { + // Row 0 is null, row 1 is valid but contains NaN — should reject. + let values = Float32Array::from(vec![0.0, 0.0, 1.0, f32::NAN]); + let fsl = FixedSizeListArray::try_new( + Arc::new(Field::new("item", DataType::Float32, true)), + 2, + Arc::new(values), + Some(vec![false, true].into()), + ) + .unwrap(); + assert!(check_no_nans(&fsl).is_err()); + } + + #[test] + fn test_is_vector_field() { + assert!(is_vector_field(&Field::new( + "v", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4), + false, + ))); + assert!(is_vector_field(&Field::new( + "v", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 4), + false, + ))); + assert!(!is_vector_field(&Field::new("id", DataType::Int32, false))); + assert!(!is_vector_field(&Field::new( + "v", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 4), + false, + ))); + } +} diff --git a/rust/lancedb/src/table/datafusion/scannable_exec.rs b/rust/lancedb/src/table/datafusion/scannable_exec.rs new file mode 100644 index 000000000..1dafa55de --- /dev/null +++ b/rust/lancedb/src/table/datafusion/scannable_exec.rs @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use core::fmt; +use std::sync::{Arc, Mutex}; + +use datafusion_common::{stats::Precision, DataFusionError, Result as DFResult, Statistics}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::{ + execution_plan::EmissionType, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, +}; + +use crate::{arrow::SendableRecordBatchStreamExt, data::scannable::Scannable}; + +pub struct ScannableExec { + // We don't require Scannable to by Sync, so we wrap it in a Mutex to allow safe concurrent access. + source: Mutex>, + num_rows: Option, + properties: PlanProperties, +} + +impl std::fmt::Debug for ScannableExec { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ScannableExec") + .field("schema", &self.schema()) + .field("num_rows", &self.num_rows) + .finish() + } +} + +impl ScannableExec { + pub fn new(source: Box) -> Self { + let schema = source.schema(); + let eq_properties = EquivalenceProperties::new(schema); + let properties = PlanProperties::new( + eq_properties, + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + datafusion_physical_plan::execution_plan::Boundedness::Bounded, + ); + + let num_rows = source.num_rows(); + let source = Mutex::new(source); + Self { + source, + num_rows, + properties, + } + } +} + +impl DisplayAs for ScannableExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ScannableExec: num_rows={:?}", self.num_rows) + } +} + +impl ExecutionPlan for ScannableExec { + fn name(&self) -> &str { + "ScannableExec" + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DFResult> { + if !children.is_empty() { + return Err(DataFusionError::Internal( + "ScannableExec does not have children".to_string(), + )); + } + Ok(self) + } + + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> DFResult { + if partition != 0 { + return Err(DataFusionError::Internal(format!( + "ScannableExec only supports partition 0, got {}", + partition + ))); + } + + let stream = match self.source.lock() { + Ok(mut guard) => guard.scan_as_stream(), + Err(poison) => poison.into_inner().scan_as_stream(), + }; + + Ok(stream.into_df_stream()) + } + + fn partition_statistics(&self, _partition: Option) -> DFResult { + Ok(Statistics { + num_rows: self + .num_rows + .map(Precision::Exact) + .unwrap_or(Precision::Absent), + total_byte_size: Precision::Absent, + column_statistics: vec![], + }) + } +}