From c0230f91d2518cea3a866dc609aac5051551b5e1 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 13 Feb 2026 14:18:36 -0800 Subject: [PATCH] feat(rust)!: accept `RecordBatch`, `Vec` in `create_table()` and `Table.add()` (#2948) BREAKING CHANGE: Arbitrary `impl RecordBatchReader` is no longer accepted, it must be made into `Box`. This PR replaces `IntoArrow` with a new trait `Scannable` to define input row data. This provides the following advantages: 1. **We can implement `Scannable` for more types than `IntoArrow`, such as `RecordBatch` and `Vec`.** The `IntoArrow` trait was implemented for arbitrary `T: RecordBatchReader`, and the Rust compiler would prevent us from implementing it for foreign types like `RecordBatch` because (theoretically) those types might implement `RecordBatchReader` in the future. That's why we implement `Scannable` for `Box` instead; since it's a concrete type it doesn't block implementing for other foreign types. 2. **We can potentially replay `Scannable` values**. Previously, we had to choose between buffering all data in memory and supporting retries of writes. But because `Scannable` things can optionally support re-scanning, we now have a way of supporting retries while also streaming. 3. **`Scannable` can provide hints like `num_rows`, which can be used to schedule parallel writers.** Without knowing the total number of rows, it's difficult to know whether it's worth writing multiple files in parallel. We don't yet fully take advantage of (2) and (3) yet, but will in future PRs. For (2), in order to be ready to leverage this, we need to hook the `Scannable` implementation up to Python and NodeJS bindings. Right now they always pass down a stream, but we want to make sure they support retries when possible. And for (3), this will need to be hooked up to #2939 and to a pipeline for running pre-processing steps (like embedding generation). ## Other changes * Moved `create_table` and `add_data` into their own modules. I've created a follow up issue to split up `table.rs` further, as it's by far the largest file: https://github.com/lancedb/lancedb/issues/2949 * Eliminated the `HAS_DATA` generic for `CreateTableBuilder`. I didn't see any public-facing places where we differentiated methods, which is why I felt this simplification was okay. * Added an `Error::External` variant and integrated some conversions to allow certain errors to pass through transparently. This will fully work once we upgrade Lance and get to take advantage of changes in https://github.com/lance-format/lance/pull/5606 * Added LZ4 compression support for write requests to remote endpoints. I checked and this has been supported on the server for > 1 year. --------- Co-authored-by: Claude Opus 4.6 --- nodejs/src/connection.rs | 1 + python/src/connection.rs | 3 +- python/src/table.rs | 3 +- rust/lancedb/examples/bedrock.rs | 10 +- rust/lancedb/examples/full_text_search.rs | 7 +- rust/lancedb/examples/hybrid_search.rs | 8 +- rust/lancedb/examples/ivf_pq.rs | 11 +- rust/lancedb/examples/openai.rs | 40 +- .../lancedb/examples/sentence_transformers.rs | 8 +- rust/lancedb/examples/simple.rs | 35 +- rust/lancedb/src/connection.rs | 697 +----------------- rust/lancedb/src/connection/create_table.rs | 612 +++++++++++++++ rust/lancedb/src/data.rs | 1 + rust/lancedb/src/data/scannable.rs | 580 +++++++++++++++ rust/lancedb/src/database.rs | 52 +- rust/lancedb/src/database/listing.rs | 127 +--- rust/lancedb/src/database/namespace.rs | 21 +- .../src/dataloader/permutation/builder.rs | 12 +- rust/lancedb/src/embeddings.rs | 188 +++-- rust/lancedb/src/error.rs | 24 +- rust/lancedb/src/io/object_store.rs | 3 +- rust/lancedb/src/ipc.rs | 4 +- rust/lancedb/src/lib.rs | 61 +- rust/lancedb/src/query.rs | 35 +- rust/lancedb/src/remote/client.rs | 12 +- rust/lancedb/src/remote/db.rs | 58 +- rust/lancedb/src/remote/table.rs | 279 ++++++- rust/lancedb/src/remote/util.rs | 51 +- rust/lancedb/src/table.rs | 431 +++-------- rust/lancedb/src/table/add_data.rs | 343 +++++++++ rust/lancedb/src/table/datafusion.rs | 59 +- rust/lancedb/src/table/datafusion/insert.rs | 54 +- rust/lancedb/src/table/datafusion/udtf/fts.rs | 166 +---- rust/lancedb/src/table/delete.rs | 19 +- rust/lancedb/src/table/optimize.rs | 84 +-- rust/lancedb/src/table/schema_evolution.rs | 86 +-- rust/lancedb/src/table/update.rs | 104 ++- rust/lancedb/src/test_utils.rs | 1 + rust/lancedb/src/test_utils/datagen.rs | 10 +- rust/lancedb/src/test_utils/embeddings.rs | 59 ++ rust/lancedb/tests/embedding_registry_test.rs | 3 +- rust/lancedb/tests/object_store_test.rs | 10 +- 42 files changed, 2466 insertions(+), 1906 deletions(-) create mode 100644 rust/lancedb/src/connection/create_table.rs create mode 100644 rust/lancedb/src/data/scannable.rs create mode 100644 rust/lancedb/src/table/add_data.rs create mode 100644 rust/lancedb/src/test_utils/embeddings.rs diff --git a/nodejs/src/connection.rs b/nodejs/src/connection.rs index a117ab41d..824ae7562 100644 --- a/nodejs/src/connection.rs +++ b/nodejs/src/connection.rs @@ -13,6 +13,7 @@ use crate::header::JsHeaderProvider; use crate::table::Table; use crate::ConnectionOptions; use lancedb::connection::{ConnectBuilder, Connection as LanceDBConnection}; + use lancedb::ipc::{ipc_file_to_batches, ipc_file_to_schema}; #[napi] diff --git a/python/src/connection.rs b/python/src/connection.rs index 2bd333ea0..7df89b101 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -121,7 +121,8 @@ impl Connection { let mode = Self::parse_create_mode_str(mode)?; - let batches = ArrowArrayStreamReader::from_pyarrow_bound(&data)?; + let batches: Box = + Box::new(ArrowArrayStreamReader::from_pyarrow_bound(&data)?); let mut builder = inner.create_table(name, batches).mode(mode); diff --git a/python/src/table.rs b/python/src/table.rs index 41e4df35f..353b22ff0 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -296,7 +296,8 @@ impl Table { data: Bound<'_, PyAny>, mode: String, ) -> PyResult> { - let batches = ArrowArrayStreamReader::from_pyarrow_bound(&data)?; + let batches: Box = + Box::new(ArrowArrayStreamReader::from_pyarrow_bound(&data)?); let mut op = self_.inner_ref()?.add(batches); if mode == "append" { op = op.mode(AddDataMode::Append); diff --git a/rust/lancedb/examples/bedrock.rs b/rust/lancedb/examples/bedrock.rs index 5cc7e0cbe..365453b60 100644 --- a/rust/lancedb/examples/bedrock.rs +++ b/rust/lancedb/examples/bedrock.rs @@ -3,13 +3,12 @@ use std::{iter::once, sync::Arc}; -use arrow_array::{Float64Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray}; +use arrow_array::{Float64Array, Int32Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use aws_config::Region; use aws_sdk_bedrockruntime::Client; use futures::StreamExt; use lancedb::{ - arrow::IntoArrow, connect, embeddings::{bedrock::BedrockEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction}, query::{ExecutableQuery, QueryBase}, @@ -67,7 +66,7 @@ async fn main() -> Result<()> { Ok(()) } -fn make_data() -> impl IntoArrow { +fn make_data() -> RecordBatch { let schema = Schema::new(vec![ Field::new("id", DataType::Int32, true), Field::new("text", DataType::Utf8, false), @@ -83,10 +82,9 @@ fn make_data() -> impl IntoArrow { ]); let price = Float64Array::from(vec![10.0, 50.0, 100.0, 30.0]); let schema = Arc::new(schema); - let rb = RecordBatch::try_new( + RecordBatch::try_new( schema.clone(), vec![Arc::new(id), Arc::new(text), Arc::new(price)], ) - .unwrap(); - Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema)) + .unwrap() } diff --git a/rust/lancedb/examples/full_text_search.rs b/rust/lancedb/examples/full_text_search.rs index 264c41487..15d7d7d88 100644 --- a/rust/lancedb/examples/full_text_search.rs +++ b/rust/lancedb/examples/full_text_search.rs @@ -3,12 +3,13 @@ use std::sync::Arc; -use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray}; +use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray}; use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; use lance_index::scalar::FullTextSearchQuery; use lancedb::connection::Connection; + use lancedb::index::scalar::FtsIndexBuilder; use lancedb::index::Index; use lancedb::query::{ExecutableQuery, QueryBase}; @@ -29,7 +30,7 @@ async fn main() -> Result<()> { Ok(()) } -fn create_some_records() -> Result> { +fn create_some_records() -> Result> { const TOTAL: usize = 1000; let schema = Arc::new(Schema::new(vec![ @@ -66,7 +67,7 @@ fn create_some_records() -> Result> { } async fn create_table(db: &Connection) -> Result { - let initial_data: Box = create_some_records()?; + let initial_data = create_some_records()?; let tbl = db.create_table("my_table", initial_data).execute().await?; Ok(tbl) } diff --git a/rust/lancedb/examples/hybrid_search.rs b/rust/lancedb/examples/hybrid_search.rs index 8a8dcda51..dc1a80a39 100644 --- a/rust/lancedb/examples/hybrid_search.rs +++ b/rust/lancedb/examples/hybrid_search.rs @@ -1,14 +1,13 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors -use arrow_array::{RecordBatch, RecordBatchIterator, StringArray}; +use arrow_array::{RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; use lance_index::scalar::FullTextSearchQuery; use lancedb::index::scalar::FtsIndexBuilder; use lancedb::index::Index; use lancedb::{ - arrow::IntoArrow, connect, embeddings::{ sentence_transformers::SentenceTransformersEmbeddings, EmbeddingDefinition, @@ -70,7 +69,7 @@ async fn main() -> Result<()> { Ok(()) } -fn make_data() -> impl IntoArrow { +fn make_data() -> RecordBatch { let schema = Schema::new(vec![Field::new("facts", DataType::Utf8, false)]); let facts = StringArray::from_iter_values(vec![ @@ -101,8 +100,7 @@ fn make_data() -> impl IntoArrow { "The first chatbot was ELIZA, created in the 1960s.", ]); let schema = Arc::new(schema); - let rb = RecordBatch::try_new(schema.clone(), vec![Arc::new(facts)]).unwrap(); - Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema)) + RecordBatch::try_new(schema.clone(), vec![Arc::new(facts)]).unwrap() } async fn create_index(table: &Table) -> Result<()> { diff --git a/rust/lancedb/examples/ivf_pq.rs b/rust/lancedb/examples/ivf_pq.rs index 9ce561c13..d171d5a45 100644 --- a/rust/lancedb/examples/ivf_pq.rs +++ b/rust/lancedb/examples/ivf_pq.rs @@ -8,13 +8,12 @@ use std::sync::Arc; use arrow_array::types::Float32Type; -use arrow_array::{ - FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader, -}; +use arrow_array::{FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; use lancedb::connection::Connection; + use lancedb::index::vector::IvfPqIndexBuilder; use lancedb::index::Index; use lancedb::query::{ExecutableQuery, QueryBase}; @@ -34,7 +33,7 @@ async fn main() -> Result<()> { Ok(()) } -fn create_some_records() -> Result> { +fn create_some_records() -> Result> { const TOTAL: usize = 1000; const DIM: usize = 128; @@ -73,9 +72,9 @@ fn create_some_records() -> Result> { } async fn create_table(db: &Connection) -> Result
{ - let initial_data: Box = create_some_records()?; + let initial_data = create_some_records()?; let tbl = db - .create_table("my_table", Box::new(initial_data)) + .create_table("my_table", initial_data) .execute() .await .unwrap(); diff --git a/rust/lancedb/examples/openai.rs b/rust/lancedb/examples/openai.rs index 73954f34a..2194e519a 100644 --- a/rust/lancedb/examples/openai.rs +++ b/rust/lancedb/examples/openai.rs @@ -5,11 +5,9 @@ use std::{iter::once, sync::Arc}; -use arrow_array::{Float64Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray}; -use arrow_schema::{DataType, Field, Schema}; +use arrow_array::{RecordBatch, StringArray}; use futures::StreamExt; use lancedb::{ - arrow::IntoArrow, connect, embeddings::{openai::OpenAIEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction}, query::{ExecutableQuery, QueryBase}, @@ -64,26 +62,20 @@ async fn main() -> Result<()> { } // --8<-- [end:openai_embeddings] -fn make_data() -> impl IntoArrow { - let schema = Schema::new(vec![ - Field::new("id", DataType::Int32, true), - Field::new("text", DataType::Utf8, false), - Field::new("price", DataType::Float64, false), - ]); - - let id = Int32Array::from(vec![1, 2, 3, 4]); - let text = StringArray::from_iter_values(vec![ - "Black T-Shirt", - "Leather Jacket", - "Winter Parka", - "Hooded Sweatshirt", - ]); - let price = Float64Array::from(vec![10.0, 50.0, 100.0, 30.0]); - let schema = Arc::new(schema); - let rb = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(id), Arc::new(text), Arc::new(price)], +fn make_data() -> RecordBatch { + arrow_array::record_batch!( + ("id", Int32, [1, 2, 3, 4]), + ( + "text", + Utf8, + [ + "Black T-Shirt", + "Leather Jacket", + "Winter Parka", + "Hooded Sweatshirt" + ] + ), + ("price", Float64, [10.0, 50.0, 100.0, 30.0]) ) - .unwrap(); - Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema)) + .unwrap() } diff --git a/rust/lancedb/examples/sentence_transformers.rs b/rust/lancedb/examples/sentence_transformers.rs index 9250430d4..2fa7ed30b 100644 --- a/rust/lancedb/examples/sentence_transformers.rs +++ b/rust/lancedb/examples/sentence_transformers.rs @@ -3,11 +3,10 @@ use std::{iter::once, sync::Arc}; -use arrow_array::{RecordBatch, RecordBatchIterator, StringArray}; +use arrow_array::{RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use futures::StreamExt; use lancedb::{ - arrow::IntoArrow, connect, embeddings::{ sentence_transformers::SentenceTransformersEmbeddings, EmbeddingDefinition, @@ -59,7 +58,7 @@ async fn main() -> Result<()> { Ok(()) } -fn make_data() -> impl IntoArrow { +fn make_data() -> RecordBatch { let schema = Schema::new(vec![Field::new("facts", DataType::Utf8, false)]); let facts = StringArray::from_iter_values(vec![ @@ -90,6 +89,5 @@ fn make_data() -> impl IntoArrow { "The first chatbot was ELIZA, created in the 1960s.", ]); let schema = Arc::new(schema); - let rb = RecordBatch::try_new(schema.clone(), vec![Arc::new(facts)]).unwrap(); - Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema)) + RecordBatch::try_new(schema.clone(), vec![Arc::new(facts)]).unwrap() } diff --git a/rust/lancedb/examples/simple.rs b/rust/lancedb/examples/simple.rs index 57b7be70b..157083d0a 100644 --- a/rust/lancedb/examples/simple.rs +++ b/rust/lancedb/examples/simple.rs @@ -8,11 +8,9 @@ use std::sync::Arc; use arrow_array::types::Float32Type; -use arrow_array::{FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterator}; +use arrow_array::{FixedSizeListArray, Int32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; - -use lancedb::arrow::IntoArrow; use lancedb::connection::Connection; use lancedb::index::Index; use lancedb::query::{ExecutableQuery, QueryBase}; @@ -59,7 +57,7 @@ async fn open_with_existing_tbl() -> Result<()> { Ok(()) } -fn create_some_records() -> Result { +fn create_some_records() -> Result { const TOTAL: usize = 1000; const DIM: usize = 128; @@ -76,25 +74,18 @@ fn create_some_records() -> Result { ])); // Create a RecordBatch stream. - let batches = RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)), - Arc::new( - FixedSizeListArray::from_iter_primitive::( - (0..TOTAL).map(|_| Some(vec![Some(1.0); DIM])), - DIM as i32, - ), - ), - ], - ) - .unwrap()] - .into_iter() - .map(Ok), + Ok(RecordBatch::try_new( schema.clone(), - ); - Ok(Box::new(batches)) + vec![ + Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)), + Arc::new( + FixedSizeListArray::from_iter_primitive::( + (0..TOTAL).map(|_| Some(vec![Some(1.0); DIM])), + DIM as i32, + ), + ), + ], + )?) } async fn create_table(db: &Connection) -> Result { diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index f58aba4ea..1e22e7c8f 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -6,8 +6,8 @@ use std::collections::HashMap; use std::sync::Arc; -use arrow_array::RecordBatchReader; -use arrow_schema::{Field, SchemaRef}; +use arrow_array::RecordBatch; +use arrow_schema::SchemaRef; use lance::dataset::ReadParams; use lance_namespace::models::{ CreateNamespaceRequest, CreateNamespaceResponse, DescribeNamespaceRequest, @@ -17,24 +17,20 @@ use lance_namespace::models::{ #[cfg(feature = "aws")] use object_store::aws::AwsCredential; -use crate::arrow::{IntoArrow, IntoArrowStream, SendableRecordBatchStream}; -use crate::database::listing::{ - ListingDatabase, OPT_NEW_TABLE_STORAGE_VERSION, OPT_NEW_TABLE_V2_MANIFEST_PATHS, -}; +use crate::connection::create_table::CreateTableBuilder; +use crate::data::scannable::Scannable; +use crate::database::listing::ListingDatabase; use crate::database::{ - CloneTableRequest, CreateTableData, CreateTableMode, CreateTableRequest, Database, - DatabaseOptions, OpenTableRequest, ReadConsistency, TableNamesRequest, -}; -use crate::embeddings::{ - EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings, + CloneTableRequest, Database, DatabaseOptions, OpenTableRequest, ReadConsistency, + TableNamesRequest, }; +use crate::embeddings::{EmbeddingRegistry, MemoryRegistry}; use crate::error::{Error, Result}; #[cfg(feature = "remote")] use crate::remote::{ client::ClientConfig, db::{OPT_REMOTE_API_KEY, OPT_REMOTE_HOST_OVERRIDE, OPT_REMOTE_REGION}, }; -use crate::table::{TableDefinition, WriteOptions}; use crate::Table; use lance::io::ObjectStoreParams; pub use lance_encoding::version::LanceFileVersion; @@ -42,6 +38,8 @@ pub use lance_encoding::version::LanceFileVersion; use lance_io::object_store::StorageOptions; use lance_io::object_store::{StorageOptionsAccessor, StorageOptionsProvider}; +mod create_table; + fn merge_storage_options( store_params: &mut ObjectStoreParams, pairs: impl IntoIterator, @@ -116,337 +114,6 @@ impl TableNamesBuilder { } } -pub struct NoData {} - -impl IntoArrow for NoData { - fn into_arrow(self) -> Result> { - unreachable!("NoData should never be converted to Arrow") - } -} - -// Stores the value given from the initial CreateTableBuilder::new call -// and defers errors until `execute` is called -enum CreateTableBuilderInitialData { - None, - Iterator(Result>), - Stream(Result), -} - -/// A builder for configuring a [`Connection::create_table`] operation -pub struct CreateTableBuilder { - parent: Arc, - embeddings: Vec<(EmbeddingDefinition, Arc)>, - embedding_registry: Arc, - request: CreateTableRequest, - // This is a bit clumsy but we defer errors until `execute` is called - // to maintain backwards compatibility - data: CreateTableBuilderInitialData, -} - -// Builder methods that only apply when we have initial data -impl CreateTableBuilder { - fn new( - parent: Arc, - name: String, - data: T, - embedding_registry: Arc, - ) -> Self { - let dummy_schema = Arc::new(arrow_schema::Schema::new(Vec::::default())); - Self { - parent, - request: CreateTableRequest::new( - name, - CreateTableData::Empty(TableDefinition::new_from_schema(dummy_schema)), - ), - embeddings: Vec::new(), - embedding_registry, - data: CreateTableBuilderInitialData::Iterator(data.into_arrow()), - } - } - - fn new_streaming( - parent: Arc, - name: String, - data: T, - embedding_registry: Arc, - ) -> Self { - let dummy_schema = Arc::new(arrow_schema::Schema::new(Vec::::default())); - Self { - parent, - request: CreateTableRequest::new( - name, - CreateTableData::Empty(TableDefinition::new_from_schema(dummy_schema)), - ), - embeddings: Vec::new(), - embedding_registry, - data: CreateTableBuilderInitialData::Stream(data.into_arrow()), - } - } - - /// Execute the create table operation - pub async fn execute(self) -> Result
{ - let embedding_registry = self.embedding_registry.clone(); - let parent = self.parent.clone(); - let request = self.into_request()?; - Ok(Table::new_with_embedding_registry( - parent.create_table(request).await?, - parent, - embedding_registry, - )) - } - - fn into_request(self) -> Result { - if self.embeddings.is_empty() { - match self.data { - CreateTableBuilderInitialData::Iterator(maybe_iter) => { - let data = maybe_iter?; - Ok(CreateTableRequest { - data: CreateTableData::Data(data), - ..self.request - }) - } - CreateTableBuilderInitialData::None => { - unreachable!("No data provided for CreateTableBuilder") - } - CreateTableBuilderInitialData::Stream(maybe_stream) => { - let data = maybe_stream?; - Ok(CreateTableRequest { - data: CreateTableData::StreamingData(data), - ..self.request - }) - } - } - } else { - let CreateTableBuilderInitialData::Iterator(maybe_iter) = self.data else { - return Err(Error::NotSupported { message: "Creating a table with embeddings is currently not support when the input is streaming".to_string() }); - }; - let data = maybe_iter?; - let data = Box::new(WithEmbeddings::new(data, self.embeddings)); - Ok(CreateTableRequest { - data: CreateTableData::Data(data), - ..self.request - }) - } - } -} - -// Builder methods that only apply when we do not have initial data -impl CreateTableBuilder { - fn new( - parent: Arc, - name: String, - schema: SchemaRef, - embedding_registry: Arc, - ) -> Self { - let table_definition = TableDefinition::new_from_schema(schema); - Self { - parent, - request: CreateTableRequest::new(name, CreateTableData::Empty(table_definition)), - data: CreateTableBuilderInitialData::None, - embeddings: Vec::default(), - embedding_registry, - } - } - - /// Execute the create table operation - pub async fn execute(self) -> Result
{ - let parent = self.parent.clone(); - let embedding_registry = self.embedding_registry.clone(); - let request = self.into_request()?; - Ok(Table::new_with_embedding_registry( - parent.create_table(request).await?, - parent, - embedding_registry, - )) - } - - fn into_request(self) -> Result { - if self.embeddings.is_empty() { - return Ok(self.request); - } - - let CreateTableData::Empty(table_def) = self.request.data else { - unreachable!("CreateTableBuilder should always have Empty data") - }; - - let schema = table_def.schema.clone(); - let empty_batch = arrow_array::RecordBatch::new_empty(schema.clone()); - - let reader = Box::new(std::iter::once(Ok(empty_batch)).collect::>()); - let reader = arrow_array::RecordBatchIterator::new(reader.into_iter(), schema); - let with_embeddings = WithEmbeddings::new(reader, self.embeddings); - let table_definition = with_embeddings.table_definition()?; - - Ok(CreateTableRequest { - data: CreateTableData::Empty(table_definition), - ..self.request - }) - } -} - -impl CreateTableBuilder { - /// Set the mode for creating the table - /// - /// This controls what happens if a table with the given name already exists - pub fn mode(mut self, mode: CreateTableMode) -> Self { - self.request.mode = mode; - self - } - - /// Apply the given write options when writing the initial data - pub fn write_options(mut self, write_options: WriteOptions) -> Self { - self.request.write_options = write_options; - self - } - - /// Set an option for the storage layer. - /// - /// Options already set on the connection will be inherited by the table, - /// but can be overridden here. - /// - /// See available options at - pub fn storage_option(mut self, key: impl Into, value: impl Into) -> Self { - let store_params = self - .request - .write_options - .lance_write_params - .get_or_insert(Default::default()) - .store_params - .get_or_insert(Default::default()); - merge_storage_options(store_params, [(key.into(), value.into())]); - self - } - - /// Set multiple options for the storage layer. - /// - /// Options already set on the connection will be inherited by the table, - /// but can be overridden here. - /// - /// See available options at - pub fn storage_options( - mut self, - pairs: impl IntoIterator, impl Into)>, - ) -> Self { - let store_params = self - .request - .write_options - .lance_write_params - .get_or_insert(Default::default()) - .store_params - .get_or_insert(Default::default()); - let updates = pairs - .into_iter() - .map(|(key, value)| (key.into(), value.into())); - merge_storage_options(store_params, updates); - self - } - - /// Add an embedding definition to the table. - /// - /// The `embedding_name` must match the name of an embedding function that - /// was previously registered with the connection's [`EmbeddingRegistry`]. - pub fn add_embedding(mut self, definition: EmbeddingDefinition) -> Result { - // Early verification of the embedding name - let embedding_func = self - .embedding_registry - .get(&definition.embedding_name) - .ok_or_else(|| Error::EmbeddingFunctionNotFound { - name: definition.embedding_name.clone(), - reason: "No embedding function found in the connection's embedding_registry" - .to_string(), - })?; - - self.embeddings.push((definition, embedding_func)); - Ok(self) - } - - /// Set whether to use V2 manifest paths for the table. (default: false) - /// - /// These paths provide more efficient opening of tables with many - /// versions on object stores. - /// - ///
Turning this on will make the dataset unreadable - /// for older versions of LanceDB (prior to 0.10.0).
- /// - /// To migrate an existing dataset, instead use the - /// [[NativeTable::migrate_manifest_paths_v2]]. - /// - /// This has no effect in LanceDB Cloud. - #[deprecated(since = "0.15.1", note = "Use `database_options` instead")] - pub fn enable_v2_manifest_paths(mut self, use_v2_manifest_paths: bool) -> Self { - let store_params = self - .request - .write_options - .lance_write_params - .get_or_insert_with(Default::default) - .store_params - .get_or_insert_with(Default::default); - let value = if use_v2_manifest_paths { - "true".to_string() - } else { - "false".to_string() - }; - merge_storage_options( - store_params, - [(OPT_NEW_TABLE_V2_MANIFEST_PATHS.to_string(), value)], - ); - self - } - - /// Set the data storage version. - /// - /// The default is `LanceFileVersion::Stable`. - #[deprecated(since = "0.15.1", note = "Use `database_options` instead")] - pub fn data_storage_version(mut self, data_storage_version: LanceFileVersion) -> Self { - let store_params = self - .request - .write_options - .lance_write_params - .get_or_insert_with(Default::default) - .store_params - .get_or_insert_with(Default::default); - merge_storage_options( - store_params, - [( - OPT_NEW_TABLE_STORAGE_VERSION.to_string(), - data_storage_version.to_string(), - )], - ); - self - } - - /// Set the namespace for the table - pub fn namespace(mut self, namespace: Vec) -> Self { - self.request.namespace = namespace; - self - } - - /// Set a custom location for the table. - /// - /// If not set, the database will derive a location from its URI and the table name. - /// This is useful when integrating with namespace systems that manage table locations. - pub fn location(mut self, location: impl Into) -> Self { - self.request.location = Some(location.into()); - self - } - - /// Set a storage options provider for automatic credential refresh. - /// - /// This allows tables to automatically refresh cloud storage credentials - /// when they expire, enabling long-running operations on remote storage. - pub fn storage_options_provider(mut self, provider: Arc) -> Self { - let store_params = self - .request - .write_options - .lance_write_params - .get_or_insert(Default::default()) - .store_params - .get_or_insert(Default::default()); - set_storage_options_provider(store_params, provider); - self - } -} - #[derive(Clone, Debug)] pub struct OpenTableBuilder { parent: Arc, @@ -684,35 +351,17 @@ impl Connection { /// /// * `name` - The name of the table /// * `initial_data` - The initial data to write to the table - pub fn create_table( + pub fn create_table( &self, name: impl Into, initial_data: T, - ) -> CreateTableBuilder { - CreateTableBuilder::::new( + ) -> CreateTableBuilder { + let initial_data = Box::new(initial_data); + CreateTableBuilder::new( self.internal.clone(), + self.embedding_registry.clone(), name.into(), initial_data, - self.embedding_registry.clone(), - ) - } - - /// Create a new table from a stream of data - /// - /// # Parameters - /// - /// * `name` - The name of the table - /// * `initial_data` - The initial data to write to the table - pub fn create_table_streaming( - &self, - name: impl Into, - initial_data: T, - ) -> CreateTableBuilder { - CreateTableBuilder::::new_streaming( - self.internal.clone(), - name.into(), - initial_data, - self.embedding_registry.clone(), ) } @@ -726,13 +375,9 @@ impl Connection { &self, name: impl Into, schema: SchemaRef, - ) -> CreateTableBuilder { - CreateTableBuilder::::new( - self.internal.clone(), - name.into(), - schema, - self.embedding_registry.clone(), - ) + ) -> CreateTableBuilder { + let empty_batch = RecordBatch::new_empty(schema); + self.create_table(name, empty_batch) } /// Open an existing table in the database @@ -1349,20 +994,11 @@ mod test_utils { #[cfg(test)] mod tests { - use crate::database::listing::{ListingDatabaseOptions, NewTableConfig}; - use crate::query::QueryBase; - use crate::query::{ExecutableQuery, QueryExecutionOptions}; - use crate::test_utils::connection::new_test_connection; - use arrow::compute::concat_batches; - use arrow_array::RecordBatchReader; use arrow_schema::{DataType, Field, Schema}; - use datafusion_physical_plan::stream::RecordBatchStreamAdapter; - use futures::{stream, TryStreamExt}; - use lance_core::error::{ArrowResult, DataFusionResult}; use lance_testing::datagen::{BatchGenerator, IncrementingInt32}; use tempfile::tempdir; - use crate::arrow::SimpleRecordBatchStream; + use crate::test_utils::connection::new_test_connection; use super::*; @@ -1478,139 +1114,6 @@ mod tests { assert_eq!(tables, vec!["table1".to_owned()]); } - fn make_data() -> Box { - let id = Box::new(IncrementingInt32::new().named("id".to_string())); - Box::new(BatchGenerator::new().col(id).batches(10, 2000)) - } - - #[tokio::test] - async fn test_create_table_v2() { - let tmp_dir = tempdir().unwrap(); - let uri = tmp_dir.path().to_str().unwrap(); - let db = connect(uri) - .database_options(&ListingDatabaseOptions { - new_table_config: NewTableConfig { - data_storage_version: Some(LanceFileVersion::Legacy), - ..Default::default() - }, - ..Default::default() - }) - .execute() - .await - .unwrap(); - - let tbl = db - .create_table("v1_test", make_data()) - .execute() - .await - .unwrap(); - - // In v1 the row group size will trump max_batch_length - let batches = tbl - .query() - .limit(20000) - .execute_with_options(QueryExecutionOptions { - max_batch_length: 50000, - ..Default::default() - }) - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - assert_eq!(batches.len(), 20); - - let db = connect(uri) - .database_options(&ListingDatabaseOptions { - new_table_config: NewTableConfig { - data_storage_version: Some(LanceFileVersion::Stable), - ..Default::default() - }, - ..Default::default() - }) - .execute() - .await - .unwrap(); - - let tbl = db - .create_table("v2_test", make_data()) - .execute() - .await - .unwrap(); - - // In v2 the page size is much bigger than 50k so we should get a single batch - let batches = tbl - .query() - .execute_with_options(QueryExecutionOptions { - max_batch_length: 50000, - ..Default::default() - }) - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - - assert_eq!(batches.len(), 1); - } - - #[tokio::test] - async fn test_create_table_streaming() { - let tmp_dir = tempdir().unwrap(); - - let uri = tmp_dir.path().to_str().unwrap(); - let db = connect(uri).execute().await.unwrap(); - - let batches = make_data().collect::>>().unwrap(); - - let schema = batches.first().unwrap().schema(); - let one_batch = concat_batches(&schema, batches.iter()).unwrap(); - - let ldb_stream = stream::iter(batches.clone().into_iter().map(Result::Ok)); - let ldb_stream: SendableRecordBatchStream = - Box::pin(SimpleRecordBatchStream::new(ldb_stream, schema.clone())); - - let tbl1 = db - .create_table_streaming("one", ldb_stream) - .execute() - .await - .unwrap(); - - let df_stream = stream::iter(batches.into_iter().map(DataFusionResult::Ok)); - let df_stream: datafusion_physical_plan::SendableRecordBatchStream = - Box::pin(RecordBatchStreamAdapter::new(schema.clone(), df_stream)); - - let tbl2 = db - .create_table_streaming("two", df_stream) - .execute() - .await - .unwrap(); - - let tbl1_data = tbl1 - .query() - .execute() - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - - let tbl1_data = concat_batches(&schema, tbl1_data.iter()).unwrap(); - assert_eq!(tbl1_data, one_batch); - - let tbl2_data = tbl2 - .query() - .execute() - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - - let tbl2_data = concat_batches(&schema, tbl2_data.iter()).unwrap(); - assert_eq!(tbl2_data, one_batch); - } - #[tokio::test] async fn drop_table() { let tc = new_test_connection().await.unwrap(); @@ -1640,41 +1143,6 @@ mod tests { assert_eq!(tables.len(), 0); } - #[tokio::test] - async fn test_create_table_already_exists() { - let tmp_dir = tempdir().unwrap(); - let uri = tmp_dir.path().to_str().unwrap(); - let db = connect(uri).execute().await.unwrap(); - let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)])); - db.create_empty_table("test", schema.clone()) - .execute() - .await - .unwrap(); - // TODO: None of the open table options are "inspectable" right now but once one is we - // should assert we are passing these options in correctly - db.create_empty_table("test", schema) - .mode(CreateTableMode::exist_ok(|mut req| { - req.index_cache_size = Some(16); - req - })) - .execute() - .await - .unwrap(); - let other_schema = Arc::new(Schema::new(vec![Field::new("y", DataType::Int32, false)])); - assert!(db - .create_empty_table("test", other_schema.clone()) - .execute() - .await - .is_err()); - let overwritten = db - .create_empty_table("test", other_schema.clone()) - .mode(CreateTableMode::Overwrite) - .execute() - .await - .unwrap(); - assert_eq!(other_schema, overwritten.schema().await.unwrap()); - } - #[tokio::test] async fn test_clone_table() { let tmp_dir = tempdir().unwrap(); @@ -1685,7 +1153,8 @@ mod tests { let mut batch_gen = BatchGenerator::new() .col(Box::new(IncrementingInt32::new().named("id"))) .col(Box::new(IncrementingInt32::new().named("value"))); - let reader = batch_gen.batches(5, 100); + let reader: Box = + Box::new(batch_gen.batches(5, 100)); let source_table = db .create_table("source_table", reader) @@ -1720,128 +1189,4 @@ mod tests { let cloned_count = cloned_table.count_rows(None).await.unwrap(); assert_eq!(source_count, cloned_count); } - - #[tokio::test] - async fn test_create_empty_table_with_embeddings() { - use crate::embeddings::{EmbeddingDefinition, EmbeddingFunction}; - use arrow_array::{ - Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray, - }; - use std::borrow::Cow; - - #[derive(Debug, Clone)] - struct MockEmbedding { - dim: usize, - } - - impl EmbeddingFunction for MockEmbedding { - fn name(&self) -> &str { - "test_embedding" - } - - fn source_type(&self) -> Result> { - Ok(Cow::Owned(DataType::Utf8)) - } - - fn dest_type(&self) -> Result> { - Ok(Cow::Owned(DataType::new_fixed_size_list( - DataType::Float32, - self.dim as i32, - true, - ))) - } - - fn compute_source_embeddings(&self, source: Arc) -> Result> { - let len = source.len(); - let values = vec![1.0f32; len * self.dim]; - let values = Arc::new(Float32Array::from(values)); - let field = Arc::new(Field::new("item", DataType::Float32, true)); - Ok(Arc::new(FixedSizeListArray::new( - field, - self.dim as i32, - values, - None, - ))) - } - - fn compute_query_embeddings(&self, _input: Arc) -> Result> { - unimplemented!() - } - } - - let tmp_dir = tempdir().unwrap(); - let uri = tmp_dir.path().to_str().unwrap(); - let db = connect(uri).execute().await.unwrap(); - - let embed_func = Arc::new(MockEmbedding { dim: 128 }); - db.embedding_registry() - .register("test_embedding", embed_func.clone()) - .unwrap(); - - let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)])); - let ed = EmbeddingDefinition { - source_column: "name".to_owned(), - dest_column: Some("name_embedding".to_owned()), - embedding_name: "test_embedding".to_owned(), - }; - - let table = db - .create_empty_table("test", schema) - .mode(CreateTableMode::Overwrite) - .add_embedding(ed) - .unwrap() - .execute() - .await - .unwrap(); - - let table_schema = table.schema().await.unwrap(); - assert!(table_schema.column_with_name("name").is_some()); - assert!(table_schema.column_with_name("name_embedding").is_some()); - - let embedding_field = table_schema.field_with_name("name_embedding").unwrap(); - assert_eq!( - embedding_field.data_type(), - &DataType::new_fixed_size_list(DataType::Float32, 128, true) - ); - - let input_schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)])); - let input_batch = RecordBatch::try_new( - input_schema.clone(), - vec![Arc::new(StringArray::from(vec![ - Some("Alice"), - Some("Bob"), - Some("Charlie"), - ]))], - ) - .unwrap(); - - let input_reader = Box::new(RecordBatchIterator::new( - vec![Ok(input_batch)].into_iter(), - input_schema, - )); - - table.add(input_reader).execute().await.unwrap(); - - let results = table - .query() - .execute() - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - - assert_eq!(results.len(), 1); - let batch = &results[0]; - assert_eq!(batch.num_rows(), 3); - assert!(batch.column_by_name("name_embedding").is_some()); - - let embedding_col = batch - .column_by_name("name_embedding") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(embedding_col.len(), 3); - } } diff --git a/rust/lancedb/src/connection/create_table.rs b/rust/lancedb/src/connection/create_table.rs new file mode 100644 index 000000000..8eb7d2207 --- /dev/null +++ b/rust/lancedb/src/connection/create_table.rs @@ -0,0 +1,612 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::sync::Arc; + +use lance_io::object_store::StorageOptionsProvider; + +use crate::{ + connection::{merge_storage_options, set_storage_options_provider}, + data::scannable::{Scannable, WithEmbeddingsScannable}, + database::{CreateTableMode, CreateTableRequest, Database}, + embeddings::{EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry}, + table::WriteOptions, + Error, Result, Table, +}; + +pub struct CreateTableBuilder { + parent: Arc, + embeddings: Vec<(EmbeddingDefinition, Arc)>, + embedding_registry: Arc, + request: CreateTableRequest, +} + +impl CreateTableBuilder { + pub(super) fn new( + parent: Arc, + embedding_registry: Arc, + name: String, + data: Box, + ) -> Self { + Self { + parent, + embeddings: Vec::new(), + embedding_registry, + request: CreateTableRequest::new(name, data), + } + } + + /// Set the mode for creating the table + /// + /// This controls what happens if a table with the given name already exists + pub fn mode(mut self, mode: CreateTableMode) -> Self { + self.request.mode = mode; + self + } + + /// Apply the given write options when writing the initial data + pub fn write_options(mut self, write_options: WriteOptions) -> Self { + self.request.write_options = write_options; + self + } + + /// Set an option for the storage layer. + /// + /// Options already set on the connection will be inherited by the table, + /// but can be overridden here. + /// + /// See available options at + pub fn storage_option(mut self, key: impl Into, value: impl Into) -> Self { + let store_params = self + .request + .write_options + .lance_write_params + .get_or_insert(Default::default()) + .store_params + .get_or_insert(Default::default()); + merge_storage_options(store_params, [(key.into(), value.into())]); + self + } + + /// Set multiple options for the storage layer. + /// + /// Options already set on the connection will be inherited by the table, + /// but can be overridden here. + /// + /// See available options at + pub fn storage_options( + mut self, + pairs: impl IntoIterator, impl Into)>, + ) -> Self { + let store_params = self + .request + .write_options + .lance_write_params + .get_or_insert(Default::default()) + .store_params + .get_or_insert(Default::default()); + let updates = pairs + .into_iter() + .map(|(key, value)| (key.into(), value.into())); + merge_storage_options(store_params, updates); + self + } + + /// Add an embedding definition to the table. + /// + /// The `embedding_name` must match the name of an embedding function that + /// was previously registered with the connection's [`EmbeddingRegistry`]. + pub fn add_embedding(mut self, definition: EmbeddingDefinition) -> Result { + // Early verification of the embedding name + let embedding_func = self + .embedding_registry + .get(&definition.embedding_name) + .ok_or_else(|| Error::EmbeddingFunctionNotFound { + name: definition.embedding_name.clone(), + reason: "No embedding function found in the connection's embedding_registry" + .to_string(), + })?; + + self.embeddings.push((definition, embedding_func)); + Ok(self) + } + + /// Set the namespace for the table + pub fn namespace(mut self, namespace: Vec) -> Self { + self.request.namespace = namespace; + self + } + + /// Set a custom location for the table. + /// + /// If not set, the database will derive a location from its URI and the table name. + /// This is useful when integrating with namespace systems that manage table locations. + pub fn location(mut self, location: impl Into) -> Self { + self.request.location = Some(location.into()); + self + } + + /// Set a storage options provider for automatic credential refresh. + /// + /// This allows tables to automatically refresh cloud storage credentials + /// when they expire, enabling long-running operations on remote storage. + pub fn storage_options_provider(mut self, provider: Arc) -> Self { + let store_params = self + .request + .write_options + .lance_write_params + .get_or_insert(Default::default()) + .store_params + .get_or_insert(Default::default()); + set_storage_options_provider(store_params, provider); + self + } + + /// Execute the create table operation + pub async fn execute(mut self) -> Result
{ + let embedding_registry = self.embedding_registry.clone(); + let parent = self.parent.clone(); + + // If embeddings were configured via add_embedding(), wrap the data + if !self.embeddings.is_empty() { + let wrapped_data: Box = Box::new(WithEmbeddingsScannable::try_new( + self.request.data, + self.embeddings, + )?); + self.request.data = wrapped_data; + } + + Ok(Table::new_with_embedding_registry( + parent.create_table(self.request).await?, + parent, + embedding_registry, + )) + } +} + +#[cfg(test)] +mod tests { + use arrow_array::{ + record_batch, Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, + }; + use arrow_schema::{ArrowError, DataType, Field, Schema}; + use futures::TryStreamExt; + use lance_file::version::LanceFileVersion; + use tempfile::tempdir; + + use crate::{ + arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}, + connect, + database::listing::{ListingDatabaseOptions, NewTableConfig}, + embeddings::{EmbeddingDefinition, EmbeddingFunction, MemoryRegistry}, + query::{ExecutableQuery, QueryBase, Select}, + test_utils::embeddings::MockEmbed, + }; + use std::borrow::Cow; + + use super::*; + + #[tokio::test] + async fn create_empty_table() { + let db = connect("memory://").execute().await.unwrap(); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("value", DataType::Float64, false), + ])); + db.create_empty_table("name", schema.clone()) + .execute() + .await + .unwrap(); + let table = db.open_table("name").execute().await.unwrap(); + assert_eq!(table.schema().await.unwrap(), schema); + assert_eq!(table.count_rows(None).await.unwrap(), 0); + } + + async fn test_create_table_with_data(data: T) + where + T: Scannable + 'static, + { + let db = connect("memory://").execute().await.unwrap(); + let schema = data.schema(); + db.create_table("data_table", data).execute().await.unwrap(); + let table = db.open_table("data_table").execute().await.unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), 3); + assert_eq!(table.schema().await.unwrap(), schema); + } + + #[tokio::test] + async fn create_table_with_batch() { + let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap(); + test_create_table_with_data(batch).await; + } + + #[tokio::test] + async fn test_create_table_with_vec_batch() { + let data = vec![ + record_batch!(("id", Int64, [1, 2])).unwrap(), + record_batch!(("id", Int64, [3])).unwrap(), + ]; + test_create_table_with_data(data).await; + } + + #[tokio::test] + async fn test_create_table_with_record_batch_reader() { + let data = vec![ + record_batch!(("id", Int64, [1, 2])).unwrap(), + record_batch!(("id", Int64, [3])).unwrap(), + ]; + let schema = data[0].schema(); + let reader: Box = Box::new( + RecordBatchIterator::new(data.into_iter().map(Ok), schema.clone()), + ); + test_create_table_with_data(reader).await; + } + + #[tokio::test] + async fn test_create_table_with_stream() { + let data = vec![ + record_batch!(("id", Int64, [1, 2])).unwrap(), + record_batch!(("id", Int64, [3])).unwrap(), + ]; + let schema = data[0].schema(); + let inner = futures::stream::iter(data.into_iter().map(Ok)); + let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream { + schema, + stream: inner, + }); + test_create_table_with_data(stream).await; + } + + #[derive(Debug)] + struct MyError; + + impl std::fmt::Display for MyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MyError occurred") + } + } + + impl std::error::Error for MyError {} + + #[tokio::test] + async fn test_create_preserves_reader_error() { + let first_batch = record_batch!(("id", Int64, [1, 2])).unwrap(); + let schema = first_batch.schema(); + let iterator = vec![ + Ok(first_batch), + Err(ArrowError::ExternalError(Box::new(MyError))), + ]; + let reader: Box = Box::new( + RecordBatchIterator::new(iterator.into_iter(), schema.clone()), + ); + + let db = connect("memory://").execute().await.unwrap(); + let result = db.create_table("failing_table", reader).execute().await; + + assert!(result.is_err()); + // TODO: when we upgrade to Lance 2.0.0, this should pass + // assert!(matches!(result, Err(Error::External { source}) + // if source.downcast_ref::().is_some() + // )); + } + + #[tokio::test] + async fn test_create_preserves_stream_error() { + let first_batch = record_batch!(("id", Int64, [1, 2])).unwrap(); + let schema = first_batch.schema(); + let iterator = vec![ + Ok(first_batch), + Err(Error::External { + source: Box::new(MyError), + }), + ]; + let stream = futures::stream::iter(iterator); + let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream { + schema: schema.clone(), + stream, + }); + + let db = connect("memory://").execute().await.unwrap(); + let result = db + .create_table("failing_stream_table", stream) + .execute() + .await; + + assert!(result.is_err()); + // TODO: when we upgrade to Lance 2.0.0, this should pass + // assert!(matches!(result, Err(Error::External { source}) + // if source.downcast_ref::().is_some() + // )); + } + + #[tokio::test] + #[allow(deprecated)] + async fn test_create_table_with_storage_options() { + let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap(); + let db = connect("memory://").execute().await.unwrap(); + + let table = db + .create_table("options_table", batch) + .storage_option("timeout", "30s") + .storage_options([("retry_count", "3")]) + .execute() + .await + .unwrap(); + + let final_options = table.storage_options().await.unwrap(); + assert_eq!(final_options.get("timeout"), Some(&"30s".to_string())); + assert_eq!(final_options.get("retry_count"), Some(&"3".to_string())); + } + + #[tokio::test] + async fn test_create_table_unregistered_embedding() { + let db = connect("memory://").execute().await.unwrap(); + let batch = record_batch!(("text", Utf8, ["hello", "world"])).unwrap(); + + // Try to add an embedding that doesn't exist in the registry + let result = db + .create_table("embed_table", batch) + .add_embedding(EmbeddingDefinition::new( + "text", + "nonexistent_embedding_function", + None::<&str>, + )); + + match result { + Err(Error::EmbeddingFunctionNotFound { name, .. }) => { + assert_eq!(name, "nonexistent_embedding_function"); + } + Err(other) => panic!("Expected EmbeddingFunctionNotFound error, got: {:?}", other), + Ok(_) => panic!("Expected error, but got Ok"), + } + } + + #[tokio::test] + async fn test_create_table_already_exists() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let db = connect(uri).execute().await.unwrap(); + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)])); + db.create_empty_table("test", schema.clone()) + .execute() + .await + .unwrap(); + db.create_empty_table("test", schema) + .mode(CreateTableMode::exist_ok(|mut req| { + req.index_cache_size = Some(16); + req + })) + .execute() + .await + .unwrap(); + let other_schema = Arc::new(Schema::new(vec![Field::new("y", DataType::Int32, false)])); + assert!(db + .create_empty_table("test", other_schema.clone()) + .execute() + .await + .is_err()); // TODO: assert what this error is + let overwritten = db + .create_empty_table("test", other_schema.clone()) + .mode(CreateTableMode::Overwrite) + .execute() + .await + .unwrap(); + assert_eq!(other_schema, overwritten.schema().await.unwrap()); + } + + #[tokio::test] + #[rstest::rstest] + #[case(LanceFileVersion::Legacy)] + #[case(LanceFileVersion::Stable)] + async fn test_create_table_with_storage_version( + #[case] data_storage_version: LanceFileVersion, + ) { + let db = connect("memory://") + .database_options(&ListingDatabaseOptions { + new_table_config: NewTableConfig { + data_storage_version: Some(data_storage_version), + ..Default::default() + }, + ..Default::default() + }) + .execute() + .await + .unwrap(); + + let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap(); + let table = db + .create_table("legacy_table", batch) + .execute() + .await + .unwrap(); + + let native_table = table.as_native().unwrap(); + let storage_format = native_table + .manifest() + .await + .unwrap() + .data_storage_format + .lance_file_version() + .unwrap(); + // Compare resolved versions since Stable/Next are aliases that resolve at storage time + assert_eq!(storage_format.resolve(), data_storage_version.resolve()); + } + + #[tokio::test] + async fn test_create_table_with_embedding() { + // Register the mock embedding function + let registry = Arc::new(MemoryRegistry::new()); + let mock_embedding: Arc = Arc::new(MockEmbed::new("mock", 4)); + registry.register("mock", mock_embedding).unwrap(); + + // Connect with the custom registry + let conn = connect("memory://") + .embedding_registry(registry) + .execute() + .await + .unwrap(); + + // Create data without the embedding column + let batch = record_batch!(("text", Utf8, ["hello", "world", "test"])).unwrap(); + + // Create table with add_embedding - embeddings should be computed automatically + let table = conn + .create_table("embed_test", batch) + .add_embedding(EmbeddingDefinition::new( + "text", + "mock", + Some("text_embedding"), + )) + .unwrap() + .execute() + .await + .unwrap(); + + // Verify row count + assert_eq!(table.count_rows(None).await.unwrap(), 3); + + // Verify the schema includes the embedding column + let result_schema = table.schema().await.unwrap(); + assert_eq!(result_schema.fields().len(), 2); + assert_eq!(result_schema.field(0).name(), "text"); + assert_eq!(result_schema.field(1).name(), "text_embedding"); + + // Verify the embedding column has the correct type + assert!(matches!( + result_schema.field(1).data_type(), + DataType::FixedSizeList(_, 4) + )); + + // Query to verify the embeddings were computed + let results: Vec = table + .query() + .select(Select::columns(&["text", "text_embedding"])) + .execute() + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 3); + + // Check that all rows have embedding values (not null) + for batch in &results { + let embedding_col = batch.column(1); + assert_eq!(embedding_col.null_count(), 0); + assert_eq!(embedding_col.len(), batch.num_rows()); + } + + // Verify the schema metadata contains the column definitions + assert!( + result_schema + .metadata + .contains_key("lancedb::column_definitions"), + "Schema metadata should contain column definitions" + ); + } + + #[tokio::test] + async fn test_create_empty_table_with_embeddings() { + #[derive(Debug, Clone)] + struct MockEmbedding { + dim: usize, + } + + impl EmbeddingFunction for MockEmbedding { + fn name(&self) -> &str { + "test_embedding" + } + + fn source_type(&self) -> Result> { + Ok(Cow::Owned(DataType::Utf8)) + } + + fn dest_type(&self) -> Result> { + Ok(Cow::Owned(DataType::new_fixed_size_list( + DataType::Float32, + self.dim as i32, + true, + ))) + } + + fn compute_source_embeddings(&self, source: Arc) -> Result> { + let len = source.len(); + let values = vec![1.0f32; len * self.dim]; + let values = Arc::new(Float32Array::from(values)); + let field = Arc::new(Field::new("item", DataType::Float32, true)); + Ok(Arc::new(FixedSizeListArray::new( + field, + self.dim as i32, + values, + None, + ))) + } + + fn compute_query_embeddings(&self, _input: Arc) -> Result> { + unimplemented!() + } + } + + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let db = connect(uri).execute().await.unwrap(); + + let embed_func = Arc::new(MockEmbedding { dim: 128 }); + db.embedding_registry() + .register("test_embedding", embed_func.clone()) + .unwrap(); + + let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)])); + let ed = EmbeddingDefinition { + source_column: "name".to_owned(), + dest_column: Some("name_embedding".to_owned()), + embedding_name: "test_embedding".to_owned(), + }; + + let table = db + .create_empty_table("test", schema) + .mode(CreateTableMode::Overwrite) + .add_embedding(ed) + .unwrap() + .execute() + .await + .unwrap(); + + let table_schema = table.schema().await.unwrap(); + assert!(table_schema.column_with_name("name").is_some()); + assert!(table_schema.column_with_name("name_embedding").is_some()); + + let embedding_field = table_schema.field_with_name("name_embedding").unwrap(); + assert_eq!( + embedding_field.data_type(), + &DataType::new_fixed_size_list(DataType::Float32, 128, true) + ); + + let input_batch = record_batch!(("name", Utf8, ["Alice", "Bob", "Charlie"])).unwrap(); + table.add(input_batch).execute().await.unwrap(); + + let results = table + .query() + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 3); + assert!(batch.column_by_name("name_embedding").is_some()); + + let embedding_col = batch + .column_by_name("name_embedding") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(embedding_col.len(), 3); + } +} diff --git a/rust/lancedb/src/data.rs b/rust/lancedb/src/data.rs index d988884a1..16840038f 100644 --- a/rust/lancedb/src/data.rs +++ b/rust/lancedb/src/data.rs @@ -5,3 +5,4 @@ pub mod inspect; pub mod sanitize; +pub mod scannable; diff --git a/rust/lancedb/src/data/scannable.rs b/rust/lancedb/src/data/scannable.rs new file mode 100644 index 000000000..350742bd7 --- /dev/null +++ b/rust/lancedb/src/data/scannable.rs @@ -0,0 +1,580 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! Data source abstraction for LanceDB. +//! +//! This module provides a [`Scannable`] trait that allows input data sources to express +//! capabilities (row count, rescannability) so the insert pipeline can make +//! better decisions about write parallelism and retry strategies. + +use std::sync::Arc; + +use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader}; +use arrow_schema::{ArrowError, SchemaRef}; +use async_trait::async_trait; +use futures::stream::once; +use futures::StreamExt; +use lance_datafusion::utils::StreamingWriteSource; + +use crate::arrow::{ + SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream, +}; +use crate::embeddings::{ + compute_embeddings_for_batch, compute_output_schema, EmbeddingDefinition, EmbeddingFunction, + EmbeddingRegistry, +}; +use crate::table::{ColumnDefinition, ColumnKind, TableDefinition}; +use crate::{Error, Result}; + +pub trait Scannable: Send { + /// Returns the schema of the data. + fn schema(&self) -> SchemaRef; + + /// Read data as a stream of record batches. + /// + /// For rescannable sources (in-memory data like RecordBatch, Vec), + /// this can be called multiple times and returns cloned data each time. + /// + /// For non-rescannable sources (streams, readers), this can only be called once. + /// Calling it a second time returns a stream whose first item is an error. + fn scan_as_stream(&mut self) -> SendableRecordBatchStream; + + /// Optional hint about the number of rows. + /// + /// When available, this allows the pipeline to estimate total data size + /// and choose appropriate partitioning. + fn num_rows(&self) -> Option { + None + } + + /// Whether the source can be re-read from the beginning. + /// + /// `true` for in-memory data (Tables, DataFrames) and disk-based sources (Datasets). + /// `false` for streaming sources (DuckDB results, network streams). + /// + /// When true, the pipeline can retry failed writes by rescanning. + fn rescannable(&self) -> bool { + false + } +} + +impl std::fmt::Debug for dyn Scannable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Scannable") + .field("schema", &self.schema()) + .field("num_rows", &self.num_rows()) + .field("rescannable", &self.rescannable()) + .finish() + } +} + +impl Scannable for RecordBatch { + fn schema(&self) -> SchemaRef { + Self::schema(self) + } + + fn scan_as_stream(&mut self) -> SendableRecordBatchStream { + let batch = self.clone(); + let schema = batch.schema(); + Box::pin(SimpleRecordBatchStream { + schema, + stream: once(async move { Ok(batch) }), + }) + } + + fn num_rows(&self) -> Option { + Some(Self::num_rows(self)) + } + + fn rescannable(&self) -> bool { + true + } +} + +impl Scannable for Vec { + fn schema(&self) -> SchemaRef { + if self.is_empty() { + Arc::new(arrow_schema::Schema::empty()) + } else { + self[0].schema() + } + } + + fn scan_as_stream(&mut self) -> SendableRecordBatchStream { + if self.is_empty() { + let schema = Scannable::schema(self); + return Box::pin(SimpleRecordBatchStream { + schema, + stream: once(async { + Err(Error::InvalidInput { + message: "Cannot scan an empty Vec".to_string(), + }) + }), + }); + } + let schema = Scannable::schema(self); + let batches = self.clone(); + let stream = futures::stream::iter(batches.into_iter().map(Ok)); + Box::pin(SimpleRecordBatchStream { schema, stream }) + } + + fn num_rows(&self) -> Option { + Some(self.iter().map(|b| b.num_rows()).sum()) + } + + fn rescannable(&self) -> bool { + true + } +} + +impl Scannable for Box { + fn schema(&self) -> SchemaRef { + RecordBatchReader::schema(self.as_ref()) + } + + fn scan_as_stream(&mut self) -> SendableRecordBatchStream { + let schema = Scannable::schema(self); + + // Swap self with a reader that errors on iteration, so a second call + // produces a clear error instead of silently returning empty data. + let err_reader: Box = Box::new(RecordBatchIterator::new( + vec![Err(ArrowError::InvalidArgumentError( + "Reader has already been consumed".into(), + ))], + schema.clone(), + )); + let reader = std::mem::replace(self, err_reader); + + // Bridge the blocking RecordBatchReader to an async stream via a channel. + let (tx, rx) = tokio::sync::mpsc::channel::>(2); + tokio::task::spawn_blocking(move || { + for batch_result in reader { + let result = batch_result.map_err(Into::into); + if tx.blocking_send(result).is_err() { + break; + } + } + }); + + let stream = futures::stream::unfold(rx, |mut rx| async move { + rx.recv().await.map(|batch| (batch, rx)) + }) + .fuse(); + + Box::pin(SimpleRecordBatchStream { schema, stream }) + } +} + +impl Scannable for SendableRecordBatchStream { + fn schema(&self) -> SchemaRef { + self.as_ref().schema() + } + + fn scan_as_stream(&mut self) -> SendableRecordBatchStream { + let schema = Scannable::schema(self); + + // Swap self with an error stream so a second call produces a clear error. + let error_stream = Box::pin(SimpleRecordBatchStream { + schema: schema.clone(), + stream: once(async { + Err(Error::InvalidInput { + message: "Stream has already been consumed".to_string(), + }) + }), + }); + std::mem::replace(self, error_stream) + } +} + +#[async_trait] +impl StreamingWriteSource for Box { + fn arrow_schema(&self) -> SchemaRef { + self.schema() + } + + fn into_stream(mut self) -> datafusion_physical_plan::SendableRecordBatchStream { + self.scan_as_stream().into_df_stream() + } +} + +/// A scannable that applies embeddings to the stream. +pub struct WithEmbeddingsScannable { + inner: Box, + embeddings: Vec<(EmbeddingDefinition, Arc)>, + output_schema: SchemaRef, +} + +impl WithEmbeddingsScannable { + /// Create a new WithEmbeddingsScannable. + /// + /// The embeddings are applied to the inner scannable's data as new columns. + pub fn try_new( + inner: Box, + embeddings: Vec<(EmbeddingDefinition, Arc)>, + ) -> Result { + let output_schema = compute_output_schema(&inner.schema(), &embeddings)?; + + // Build column definitions: Physical for base columns, Embedding for new ones + let base_col_count = inner.schema().fields().len(); + let column_definitions: Vec = (0..base_col_count) + .map(|_| ColumnDefinition { + kind: ColumnKind::Physical, + }) + .chain(embeddings.iter().map(|(ed, _)| ColumnDefinition { + kind: ColumnKind::Embedding(ed.clone()), + })) + .collect(); + + let table_definition = TableDefinition::new(output_schema, column_definitions); + let output_schema = table_definition.into_rich_schema(); + + Ok(Self { + inner, + embeddings, + output_schema, + }) + } +} + +impl Scannable for WithEmbeddingsScannable { + fn schema(&self) -> SchemaRef { + self.output_schema.clone() + } + + fn scan_as_stream(&mut self) -> SendableRecordBatchStream { + let inner_stream = self.inner.scan_as_stream(); + let embeddings = self.embeddings.clone(); + let output_schema = self.output_schema.clone(); + + let mapped_stream = inner_stream.then(move |batch_result| { + let embeddings = embeddings.clone(); + async move { + let batch = batch_result?; + let result = tokio::task::spawn_blocking(move || { + compute_embeddings_for_batch(batch, &embeddings) + }) + .await + .map_err(|e| Error::Runtime { + message: format!("Task panicked during embedding computation: {}", e), + })??; + Ok(result) + } + }); + + Box::pin(SimpleRecordBatchStream { + schema: output_schema, + stream: mapped_stream, + }) + } + + fn num_rows(&self) -> Option { + self.inner.num_rows() + } + + fn rescannable(&self) -> bool { + self.inner.rescannable() + } +} + +pub fn scannable_with_embeddings( + inner: Box, + table_definition: &TableDefinition, + registry: Option<&Arc>, +) -> Result> { + if let Some(registry) = registry { + let mut embeddings = Vec::with_capacity(table_definition.column_definitions.len()); + for cd in table_definition.column_definitions.iter() { + if let ColumnKind::Embedding(embedding_def) = &cd.kind { + match registry.get(&embedding_def.embedding_name) { + Some(func) => { + embeddings.push((embedding_def.clone(), func)); + } + None => { + return Err(Error::EmbeddingFunctionNotFound { + name: embedding_def.embedding_name.clone(), + reason: format!( + "Table was defined with an embedding column `{}` but no embedding function was found with that name within the registry.", + embedding_def.embedding_name + ), + }); + } + } + } + } + + if !embeddings.is_empty() { + return Ok(Box::new(WithEmbeddingsScannable::try_new( + inner, embeddings, + )?)); + } + } + + Ok(inner) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::record_batch; + use futures::TryStreamExt; + + #[tokio::test] + async fn test_record_batch_rescannable() { + let mut batch = record_batch!(("id", Int64, [0, 1, 2])).unwrap(); + + let stream1 = batch.scan_as_stream(); + let batches1: Vec = stream1.try_collect().await.unwrap(); + assert_eq!(batches1.len(), 1); + assert_eq!(batches1[0], batch); + + assert!(batch.rescannable()); + let stream2 = batch.scan_as_stream(); + let batches2: Vec = stream2.try_collect().await.unwrap(); + assert_eq!(batches2.len(), 1); + assert_eq!(batches2[0], batch); + } + + #[tokio::test] + async fn test_vec_batch_rescannable() { + let mut batches = vec![ + record_batch!(("id", Int64, [0, 1])).unwrap(), + record_batch!(("id", Int64, [2, 3, 4])).unwrap(), + ]; + + let stream1 = batches.scan_as_stream(); + let result1: Vec = stream1.try_collect().await.unwrap(); + assert_eq!(result1.len(), 2); + assert_eq!(result1[0], batches[0]); + assert_eq!(result1[1], batches[1]); + + assert!(batches.rescannable()); + let stream2 = batches.scan_as_stream(); + let result2: Vec = stream2.try_collect().await.unwrap(); + assert_eq!(result2.len(), 2); + assert_eq!(result2[0], batches[0]); + assert_eq!(result2[1], batches[1]); + } + + #[tokio::test] + async fn test_vec_batch_empty_errors() { + let mut empty: Vec = vec![]; + let mut stream = empty.scan_as_stream(); + let result = stream.next().await; + assert!(result.is_some()); + assert!(result.unwrap().is_err()); + } + + #[tokio::test] + async fn test_reader_not_rescannable() { + let batch = record_batch!(("id", Int64, [0, 1, 2])).unwrap(); + let schema = batch.schema(); + let mut reader: Box = Box::new( + RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone()), + ); + + let stream1 = reader.scan_as_stream(); + let result1: Vec = stream1.try_collect().await.unwrap(); + assert_eq!(result1.len(), 1); + assert_eq!(result1[0], batch); + + assert!(!reader.rescannable()); + // Second call returns a stream whose first item is an error + let mut stream2 = reader.scan_as_stream(); + let result2 = stream2.next().await; + assert!(result2.is_some()); + assert!(result2.unwrap().is_err()); + } + + #[tokio::test] + async fn test_stream_not_rescannable() { + let batch = record_batch!(("id", Int64, [0, 1, 2])).unwrap(); + let schema = batch.schema(); + let inner_stream = futures::stream::iter(vec![Ok(batch.clone())]); + let mut stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream { + schema: schema.clone(), + stream: inner_stream, + }); + + let stream1 = stream.scan_as_stream(); + let result1: Vec = stream1.try_collect().await.unwrap(); + assert_eq!(result1.len(), 1); + assert_eq!(result1[0], batch); + + assert!(!stream.rescannable()); + // Second call returns a stream whose first item is an error + let mut stream2 = stream.scan_as_stream(); + let result2 = stream2.next().await; + assert!(result2.is_some()); + assert!(result2.unwrap().is_err()); + } + + mod embedding_tests { + use super::*; + use crate::embeddings::MemoryRegistry; + use crate::table::{ColumnDefinition, ColumnKind}; + use crate::test_utils::embeddings::MockEmbed; + use arrow_array::Array as _; + use arrow_array::{ArrayRef, StringArray}; + use arrow_schema::{DataType, Field, Schema}; + + #[tokio::test] + async fn test_with_embeddings_scannable() { + let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)])); + let text_array = StringArray::from(vec!["hello", "world", "test"]); + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(text_array) as ArrayRef]) + .unwrap(); + + let mock_embedding: Arc = Arc::new(MockEmbed::new("mock", 4)); + let embedding_def = EmbeddingDefinition::new("text", "mock", Some("text_embedding")); + + let mut scannable = WithEmbeddingsScannable::try_new( + Box::new(batch.clone()), + vec![(embedding_def, mock_embedding)], + ) + .unwrap(); + + // Check that schema has the embedding column + let output_schema = scannable.schema(); + assert_eq!(output_schema.fields().len(), 2); + assert_eq!(output_schema.field(0).name(), "text"); + assert_eq!(output_schema.field(1).name(), "text_embedding"); + + // Check num_rows and rescannable are preserved + assert_eq!(scannable.num_rows(), Some(3)); + assert!(scannable.rescannable()); + + // Read the data + let stream = scannable.scan_as_stream(); + let results: Vec = stream.try_collect().await.unwrap(); + assert_eq!(results.len(), 1); + + let result_batch = &results[0]; + assert_eq!(result_batch.num_rows(), 3); + assert_eq!(result_batch.num_columns(), 2); + + // Verify the embedding column is present and has the right shape + let embedding_col = result_batch.column(1); + assert_eq!(embedding_col.len(), 3); + } + + #[tokio::test] + async fn test_maybe_embedded_scannable_no_embeddings() { + let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap(); + + // Create a table definition with no embedding columns + let table_def = TableDefinition::new_from_schema(batch.schema()); + + // Even with a registry, if there are no embedding columns, it's a passthrough + let registry: Arc = Arc::new(MemoryRegistry::new()); + let mut scannable = + scannable_with_embeddings(Box::new(batch.clone()), &table_def, Some(®istry)) + .unwrap(); + + // Check that data passes through unchanged + let stream = scannable.scan_as_stream(); + let results: Vec = stream.try_collect().await.unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0], batch); + } + + #[tokio::test] + async fn test_maybe_embedded_scannable_with_embeddings() { + let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)])); + let text_array = StringArray::from(vec!["hello", "world"]); + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(text_array) as ArrayRef]) + .unwrap(); + + // Create a table definition with an embedding column + let embedding_def = EmbeddingDefinition::new("text", "mock", Some("text_embedding")); + let embedding_schema = Arc::new(Schema::new(vec![ + Field::new("text", DataType::Utf8, false), + Field::new( + "text_embedding", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + 4, + ), + false, + ), + ])); + let table_def = TableDefinition::new( + embedding_schema, + vec![ + ColumnDefinition { + kind: ColumnKind::Physical, + }, + ColumnDefinition { + kind: ColumnKind::Embedding(embedding_def.clone()), + }, + ], + ); + + // Register the mock embedding function + let registry: Arc = Arc::new(MemoryRegistry::new()); + let mock_embedding: Arc = Arc::new(MockEmbed::new("mock", 4)); + registry.register("mock", mock_embedding).unwrap(); + + let mut scannable = + scannable_with_embeddings(Box::new(batch), &table_def, Some(®istry)).unwrap(); + + // Read and verify the data has embeddings + let stream = scannable.scan_as_stream(); + let results: Vec = stream.try_collect().await.unwrap(); + assert_eq!(results.len(), 1); + + let result_batch = &results[0]; + assert_eq!(result_batch.num_columns(), 2); + assert_eq!(result_batch.schema().field(1).name(), "text_embedding"); + } + + #[tokio::test] + async fn test_maybe_embedded_scannable_missing_function() { + let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)])); + let text_array = StringArray::from(vec!["hello"]); + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(text_array) as ArrayRef]) + .unwrap(); + + // Create a table definition with an embedding column + let embedding_def = + EmbeddingDefinition::new("text", "nonexistent", Some("text_embedding")); + let embedding_schema = Arc::new(Schema::new(vec![ + Field::new("text", DataType::Utf8, false), + Field::new( + "text_embedding", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + 4, + ), + false, + ), + ])); + let table_def = TableDefinition::new( + embedding_schema, + vec![ + ColumnDefinition { + kind: ColumnKind::Physical, + }, + ColumnDefinition { + kind: ColumnKind::Embedding(embedding_def), + }, + ], + ); + + // Registry has no embedding functions registered + let registry: Arc = Arc::new(MemoryRegistry::new()); + + let result = scannable_with_embeddings(Box::new(batch), &table_def, Some(®istry)); + + // Should fail because the embedding function is not found + assert!(result.is_err()); + let err = result.err().unwrap(); + assert!( + matches!(err, Error::EmbeddingFunctionNotFound { .. }), + "Expected EmbeddingFunctionNotFound" + ); + } + } +} diff --git a/rust/lancedb/src/database.rs b/rust/lancedb/src/database.rs index 36947727c..a4a221193 100644 --- a/rust/lancedb/src/database.rs +++ b/rust/lancedb/src/database.rs @@ -18,12 +18,7 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; -use arrow_array::RecordBatchReader; -use async_trait::async_trait; -use datafusion_physical_plan::stream::RecordBatchStreamAdapter; -use futures::stream; use lance::dataset::ReadParams; -use lance_datafusion::utils::StreamingWriteSource; use lance_namespace::models::{ CreateNamespaceRequest, CreateNamespaceResponse, DescribeNamespaceRequest, DescribeNamespaceResponse, DropNamespaceRequest, DropNamespaceResponse, ListNamespacesRequest, @@ -31,9 +26,9 @@ use lance_namespace::models::{ }; use lance_namespace::LanceNamespace; -use crate::arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt}; +use crate::data::scannable::Scannable; use crate::error::Result; -use crate::table::{BaseTable, TableDefinition, WriteOptions}; +use crate::table::{BaseTable, WriteOptions}; pub mod listing; pub mod namespace; @@ -115,51 +110,14 @@ impl Default for CreateTableMode { } } -/// The data to start a table or a schema to create an empty table -pub enum CreateTableData { - /// Creates a table using an iterator of data, the schema will be obtained from the data - Data(Box), - /// Creates a table using a stream of data, the schema will be obtained from the data - StreamingData(SendableRecordBatchStream), - /// Creates an empty table, the definition / schema must be provided separately - Empty(TableDefinition), -} - -impl CreateTableData { - pub fn schema(&self) -> Arc { - match self { - Self::Data(reader) => reader.schema(), - Self::StreamingData(stream) => stream.schema(), - Self::Empty(definition) => definition.schema.clone(), - } - } -} - -#[async_trait] -impl StreamingWriteSource for CreateTableData { - fn arrow_schema(&self) -> Arc { - self.schema() - } - fn into_stream(self) -> datafusion_physical_plan::SendableRecordBatchStream { - match self { - Self::Data(reader) => reader.into_stream(), - Self::StreamingData(stream) => stream.into_df_stream(), - Self::Empty(table_definition) => { - let schema = table_definition.schema.clone(); - Box::pin(RecordBatchStreamAdapter::new(schema, stream::empty())) - } - } - } -} - /// A request to create a table pub struct CreateTableRequest { /// The name of the new table pub name: String, /// The namespace to create the table in. Empty list represents root namespace. pub namespace: Vec, - /// Initial data to write to the table, can be None to create an empty table - pub data: CreateTableData, + /// Initial data to write to the table, can be empty. + pub data: Box, /// The mode to use when creating the table pub mode: CreateTableMode, /// Options to use when writing data (only used if `data` is not None) @@ -173,7 +131,7 @@ pub struct CreateTableRequest { } impl CreateTableRequest { - pub fn new(name: String, data: CreateTableData) -> Self { + pub fn new(name: String, data: Box) -> Self { Self { name, namespace: vec![], diff --git a/rust/lancedb/src/database/listing.rs b/rust/lancedb/src/database/listing.rs index b2b08b607..aea5ce24c 100644 --- a/rust/lancedb/src/database/listing.rs +++ b/rust/lancedb/src/database/listing.rs @@ -922,7 +922,7 @@ impl Database for ListingDatabase { .with_read_params(read_params.clone()) .load() .await - .map_err(|e| Error::Lance { source: e })?; + .map_err(|e| -> Error { e.into() })?; let version_ref = match (request.source_version, request.source_tag) { (Some(v), None) => Ok(Ref::Version(None, Some(v))), @@ -937,7 +937,7 @@ impl Database for ListingDatabase { source_dataset .shallow_clone(&target_uri, version_ref, Some(storage_params)) .await - .map_err(|e| Error::Lance { source: e })?; + .map_err(|e| -> Error { e.into() })?; let cloned_table = NativeTable::open_with_params( &target_uri, @@ -1098,8 +1098,10 @@ impl Database for ListingDatabase { mod tests { use super::*; use crate::connection::ConnectRequest; - use crate::database::{CreateTableData, CreateTableMode, CreateTableRequest, WriteOptions}; - use crate::table::{Table, TableDefinition}; + use crate::data::scannable::Scannable; + use crate::database::{CreateTableMode, CreateTableRequest}; + use crate::table::WriteOptions; + use crate::Table; use arrow_array::{Int32Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use std::path::PathBuf; @@ -1139,7 +1141,7 @@ mod tests { .create_table(CreateTableRequest { name: "source_table".to_string(), namespace: vec![], - data: CreateTableData::Empty(TableDefinition::new_from_schema(schema.clone())), + data: Box::new(RecordBatch::new_empty(schema.clone())) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1196,16 +1198,11 @@ mod tests { ) .unwrap(); - let reader = Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch)], - schema.clone(), - )); - let source_table = db .create_table(CreateTableRequest { name: "source_with_data".to_string(), namespace: vec![], - data: CreateTableData::Data(reader), + data: Box::new(batch) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1264,7 +1261,7 @@ mod tests { db.create_table(CreateTableRequest { name: "source".to_string(), namespace: vec![], - data: CreateTableData::Empty(TableDefinition::new_from_schema(schema)), + data: Box::new(RecordBatch::new_empty(schema)) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1300,7 +1297,7 @@ mod tests { db.create_table(CreateTableRequest { name: "source".to_string(), namespace: vec![], - data: CreateTableData::Empty(TableDefinition::new_from_schema(schema)), + data: Box::new(RecordBatch::new_empty(schema)) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1340,7 +1337,7 @@ mod tests { db.create_table(CreateTableRequest { name: "source".to_string(), namespace: vec![], - data: CreateTableData::Empty(TableDefinition::new_from_schema(schema)), + data: Box::new(RecordBatch::new_empty(schema)) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1380,7 +1377,7 @@ mod tests { db.create_table(CreateTableRequest { name: "source".to_string(), namespace: vec![], - data: CreateTableData::Empty(TableDefinition::new_from_schema(schema)), + data: Box::new(RecordBatch::new_empty(schema)) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1435,7 +1432,7 @@ mod tests { db.create_table(CreateTableRequest { name: "source".to_string(), namespace: vec![], - data: CreateTableData::Empty(TableDefinition::new_from_schema(schema)), + data: Box::new(RecordBatch::new_empty(schema)) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1484,16 +1481,11 @@ mod tests { ) .unwrap(); - let reader = Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch1)], - schema.clone(), - )); - let source_table = db .create_table(CreateTableRequest { name: "versioned_source".to_string(), namespace: vec![], - data: CreateTableData::Data(reader), + data: Box::new(batch1) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1517,14 +1509,7 @@ mod tests { let db = Arc::new(db); let source_table_obj = Table::new(source_table.clone(), db.clone()); - source_table_obj - .add(Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch2)], - schema.clone(), - ))) - .execute() - .await - .unwrap(); + source_table_obj.add(batch2).execute().await.unwrap(); // Verify source table now has 4 rows assert_eq!(source_table.count_rows(None).await.unwrap(), 4); @@ -1570,16 +1555,11 @@ mod tests { ) .unwrap(); - let reader = Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch1)], - schema.clone(), - )); - let source_table = db .create_table(CreateTableRequest { name: "tagged_source".to_string(), namespace: vec![], - data: CreateTableData::Data(reader), + data: Box::new(batch1), mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1607,14 +1587,7 @@ mod tests { .unwrap(); let source_table_obj = Table::new(source_table.clone(), db.clone()); - source_table_obj - .add(Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch2)], - schema.clone(), - ))) - .execute() - .await - .unwrap(); + source_table_obj.add(batch2).execute().await.unwrap(); // Source table should have 4 rows assert_eq!(source_table.count_rows(None).await.unwrap(), 4); @@ -1657,16 +1630,11 @@ mod tests { ) .unwrap(); - let reader = Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch1)], - schema.clone(), - )); - let source_table = db .create_table(CreateTableRequest { name: "independent_source".to_string(), namespace: vec![], - data: CreateTableData::Data(reader), + data: Box::new(batch1), mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1706,14 +1674,7 @@ mod tests { let db = Arc::new(db); let cloned_table_obj = Table::new(cloned_table.clone(), db.clone()); - cloned_table_obj - .add(Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch_clone)], - schema.clone(), - ))) - .execute() - .await - .unwrap(); + cloned_table_obj.add(batch_clone).execute().await.unwrap(); // Add different data to the source table let batch_source = RecordBatch::try_new( @@ -1726,14 +1687,7 @@ mod tests { .unwrap(); let source_table_obj = Table::new(source_table.clone(), db); - source_table_obj - .add(Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch_source)], - schema.clone(), - ))) - .execute() - .await - .unwrap(); + source_table_obj.add(batch_source).execute().await.unwrap(); // Verify they have evolved independently assert_eq!(source_table.count_rows(None).await.unwrap(), 4); // 2 + 2 @@ -1751,16 +1705,11 @@ mod tests { RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from(vec![1, 2]))]) .unwrap(); - let reader = Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch1)], - schema.clone(), - )); - let source_table = db .create_table(CreateTableRequest { name: "latest_version_source".to_string(), namespace: vec![], - data: CreateTableData::Data(reader), + data: Box::new(batch1), mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1779,14 +1728,7 @@ mod tests { .unwrap(); let source_table_obj = Table::new(source_table.clone(), db.clone()); - source_table_obj - .add(Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch)], - schema.clone(), - ))) - .execute() - .await - .unwrap(); + source_table_obj.add(batch).execute().await.unwrap(); } // Source should have 8 rows total (2 + 2 + 2 + 2) @@ -1849,16 +1791,11 @@ mod tests { ) .unwrap(); - let reader = Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch)], - schema.clone(), - )); - let table = db .create_table(CreateTableRequest { name: "test_stable".to_string(), namespace: vec![], - data: CreateTableData::Data(reader), + data: Box::new(batch), mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1887,11 +1824,6 @@ mod tests { ) .unwrap(); - let reader = Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch)], - schema.clone(), - )); - let mut storage_options = HashMap::new(); storage_options.insert( OPT_NEW_TABLE_ENABLE_STABLE_ROW_IDS.to_string(), @@ -1914,7 +1846,7 @@ mod tests { .create_table(CreateTableRequest { name: "test_stable_table_level".to_string(), namespace: vec![], - data: CreateTableData::Data(reader), + data: Box::new(batch), mode: CreateTableMode::Create, write_options, location: None, @@ -1963,11 +1895,6 @@ mod tests { ) .unwrap(); - let reader = Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch)], - schema.clone(), - )); - let mut storage_options = HashMap::new(); storage_options.insert( OPT_NEW_TABLE_ENABLE_STABLE_ROW_IDS.to_string(), @@ -1990,7 +1917,7 @@ mod tests { .create_table(CreateTableRequest { name: "test_override".to_string(), namespace: vec![], - data: CreateTableData::Data(reader), + data: Box::new(batch), mode: CreateTableMode::Create, write_options, location: None, @@ -2108,7 +2035,7 @@ mod tests { db.create_table(CreateTableRequest { name: "table1".to_string(), namespace: vec![], - data: CreateTableData::Empty(TableDefinition::new_from_schema(schema.clone())), + data: Box::new(RecordBatch::new_empty(schema.clone())) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -2120,7 +2047,7 @@ mod tests { db.create_table(CreateTableRequest { name: "table2".to_string(), namespace: vec![], - data: CreateTableData::Empty(TableDefinition::new_from_schema(schema)), + data: Box::new(RecordBatch::new_empty(schema)) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, diff --git a/rust/lancedb/src/database/namespace.rs b/rust/lancedb/src/database/namespace.rs index b8ad19cd0..91a55809e 100644 --- a/rust/lancedb/src/database/namespace.rs +++ b/rust/lancedb/src/database/namespace.rs @@ -354,15 +354,13 @@ mod tests { use super::*; use crate::connect_namespace; use crate::query::ExecutableQuery; - use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray}; + use arrow_array::{Int32Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; use tempfile::tempdir; /// Helper function to create test data - fn create_test_data() -> RecordBatchIterator< - std::vec::IntoIter>, - > { + fn create_test_data() -> RecordBatch { let schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("name", DataType::Utf8, false), @@ -371,12 +369,7 @@ mod tests { let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); let name_array = StringArray::from(vec!["Alice", "Bob", "Charlie", "David", "Eve"]); - let batch = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(id_array), Arc::new(name_array)], - ) - .unwrap(); - RecordBatchIterator::new(vec![std::result::Result::Ok(batch)].into_iter(), schema) + RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(name_array)]).unwrap() } #[tokio::test] @@ -618,13 +611,7 @@ mod tests { // Test: Overwrite the table let table2 = conn - .create_table( - "overwrite_test", - RecordBatchIterator::new( - vec![std::result::Result::Ok(test_data2)].into_iter(), - schema, - ), - ) + .create_table("overwrite_test", test_data2) .namespace(vec!["test_ns".into()]) .mode(CreateTableMode::Overwrite) .execute() diff --git a/rust/lancedb/src/dataloader/permutation/builder.rs b/rust/lancedb/src/dataloader/permutation/builder.rs index 9fb507443..c0c418e55 100644 --- a/rust/lancedb/src/dataloader/permutation/builder.rs +++ b/rust/lancedb/src/dataloader/permutation/builder.rs @@ -13,7 +13,7 @@ use lance_datafusion::exec::SessionContextExt; use crate::{ arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream}, connect, - database::{CreateTableData, CreateTableRequest, Database}, + database::{CreateTableRequest, Database}, dataloader::permutation::{ shuffle::{Shuffler, ShufflerConfig}, split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN}, @@ -313,10 +313,8 @@ impl PermutationBuilder { } }; - let create_table_request = CreateTableRequest::new( - name.to_string(), - CreateTableData::StreamingData(streaming_data), - ); + let create_table_request = + CreateTableRequest::new(name.to_string(), Box::new(streaming_data)); let table = database.create_table(create_table_request).await?; @@ -347,7 +345,7 @@ mod tests { .col("col_b", lance_datagen::array::step::()) .into_ldb_stream(RowCount::from(100), BatchCount::from(10)); let data_table = db - .create_table_streaming("base_tbl", initial_data) + .create_table("base_tbl", initial_data) .execute() .await .unwrap(); @@ -387,7 +385,7 @@ mod tests { .col("some_value", lance_datagen::array::step::()) .into_ldb_stream(RowCount::from(100), BatchCount::from(10)); let data_table = db - .create_table_streaming("mytbl", initial_data) + .create_table("mytbl", initial_data) .execute() .await .unwrap(); diff --git a/rust/lancedb/src/embeddings.rs b/rust/lancedb/src/embeddings.rs index 56b87c1cb..b6272a123 100644 --- a/rust/lancedb/src/embeddings.rs +++ b/rust/lancedb/src/embeddings.rs @@ -18,7 +18,7 @@ use std::{ }; use arrow_array::{Array, RecordBatch, RecordBatchReader}; -use arrow_schema::{DataType, Field, SchemaBuilder}; +use arrow_schema::{DataType, Field, SchemaBuilder, SchemaRef}; // use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -190,6 +190,112 @@ impl WithEmbeddings { } } +/// Compute embedding arrays for a batch. +/// +/// When multiple embedding functions are defined, they are computed in parallel using +/// scoped threads. For a single embedding function, computation is done inline. +fn compute_embedding_arrays( + batch: &RecordBatch, + embeddings: &[(EmbeddingDefinition, Arc)], +) -> Result>> { + if embeddings.len() == 1 { + let (fld, func) = &embeddings[0]; + let src_column = + batch + .column_by_name(&fld.source_column) + .ok_or_else(|| Error::InvalidInput { + message: format!("Source column '{}' not found", fld.source_column), + })?; + return Ok(vec![func.compute_source_embeddings(src_column.clone())?]); + } + + // Parallel path: multiple embeddings + std::thread::scope(|s| { + let handles: Vec<_> = embeddings + .iter() + .map(|(fld, func)| { + let src_column = batch.column_by_name(&fld.source_column).ok_or_else(|| { + Error::InvalidInput { + message: format!("Source column '{}' not found", fld.source_column), + } + })?; + + let handle = s.spawn(move || func.compute_source_embeddings(src_column.clone())); + + Ok(handle) + }) + .collect::>()?; + + handles + .into_iter() + .map(|h| { + h.join().map_err(|e| Error::Runtime { + message: format!("Thread panicked during embedding computation: {:?}", e), + })? + }) + .collect() + }) +} + +/// Compute the output schema when embeddings are applied to a base schema. +/// +/// This returns the schema with embedding columns appended. +pub fn compute_output_schema( + base_schema: &SchemaRef, + embeddings: &[(EmbeddingDefinition, Arc)], +) -> Result { + let mut sb: SchemaBuilder = base_schema.as_ref().into(); + + for (ed, func) in embeddings { + let src_field = base_schema + .field_with_name(&ed.source_column) + .map_err(|_| Error::InvalidInput { + message: format!("Source column '{}' not found in schema", ed.source_column), + })?; + + let field_name = ed + .dest_column + .clone() + .unwrap_or_else(|| format!("{}_embedding", &ed.source_column)); + + sb.push(Field::new( + field_name, + func.dest_type()?.into_owned(), + src_field.is_nullable(), + )); + } + + Ok(Arc::new(sb.finish())) +} + +/// Compute embeddings for a batch and append as new columns. +/// +/// This function computes embeddings using the provided embedding functions and +/// appends them as new columns to the batch. +pub fn compute_embeddings_for_batch( + batch: RecordBatch, + embeddings: &[(EmbeddingDefinition, Arc)], +) -> Result { + let embedding_arrays = compute_embedding_arrays(&batch, embeddings)?; + + let mut result = batch; + for ((fld, _), embedding) in embeddings.iter().zip(embedding_arrays.iter()) { + let dst_field_name = fld + .dest_column + .clone() + .unwrap_or_else(|| format!("{}_embedding", &fld.source_column)); + + let dst_field = Field::new( + dst_field_name, + embedding.data_type().clone(), + embedding.nulls().is_some(), + ); + + result = result.try_with_column(dst_field, embedding.clone())?; + } + Ok(result) +} + impl WithEmbeddings { fn dest_fields(&self) -> Result> { let schema = self.inner.schema(); @@ -240,48 +346,6 @@ impl WithEmbeddings { column_definitions, }) } - - fn compute_embeddings_parallel(&self, batch: &RecordBatch) -> Result>> { - if self.embeddings.len() == 1 { - let (fld, func) = &self.embeddings[0]; - let src_column = - batch - .column_by_name(&fld.source_column) - .ok_or_else(|| Error::InvalidInput { - message: format!("Source column '{}' not found", fld.source_column), - })?; - return Ok(vec![func.compute_source_embeddings(src_column.clone())?]); - } - - // Parallel path: multiple embeddings - std::thread::scope(|s| { - let handles: Vec<_> = self - .embeddings - .iter() - .map(|(fld, func)| { - let src_column = batch.column_by_name(&fld.source_column).ok_or_else(|| { - Error::InvalidInput { - message: format!("Source column '{}' not found", fld.source_column), - } - })?; - - let handle = - s.spawn(move || func.compute_source_embeddings(src_column.clone())); - - Ok(handle) - }) - .collect::>()?; - - handles - .into_iter() - .map(|h| { - h.join().map_err(|e| Error::Runtime { - message: format!("Thread panicked during embedding computation: {:?}", e), - })? - }) - .collect() - }) - } } impl Iterator for MaybeEmbedded { @@ -309,37 +373,13 @@ impl Iterator for WithEmbeddings { fn next(&mut self) -> Option { let batch = self.inner.next()?; match batch { - Ok(batch) => { - let embeddings = match self.compute_embeddings_parallel(&batch) { - Ok(emb) => emb, - Err(e) => { - return Some(Err(arrow_schema::ArrowError::ComputeError(format!( - "Error computing embedding: {}", - e - )))) - } - }; - - let mut batch = batch; - for ((fld, _), embedding) in self.embeddings.iter().zip(embeddings.iter()) { - let dst_field_name = fld - .dest_column - .clone() - .unwrap_or_else(|| format!("{}_embedding", &fld.source_column)); - - let dst_field = Field::new( - dst_field_name, - embedding.data_type().clone(), - embedding.nulls().is_some(), - ); - - match batch.try_with_column(dst_field.clone(), embedding.clone()) { - Ok(b) => batch = b, - Err(e) => return Some(Err(e)), - }; - } - Some(Ok(batch)) - } + Ok(batch) => match compute_embeddings_for_batch(batch, &self.embeddings) { + Ok(batch_with_embeddings) => Some(Ok(batch_with_embeddings)), + Err(e) => Some(Err(arrow_schema::ArrowError::ComputeError(format!( + "Error computing embedding: {}", + e + )))), + }, Err(e) => Some(Err(e)), } } diff --git a/rust/lancedb/src/error.rs b/rust/lancedb/src/error.rs index 4312b3860..55e2350ac 100644 --- a/rust/lancedb/src/error.rs +++ b/rust/lancedb/src/error.rs @@ -6,7 +6,7 @@ use std::sync::PoisonError; use arrow_schema::ArrowError; use snafu::Snafu; -type BoxError = Box; +pub(crate) type BoxError = Box; #[derive(Debug, Snafu)] #[snafu(visibility(pub(crate)))] @@ -80,6 +80,9 @@ pub enum Error { Arrow { source: ArrowError }, #[snafu(display("LanceDBError: not supported: {message}"))] NotSupported { message: String }, + /// External error pass through from user code. + #[snafu(transparent)] + External { source: BoxError }, #[snafu(whatever, display("{message}"))] Other { message: String, @@ -92,15 +95,26 @@ pub type Result = std::result::Result; impl From for Error { fn from(source: ArrowError) -> Self { - Self::Arrow { source } + match source { + ArrowError::ExternalError(source) => match source.downcast::() { + Ok(e) => *e, + Err(source) => Self::External { source }, + }, + _ => Self::Arrow { source }, + } } } impl From for Error { fn from(source: lance::Error) -> Self { - // TODO: Once Lance is changed to preserve ObjectStore, DataFusion, and Arrow errors, we can - // pass those variants through here as well. - Self::Lance { source } + // Try to unwrap external errors that were wrapped by lance + match source { + lance::Error::Wrapped { error, .. } => match error.downcast::() { + Ok(e) => *e, + Err(source) => Self::External { source }, + }, + _ => Self::Lance { source }, + } } } diff --git a/rust/lancedb/src/io/object_store.rs b/rust/lancedb/src/io/object_store.rs index 4988ae05c..d935c099a 100644 --- a/rust/lancedb/src/io/object_store.rs +++ b/rust/lancedb/src/io/object_store.rs @@ -218,8 +218,9 @@ mod test { datagen = datagen.col(Box::::default()); datagen = datagen.col(Box::new(RandomVector::default().named("vector".into()))); + let data: Box = Box::new(datagen.batch(100)); let res = db - .create_table("test", Box::new(datagen.batch(100))) + .create_table("test", data) .write_options(WriteOptions { lance_write_params: Some(param), }) diff --git a/rust/lancedb/src/ipc.rs b/rust/lancedb/src/ipc.rs index c814b41bf..e71f2e521 100644 --- a/rust/lancedb/src/ipc.rs +++ b/rust/lancedb/src/ipc.rs @@ -12,10 +12,10 @@ use arrow_schema::Schema; use crate::{Error, Result}; /// Convert a Arrow IPC file to a batch reader -pub fn ipc_file_to_batches(buf: Vec) -> Result { +pub fn ipc_file_to_batches(buf: Vec) -> Result> { let buf_reader = Cursor::new(buf); let reader = FileReader::try_new(buf_reader, None)?; - Ok(reader) + Ok(Box::new(reader)) } /// Convert record batches to Arrow IPC file diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index 64699cb15..944613253 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -39,7 +39,6 @@ //! #### Connect to a database. //! //! ```rust -//! # use arrow_schema::{Field, Schema}; //! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! let db = lancedb::connect("data/sample-lancedb").execute().await.unwrap(); //! # }); @@ -74,7 +73,10 @@ //! //! #### Create a table //! -//! To create a Table, you need to provide a [`arrow_schema::Schema`] and a [`arrow_array::RecordBatch`] stream. +//! To create a Table, you need to provide an [`arrow_array::RecordBatch`]. The +//! schema of the `RecordBatch` determines the schema of the table. +//! +//! Vector columns should be represented as `FixedSizeList` data type. //! //! ```rust //! # use std::sync::Arc; @@ -85,34 +87,29 @@ //! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! # let tmpdir = tempfile::tempdir().unwrap(); //! # let db = lancedb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap(); +//! let ndims = 128; //! let schema = Arc::new(Schema::new(vec![ //! Field::new("id", DataType::Int32, false), //! Field::new( //! "vector", -//! DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 128), +//! DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), ndims), //! true, //! ), //! ])); -//! // Create a RecordBatch stream. -//! let batches = RecordBatchIterator::new( -//! vec![RecordBatch::try_new( +//! let data = RecordBatch::try_new( //! schema.clone(), //! vec![ //! Arc::new(Int32Array::from_iter_values(0..256)), //! Arc::new( //! FixedSizeListArray::from_iter_primitive::( -//! (0..256).map(|_| Some(vec![Some(1.0); 128])), -//! 128, +//! (0..256).map(|_| Some(vec![Some(1.0); ndims as usize])), +//! ndims, //! ), //! ), //! ], //! ) -//! .unwrap()] -//! .into_iter() -//! .map(Ok), -//! schema.clone(), -//! ); -//! db.create_table("my_table", Box::new(batches)) +//! .unwrap(); +//! db.create_table("my_table", data) //! .execute() //! .await //! .unwrap(); @@ -151,42 +148,18 @@ //! #### Open table and search //! //! ```rust -//! # use std::sync::Arc; //! # use futures::TryStreamExt; -//! # use arrow_schema::{DataType, Schema, Field}; -//! # use arrow_array::{RecordBatch, RecordBatchIterator}; -//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type}; //! # use lancedb::query::{ExecutableQuery, QueryBase}; -//! # tokio::runtime::Runtime::new().unwrap().block_on(async { -//! # let tmpdir = tempfile::tempdir().unwrap(); -//! # let db = lancedb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap(); -//! # let schema = Arc::new(Schema::new(vec![ -//! # Field::new("id", DataType::Int32, false), -//! # Field::new("vector", DataType::FixedSizeList( -//! # Arc::new(Field::new("item", DataType::Float32, true)), 128), true), -//! # ])); -//! # let batches = RecordBatchIterator::new(vec![ -//! # RecordBatch::try_new(schema.clone(), -//! # vec![ -//! # Arc::new(Int32Array::from_iter_values(0..10)), -//! # Arc::new(FixedSizeListArray::from_iter_primitive::( -//! # (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)), -//! # ]).unwrap() -//! # ].into_iter().map(Ok), -//! # schema.clone()); -//! # db.create_table("my_table", Box::new(batches)).execute().await.unwrap(); -//! # let table = db.open_table("my_table").execute().await.unwrap(); +//! # async fn example(table: &lancedb::Table) -> lancedb::Result<()> { //! let results = table //! .query() -//! .nearest_to(&[1.0; 128]) -//! .unwrap() +//! .nearest_to(&[1.0; 128])? //! .execute() -//! .await -//! .unwrap() +//! .await? //! .try_collect::>() -//! .await -//! .unwrap(); -//! # }); +//! .await?; +//! # Ok(()) +//! # } //! ``` pub mod arrow; diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index c3e2f8faa..b1691c8ac 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -1381,7 +1381,7 @@ mod tests { use arrow::{array::downcast_array, compute::concat_batches, datatypes::Int32Type}; use arrow_array::{ cast::AsArray, types::Float32Type, FixedSizeListArray, Float32Array, Int32Array, - RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray, + RecordBatch, StringArray, }; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; use futures::{StreamExt, TryStreamExt}; @@ -1402,7 +1402,7 @@ mod tests { let batches = make_test_batches(); let conn = connect(uri).execute().await.unwrap(); let table = conn - .create_table("my_table", Box::new(batches)) + .create_table("my_table", batches) .execute() .await .unwrap(); @@ -1463,7 +1463,7 @@ mod tests { let batches = make_non_empty_batches(); let conn = connect(uri).execute().await.unwrap(); let table = conn - .create_table("my_table", Box::new(batches)) + .create_table("my_table", batches) .execute() .await .unwrap(); @@ -1525,7 +1525,7 @@ mod tests { let batches = make_non_empty_batches(); let conn = connect(uri).execute().await.unwrap(); let table = conn - .create_table("my_table", Box::new(batches)) + .create_table("my_table", batches) .execute() .await .unwrap(); @@ -1578,7 +1578,7 @@ mod tests { let batches = make_non_empty_batches(); let conn = connect(uri).execute().await.unwrap(); let table = conn - .create_table("my_table", Box::new(batches)) + .create_table("my_table", batches) .execute() .await .unwrap(); @@ -1599,13 +1599,13 @@ mod tests { assert!(result.is_err()); } - fn make_non_empty_batches() -> impl RecordBatchReader + Send + 'static { + fn make_non_empty_batches() -> Box { let vec = Box::new(RandomVector::new().named("vector".to_string())); let id = Box::new(IncrementingInt32::new().named("id".to_string())); - BatchGenerator::new().col(vec).col(id).batch(512) + Box::new(BatchGenerator::new().col(vec).col(id).batch(512)) } - fn make_test_batches() -> impl RecordBatchReader + Send + 'static { + fn make_test_batches() -> RecordBatch { let dim: usize = 128; let schema = Arc::new(ArrowSchema::new(vec![ ArrowField::new("key", DataType::Int32, false), @@ -1619,12 +1619,7 @@ mod tests { ), ArrowField::new("uri", DataType::Utf8, true), ])); - RecordBatchIterator::new( - vec![RecordBatch::new_empty(schema.clone())] - .into_iter() - .map(Ok), - schema, - ) + RecordBatch::new_empty(schema) } async fn make_test_table(tmp_dir: &tempfile::TempDir) -> Table { @@ -1633,7 +1628,7 @@ mod tests { let batches = make_non_empty_batches(); let conn = connect(uri).execute().await.unwrap(); - conn.create_table("my_table", Box::new(batches)) + conn.create_table("my_table", batches) .execute() .await .unwrap() @@ -1862,10 +1857,8 @@ mod tests { let record_batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vector)]).unwrap(); - let record_batch_iter = - RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone()); let table = conn - .create_table("my_table", record_batch_iter) + .create_table("my_table", record_batch) .execute() .await .unwrap(); @@ -1949,10 +1942,8 @@ mod tests { ], ) .unwrap(); - let record_batch_iter = - RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone()); let table = conn - .create_table("my_table", record_batch_iter) + .create_table("my_table", record_batch) .mode(CreateTableMode::Overwrite) .execute() .await @@ -2062,8 +2053,6 @@ mod tests { async fn test_pagination_with_fts() { let db = connect("memory://test").execute().await.unwrap(); let data = fts_test_data(400); - let schema = data.schema(); - let data = RecordBatchIterator::new(vec![Ok(data)], schema); let table = db.create_table("test_table", data).execute().await.unwrap(); table diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 46e51569a..73d2f14da 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -491,7 +491,7 @@ impl RestfulLanceDbClient { } /// Apply dynamic headers from the header provider if configured - async fn apply_dynamic_headers(&self, mut request: Request) -> Result { + pub(crate) async fn apply_dynamic_headers(&self, mut request: Request) -> Result { if let Some(ref provider) = self.header_provider { let headers = provider.get_headers().await?; let request_headers = request.headers_mut(); @@ -555,7 +555,9 @@ impl RestfulLanceDbClient { message: "Attempted to retry a request that cannot be cloned".to_string(), })?; let (_, r) = tmp_req.build_split(); - let mut r = r.unwrap(); + let mut r = r.map_err(|e| Error::Runtime { + message: format!("Failed to build request: {}", e), + })?; let request_id = self.extract_request_id(&mut r); let mut retry_counter = RetryCounter::new(retry_config, request_id.clone()); @@ -571,7 +573,9 @@ impl RestfulLanceDbClient { } let (c, request) = req_builder.build_split(); - let mut request = request.unwrap(); + let mut request = request.map_err(|e| Error::Runtime { + message: format!("Failed to build request: {}", e), + })?; self.set_request_id(&mut request, &request_id.clone()); // Apply dynamic headers before each retry attempt @@ -621,7 +625,7 @@ impl RestfulLanceDbClient { } } - fn log_request(&self, request: &Request, request_id: &String) { + pub(crate) fn log_request(&self, request: &Request, request_id: &String) { if log::log_enabled!(log::Level::Debug) { let content_type = request .headers() diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 66736a872..38541f079 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -4,13 +4,11 @@ use std::collections::HashMap; use std::sync::Arc; -use arrow_array::RecordBatchIterator; use async_trait::async_trait; use http::StatusCode; use lance_io::object_store::StorageOptions; use moka::future::Cache; use reqwest::header::CONTENT_TYPE; -use tokio::task::spawn_blocking; use lance_namespace::models::{ CreateNamespaceRequest, CreateNamespaceResponse, DescribeNamespaceRequest, @@ -19,16 +17,17 @@ use lance_namespace::models::{ }; use crate::database::{ - CloneTableRequest, CreateTableData, CreateTableMode, CreateTableRequest, Database, - DatabaseOptions, OpenTableRequest, ReadConsistency, TableNamesRequest, + CloneTableRequest, CreateTableMode, CreateTableRequest, Database, DatabaseOptions, + OpenTableRequest, ReadConsistency, TableNamesRequest, }; use crate::error::Result; +use crate::remote::util::stream_as_body; use crate::table::BaseTable; use crate::Error; use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender}; use super::table::RemoteTable; -use super::util::{batches_to_ipc_bytes, parse_server_version}; +use super::util::parse_server_version; use super::ARROW_STREAM_CONTENT_TYPE; // Request structure for the remote clone table API @@ -436,26 +435,8 @@ impl Database for RemoteDatabase { Ok(response) } - async fn create_table(&self, request: CreateTableRequest) -> Result> { - let data = match request.data { - CreateTableData::Data(data) => data, - CreateTableData::StreamingData(_) => { - return Err(Error::NotSupported { - message: "Creating a remote table from a streaming source".to_string(), - }) - } - CreateTableData::Empty(table_definition) => { - let schema = table_definition.schema.clone(); - Box::new(RecordBatchIterator::new(vec![], schema)) - } - }; - - // TODO: https://github.com/lancedb/lancedb/issues/1026 - // We should accept data from an async source. In the meantime, spawn this as blocking - // to make sure we don't block the tokio runtime if the source is slow. - let data_buffer = spawn_blocking(move || batches_to_ipc_bytes(data)) - .await - .unwrap()?; + async fn create_table(&self, mut request: CreateTableRequest) -> Result> { + let body = stream_as_body(request.data.scan_as_stream())?; let identifier = build_table_identifier(&request.name, &request.namespace, &self.client.id_delimiter); @@ -463,7 +444,7 @@ impl Database for RemoteDatabase { .client .post(&format!("/v1/table/{}/create/", identifier)) .query(&[("mode", Into::<&str>::into(&request.mode))]) - .body(data_buffer) + .body(body) .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE); let (request_id, rsp) = self.client.send(req).await?; @@ -813,7 +794,7 @@ mod tests { use std::collections::HashMap; use std::sync::{Arc, OnceLock}; - use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator}; + use arrow_array::{Int32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; use crate::connection::ConnectBuilder; @@ -993,8 +974,7 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema()); - let table = conn.create_table("table1", reader).execute().await.unwrap(); + let table = conn.create_table("table1", data).execute().await.unwrap(); assert_eq!(table.name(), "table1"); } @@ -1011,8 +991,7 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema()); - let result = conn.create_table("table1", reader).execute().await; + let result = conn.create_table("table1", data).execute().await; assert!(result.is_err()); assert!( matches!(result, Err(crate::Error::TableAlreadyExists { name }) if name == "table1") @@ -1045,8 +1024,7 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema()); - let mut builder = conn.create_table("table1", reader); + let mut builder = conn.create_table("table1", data.clone()); if let Some(mode) = mode { builder = builder.mode(mode); } @@ -1071,9 +1049,8 @@ mod tests { .unwrap(); let called: Arc> = Arc::new(OnceLock::new()); - let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema()); let called_in_cb = called.clone(); - conn.create_table("table1", reader) + conn.create_table("table1", data) .mode(CreateTableMode::ExistOk(Box::new(move |b| { called_in_cb.clone().set(true).unwrap(); b @@ -1262,9 +1239,8 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema()); let table = conn - .create_table("table1", reader) + .create_table("table1", data) .namespace(vec!["ns1".to_string()]) .execute() .await @@ -1730,10 +1706,8 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let reader = RecordBatchIterator::new([Ok(data.clone())], schema.clone()); - let table = conn - .create_table("test_table", reader) + .create_table("test_table", data) .namespace(namespace.clone()) .execute() .await; @@ -1806,9 +1780,7 @@ mod tests { let data = RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from(vec![i]))]) .unwrap(); - let reader = RecordBatchIterator::new([Ok(data.clone())], schema.clone()); - - conn.create_table(format!("table{}", i), reader) + conn.create_table(format!("table{}", i), data) .namespace(namespace.clone()) .execute() .await diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index cf0989b0b..8ed4b6236 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -3,9 +3,11 @@ pub mod insert; +use crate::data::scannable::Scannable; use crate::index::Index; use crate::index::IndexStatistics; use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest}; +use crate::remote::util::stream_as_ipc; use crate::table::AddColumnsResult; use crate::table::AddResult; use crate::table::AlterColumnsResult; @@ -45,10 +47,10 @@ use tokio::sync::RwLock; use super::client::RequestResultExt; use super::client::{HttpSend, RestfulLanceDbClient, Sender}; use super::db::ServerVersion; +use super::util::stream_as_body; use super::ARROW_STREAM_CONTENT_TYPE; use crate::index::waiter::wait_for_index; use crate::{ - connection::NoData, error::Result, index::{IndexBuilder, IndexConfig}, query::QueryExecutionOptions, @@ -264,7 +266,10 @@ impl RemoteTable { fn reader_as_body(data: Box) -> Result { // TODO: Once Phalanx supports compression, we should use it here. - let mut writer = arrow_ipc::writer::StreamWriter::try_new(Vec::new(), &data.schema())?; + let mut writer = arrow_ipc::writer::StreamWriter::try_new( + Vec::new(), + &RecordBatchReader::schema(&*data), + )?; // Mutex is just here to make it sync. We shouldn't have any contention. let mut data = Mutex::new(data); @@ -340,6 +345,110 @@ impl RemoteTable { Ok(res) } + /// Send a request with data from a Scannable source. + /// + /// For rescannable sources, this will retry on retryable errors by re-reading + /// the data. For non-rescannable sources (streams), only a single attempt is made. + async fn send_scannable( + &self, + req_builder: RequestBuilder, + data: &mut dyn Scannable, + ) -> Result<(String, Response)> { + use crate::remote::retry::RetryCounter; + + // Right now, Python and Typescript don't pass down re-scannable data yet. + // So to preserve existing retry behavior, we have to collect data in + // memory for now. Once they expose rescannable data sources, we can remove this. + if !data.rescannable() && self.client.retry_config.retries > 0 { + let mut body = Vec::new(); + stream_as_ipc(data.scan_as_stream())? + .try_for_each(|b| { + body.extend_from_slice(&b); + futures::future::ok(()) + }) + .await?; + let req_builder = req_builder.body(body); + return self.client.send_with_retry(req_builder, None, true).await; + } + + let rescannable = data.rescannable(); + let max_retries = if rescannable { + self.client.retry_config.retries + } else { + 0 + }; + + // Clone the request builder to extract the request id + let tmp_req = req_builder.try_clone().ok_or_else(|| Error::Runtime { + message: "Attempted to retry a request that cannot be cloned".to_string(), + })?; + let (_, r) = tmp_req.build_split(); + let mut r = r.map_err(|e| Error::Runtime { + message: format!("Failed to build request: {}", e), + })?; + let request_id = self.client.extract_request_id(&mut r); + let mut retry_counter = RetryCounter::new(&self.client.retry_config, request_id.clone()); + + loop { + // Re-read data on each attempt + let stream = data.scan_as_stream(); + let body = stream_as_body(stream)?; + + let mut req_builder = req_builder.try_clone().ok_or_else(|| Error::Runtime { + message: "Attempted to retry a request that cannot be cloned".to_string(), + })?; + req_builder = req_builder.body(body); + + let (c, request) = req_builder.build_split(); + let mut request = request.map_err(|e| Error::Runtime { + message: format!("Failed to build request: {}", e), + })?; + self.client.set_request_id(&mut request, &request_id); + + // Apply dynamic headers + request = self.client.apply_dynamic_headers(request).await?; + + self.client.log_request(&request, &request_id); + + let response = match self.client.sender.send(&c, request).await { + Ok(r) => r, + Err(err) => { + if err.is_connect() { + retry_counter.increment_connect_failures(err)?; + } else if err.is_body() || err.is_decode() { + retry_counter.increment_read_failures(err)?; + } else { + return Err(crate::Error::Http { + source: err.into(), + request_id, + status_code: None, + }); + } + tokio::time::sleep(retry_counter.next_sleep_time()).await; + continue; + } + }; + + let status = response.status(); + + // Check for retryable status codes + if self.client.retry_config.statuses.contains(&status) + && retry_counter.request_failures < max_retries + { + let http_err = crate::Error::Http { + source: format!("Retryable status code: {}", status).into(), + request_id: request_id.clone(), + status_code: Some(status), + }; + retry_counter.increment_request_failures(http_err)?; + tokio::time::sleep(retry_counter.next_sleep_time()).await; + continue; + } + + return Ok((request_id, response)); + } + } + pub(super) async fn handle_table_not_found( table_name: &str, response: reqwest::Response, @@ -656,8 +765,9 @@ impl std::fmt::Display for RemoteTable { #[cfg(all(test, feature = "remote"))] mod test_utils { use super::*; - use crate::remote::client::test_utils::client_with_handler; use crate::remote::client::test_utils::MockSender; + use crate::remote::client::test_utils::{client_with_handler, client_with_handler_and_config}; + use crate::remote::ClientConfig; impl RemoteTable { pub fn new_mock(name: String, handler: F, version: Option) -> Self @@ -676,6 +786,23 @@ mod test_utils { location: RwLock::new(None), } } + + pub fn new_mock_with_config(name: String, handler: F, config: ClientConfig) -> Self + where + F: Fn(reqwest::Request) -> http::Response + Send + Sync + 'static, + T: Into, + { + let client = client_with_handler_and_config(handler, config); + Self { + client, + name: name.clone(), + namespace: vec![], + identifier: name, + server_version: ServerVersion::default(), + version: RwLock::new(None), + location: RwLock::new(None), + } + } } } @@ -797,11 +924,7 @@ impl BaseTable for RemoteTable { status_code: None, }) } - async fn add( - &self, - add: AddDataBuilder, - data: Box, - ) -> Result { + async fn add(&self, mut add: AddDataBuilder) -> Result { self.check_mutable().await?; let mut request = self .client @@ -815,7 +938,7 @@ impl BaseTable for RemoteTable { } } - let (request_id, response) = self.send_streaming(request, data, true).await?; + let (request_id, response) = self.send_scannable(request, &mut *add.data).await?; let response = self.check_table_response(&request_id, response).await?; let body = response.text().await.err_to_http(request_id.clone())?; if body.trim().is_empty() { @@ -1584,12 +1707,14 @@ impl TryFrom for MergeInsertRequest { #[cfg(test)] mod tests { + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; use std::{collections::HashMap, pin::Pin}; use super::*; use arrow::{array::AsArray, compute::concat_batches, datatypes::Int32Type}; - use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator}; + use arrow_array::{record_batch, Int32Array, RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; use chrono::{DateTime, Utc}; use futures::{future::BoxFuture, StreamExt, TryFutureExt}; @@ -1623,7 +1748,8 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let example_data = || { + let example_data_for_add = || batch.clone(); + let example_data_for_merge = || -> Box { Box::new(RecordBatchIterator::new( [Ok(batch.clone())], batch.schema(), @@ -1636,11 +1762,11 @@ mod tests { Box::pin(table.schema().map_ok(|_| ())), Box::pin(table.count_rows(None).map_ok(|_| ())), Box::pin(table.update().column("a", "a + 1").execute().map_ok(|_| ())), - Box::pin(table.add(example_data()).execute().map_ok(|_| ())), + Box::pin(table.add(example_data_for_add()).execute().map_ok(|_| ())), Box::pin( table .merge_insert(&["test"]) - .execute(example_data()) + .execute(example_data_for_merge()) .map_ok(|_| ()), ), Box::pin(table.delete("false").map_ok(|_| ())), @@ -1772,8 +1898,15 @@ mod tests { fn write_ipc_stream(data: &RecordBatch) -> Vec { let mut body = Vec::new(); { - let mut writer = arrow_ipc::writer::StreamWriter::try_new(&mut body, &data.schema()) - .expect("Failed to create writer"); + let options = arrow_ipc::writer::IpcWriteOptions::default() + .try_with_compression(Some(arrow_ipc::CompressionType::LZ4_FRAME)) + .expect("Failed to create IPC write options"); + let mut writer = arrow_ipc::writer::StreamWriter::try_new_with_options( + &mut body, + &data.schema(), + options, + ) + .expect("Failed to create writer"); writer.write(data).expect("Failed to write data"); writer.finish().expect("Failed to finish"); } @@ -1831,11 +1964,7 @@ mod tests { panic!("Unexpected request path: {}", request.url().path()); } }); - let result = table - .add(RecordBatchIterator::new([Ok(data.clone())], data.schema())) - .execute() - .await - .unwrap(); + let result = table.add(data.clone()).execute().await.unwrap(); // Check version matches expected value assert_eq!(result.version, expected_version); @@ -1892,7 +2021,7 @@ mod tests { }); let result = table - .add(RecordBatchIterator::new([Ok(data.clone())], data.schema())) + .add(data.clone()) .mode(AddDataMode::Overwrite) .execute() .await @@ -2032,7 +2161,7 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let data = Box::new(RecordBatchIterator::new( + let data: Box = Box::new(RecordBatchIterator::new( [Ok(batch.clone())], batch.schema(), )); @@ -2084,7 +2213,7 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let data = Box::new(RecordBatchIterator::new( + let data: Box = Box::new(RecordBatchIterator::new( [Ok(batch.clone())], batch.schema(), )); @@ -3015,7 +3144,7 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let data = Box::new(RecordBatchIterator::new( + let data: Box = Box::new(RecordBatchIterator::new( [Ok(batch.clone())], batch.schema(), )); @@ -3030,10 +3159,7 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let res = table - .add(RecordBatchIterator::new([Ok(data.clone())], data.schema())) - .execute() - .await; + let res = table.add(data.clone()).execute().await; assert!(matches!(res, Err(Error::NotSupported { .. }))); let res = table @@ -3295,11 +3421,7 @@ mod tests { } }); - let result = table - .add(RecordBatchIterator::new([Ok(data.clone())], data.schema())) - .execute() - .await - .unwrap(); + let result = table.add(data.clone()).execute().await.unwrap(); assert_eq!(result.version, 2); @@ -3446,4 +3568,95 @@ mod tests { assert_eq!(uri2, "gs://bucket/table"); assert_eq!(call_count.load(Ordering::SeqCst), 1); // Still 1, no new call } + + #[tokio::test] + async fn test_add_retries_rescannable_data() { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + // Configure with retries enabled (default is 3) + let config = crate::remote::ClientConfig::default(); + + let table = Table::new_with_handler_and_config( + "my_table", + move |_request| { + let count = call_count_clone.fetch_add(1, Ordering::SeqCst); + if count < 2 { + // First two attempts fail with a retryable error (409) + http::Response::builder().status(409).body("").unwrap() + } else { + // Third attempt succeeds + http::Response::builder() + .status(200) + .body(r#"{"version": 1}"#) + .unwrap() + } + }, + config, + ); + + // RecordBatch is rescannable - should retry and succeed + let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap(); + let result = table.add(batch).execute().await; + + assert!( + result.is_ok(), + "Expected success after retries: {:?}", + result + ); + assert_eq!( + call_count.load(Ordering::SeqCst), + 3, + "Expected 2 failed attempts + 1 success = 3 total" + ); + } + + #[tokio::test] + async fn test_add_no_retry_for_non_rescannable() { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + // Configure with retries enabled + let config = crate::remote::ClientConfig::default(); + + let table = Table::new_with_handler_and_config( + "my_table", + move |_request| { + call_count_clone.fetch_add(1, Ordering::SeqCst); + // Always fail with retryable error + http::Response::builder().status(409).body("").unwrap() + }, + config, + ); + + // RecordBatchReader is NOT rescannable - should NOT retry + let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap(); + let reader: Box = Box::new( + RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()), + ); + + let result = table.add(reader).execute().await; + + // Should fail because we can't retry non-rescannable sources + assert!(result.is_err()); + // Right now, we actually do retry, so we get 3 failures. In the future + // this will change and we need to update the test. + assert!( + matches!( + result.unwrap_err(), + Error::Retry { + request_failures: 3, + .. + } + ), + "Expected RequestFailed with status 409" + ); + // TODO: After we implement proper non-rescannable handling, uncomment below + // (This is blocked on getting Python and Node to pass down re-scannable data.) + // assert_eq!( + // call_count.load(Ordering::SeqCst), + // 1, + // "Expected only one attempt for non-rescannable source" + // ); + } } diff --git a/rust/lancedb/src/remote/util.rs b/rust/lancedb/src/remote/util.rs index 4b92ec0c2..51ddb97de 100644 --- a/rust/lancedb/src/remote/util.rs +++ b/rust/lancedb/src/remote/util.rs @@ -1,29 +1,50 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors -use std::io::Cursor; - -use arrow_array::RecordBatchReader; +use arrow_ipc::CompressionType; +use futures::{Stream, StreamExt}; use reqwest::Response; -use crate::Result; +use crate::{arrow::SendableRecordBatchStream, Result}; use super::db::ServerVersion; -pub fn batches_to_ipc_bytes(batches: impl RecordBatchReader) -> Result> { +pub fn stream_as_ipc( + data: SendableRecordBatchStream, +) -> Result>> { + let options = arrow_ipc::writer::IpcWriteOptions::default() + .try_with_compression(Some(CompressionType::LZ4_FRAME))?; const WRITE_BUF_SIZE: usize = 4096; let buf = Vec::with_capacity(WRITE_BUF_SIZE); - let mut buf = Cursor::new(buf); - { - let mut writer = arrow_ipc::writer::StreamWriter::try_new(&mut buf, &batches.schema())?; + let writer = + arrow_ipc::writer::StreamWriter::try_new_with_options(buf, &data.schema(), options)?; + let stream = futures::stream::try_unfold( + (data, writer, false), + move |(mut data, mut writer, finished)| async move { + if finished { + return Ok(None); + } + match data.next().await { + Some(Ok(batch)) => { + writer.write(&batch)?; + let buffer = std::mem::take(writer.get_mut()); + Ok(Some((bytes::Bytes::from(buffer), (data, writer, false)))) + } + Some(Err(e)) => Err(e), + None => { + writer.finish()?; + let buffer = std::mem::take(writer.get_mut()); + Ok(Some((bytes::Bytes::from(buffer), (data, writer, true)))) + } + } + }, + ); + Ok(stream) +} - for batch in batches { - let batch = batch?; - writer.write(&batch)?; - } - writer.finish()?; - } - Ok(buf.into_inner()) +pub fn stream_as_body(data: SendableRecordBatchStream) -> Result { + let stream = stream_as_ipc(data)?; + Ok(reqwest::Body::wrap_stream(stream)) } pub fn parse_server_version(req_id: &str, rsp: &Response) -> Result { diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 776479a8b..b7761f5f9 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -5,7 +5,7 @@ use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder}; use arrow::datatypes::{Float32Type, UInt8Type}; -use arrow_array::{RecordBatchIterator, RecordBatchReader}; +use arrow_array::{RecordBatch, RecordBatchReader}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion_expr::Expr; @@ -50,10 +50,9 @@ use std::format; use std::path::Path; use std::sync::Arc; -use crate::arrow::IntoArrow; -use crate::connection::NoData; +use crate::data::scannable::{scannable_with_embeddings, Scannable}; use crate::database::Database; -use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry}; +use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MemoryRegistry}; use crate::error::{Error, Result}; use crate::index::vector::VectorIndex; use crate::index::IndexStatistics; @@ -72,6 +71,7 @@ use crate::utils::{ use self::dataset::DatasetConsistencyWrapper; use self::merge::MergeInsertBuilder; +mod add_data; pub mod datafusion; pub(crate) mod dataset; pub mod delete; @@ -80,6 +80,8 @@ pub mod optimize; pub mod schema_evolution; pub mod update; +pub use add_data::{AddDataBuilder, AddDataMode, AddResult}; + use crate::index::waiter::wait_for_index; pub use chrono::Duration; pub use delete::DeleteResult; @@ -196,60 +198,6 @@ pub struct WriteOptions { pub lance_write_params: Option, } -#[derive(Debug, Clone, Default)] -pub enum AddDataMode { - /// Rows will be appended to the table (the default) - #[default] - Append, - /// The existing table will be overwritten with the new data - Overwrite, -} - -/// A builder for configuring a [`crate::connection::Connection::create_table`] or [`Table::add`] -/// operation -pub struct AddDataBuilder { - parent: Arc, - pub(crate) data: T, - pub(crate) mode: AddDataMode, - pub(crate) write_options: WriteOptions, - embedding_registry: Option>, -} - -impl std::fmt::Debug for AddDataBuilder { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("AddDataBuilder") - .field("parent", &self.parent) - .field("mode", &self.mode) - .field("write_options", &self.write_options) - .finish() - } -} - -impl AddDataBuilder { - pub fn mode(mut self, mode: AddDataMode) -> Self { - self.mode = mode; - self - } - - pub fn write_options(mut self, options: WriteOptions) -> Self { - self.write_options = options; - self - } - - pub async fn execute(self) -> Result { - let parent = self.parent.clone(); - let data = self.data.into_arrow()?; - let without_data = AddDataBuilder:: { - data: NoData {}, - mode: self.mode, - parent: self.parent, - write_options: self.write_options, - embedding_registry: self.embedding_registry, - }; - parent.add(without_data, data).await - } -} - /// Filters that can be used to limit the rows returned by a query pub enum Filter { /// A SQL filter string @@ -283,15 +231,6 @@ pub trait Tags: Send + Sync { async fn update(&mut self, tag: &str, version: u64) -> Result<()>; } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] -pub struct AddResult { - // The commit version associated with the operation. - // A version of `0` indicates compatibility with legacy servers that do not return - /// a commit version. - #[serde(default)] - pub version: u64, -} - #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] pub struct MergeResult { // The commit version associated with the operation. @@ -364,11 +303,7 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { ) -> Result; /// Add new records to the table. - async fn add( - &self, - add: AddDataBuilder, - data: Box, - ) -> Result; + async fn add(&self, add: AddDataBuilder) -> Result; /// Delete rows from the table. async fn delete(&self, predicate: &str) -> Result; /// Update rows in the table. @@ -513,6 +448,30 @@ mod test_utils { embedding_registry: Arc::new(MemoryRegistry::new()), } } + + pub fn new_with_handler_and_config( + name: impl Into, + handler: impl Fn(reqwest::Request) -> http::Response + Clone + Send + Sync + 'static, + config: crate::remote::ClientConfig, + ) -> Self + where + T: Into, + { + let inner = Arc::new(crate::remote::table::RemoteTable::new_mock_with_config( + name.into(), + handler.clone(), + config.clone(), + )); + let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock_with_config( + handler, config, + )); + Self { + inner, + database: Some(database), + // Registry is unused. + embedding_registry: Arc::new(MemoryRegistry::new()), + } + } } } @@ -613,16 +572,14 @@ impl Table { /// /// # Arguments /// - /// * `batches` data to be added to the Table + /// * `data` data to be added to the Table /// * `options` options to control how data is added - pub fn add(&self, batches: T) -> AddDataBuilder { - AddDataBuilder { - parent: self.inner.clone(), - data: batches, - mode: AddDataMode::Append, - write_options: WriteOptions::default(), - embedding_registry: Some(self.embedding_registry.clone()), - } + pub fn add(&self, data: T) -> AddDataBuilder { + AddDataBuilder::new( + self.inner.clone(), + Box::new(data), + Some(self.embedding_registry.clone()), + ) } /// Update existing records in the Table @@ -661,31 +618,26 @@ impl Table { /// .execute() /// .await /// .unwrap(); - /// # let schema = Arc::new(Schema::new(vec![ - /// # Field::new("id", DataType::Int32, false), - /// # Field::new("vector", DataType::FixedSizeList( - /// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true), - /// # ])); - /// let batches = RecordBatchIterator::new( - /// vec![RecordBatch::try_new( - /// schema.clone(), - /// vec![ - /// Arc::new(Int32Array::from_iter_values(0..10)), - /// Arc::new( - /// FixedSizeListArray::from_iter_primitive::( - /// (0..10).map(|_| Some(vec![Some(1.0); 128])), - /// 128, - /// ), - /// ), - /// ], - /// ) - /// .unwrap()] - /// .into_iter() - /// .map(Ok), + /// let schema = Arc::new(Schema::new(vec![ + /// Field::new("id", DataType::Int32, false), + /// Field::new("vector", DataType::FixedSizeList( + /// Arc::new(Field::new("item", DataType::Float32, true)), 128), true), + /// ])); + /// let data = RecordBatch::try_new( /// schema.clone(), - /// ); + /// vec![ + /// Arc::new(Int32Array::from_iter_values(0..10)), + /// Arc::new( + /// FixedSizeListArray::from_iter_primitive::( + /// (0..10).map(|_| Some(vec![Some(1.0); 128])), + /// 128, + /// ), + /// ), + /// ], + /// ) + /// .unwrap(); /// let tbl = db - /// .create_table("delete_test", Box::new(batches)) + /// .create_table("delete_test", data) /// .execute() /// .await /// .unwrap(); @@ -1445,7 +1397,7 @@ impl NativeTable { name: name.to_string(), source: Box::new(e), }, - source => Error::Lance { source }, + e => e.into(), })?; let dataset = DatasetConsistencyWrapper::new_latest(dataset, read_consistency_interval); @@ -1529,7 +1481,7 @@ impl NativeTable { lance::Error::Namespace { source, .. } => Error::Runtime { message: format!("Failed to get table info from namespace: {:?}", source), }, - source => Error::Lance { source }, + e => e.into(), })?; let dataset = builder @@ -1541,7 +1493,7 @@ impl NativeTable { name: name.to_string(), source: Box::new(e), }, - source => Error::Lance { source }, + e => e.into(), })?; let uri = dataset.uri().to_string(); @@ -1635,7 +1587,7 @@ impl NativeTable { lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists { name: name.to_string(), }, - source => Error::Lance { source }, + e => e.into(), })?; let id = Self::build_id(&namespace, name); @@ -1662,12 +1614,12 @@ impl NativeTable { read_consistency_interval: Option, namespace_client: Option>, ) -> Result { - let batches = RecordBatchIterator::new(vec![], schema); + let data: Box = Box::new(RecordBatch::new_empty(schema)); Self::create( uri, name, namespace, - batches, + data, write_store_wrapper, params, read_consistency_interval, @@ -1756,7 +1708,7 @@ impl NativeTable { lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists { name: name.to_string(), }, - source => Error::Lance { source }, + e => e.into(), })?; let id = Self::build_id(&namespace, name); @@ -2538,17 +2490,7 @@ impl BaseTable for NativeTable { } } - async fn add( - &self, - add: AddDataBuilder, - data: Box, - ) -> Result { - let data = Box::new(MaybeEmbedded::try_new( - data, - self.table_definition().await?, - add.embedding_registry, - )?) as Box; - + async fn add(&self, add: AddDataBuilder) -> Result { let lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams { mode: match add.mode { AddDataMode::Append => WriteMode::Append, @@ -2557,6 +2499,11 @@ impl BaseTable for NativeTable { ..Default::default() }); + // Apply embeddings if configured + let table_def = self.table_definition().await?; + let data = + scannable_with_embeddings(add.data, &table_def, add.embedding_registry.as_ref())?; + let dataset = { // Limited scope for the mutable borrow of self.dataset avoids deadlock. let ds = self.dataset.get_mut().await?; @@ -3164,7 +3111,6 @@ mod tests { use arrow_data::ArrayDataBuilder; use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; - use lance::dataset::WriteMode; use lance::io::{ObjectStoreParams, WrappingObjectStore}; use lance::Dataset; use rand::Rng; @@ -3183,8 +3129,9 @@ mod tests { let tmp_dir = tempdir().unwrap(); let dataset_path = tmp_dir.path().join("test.lance"); - let batches = make_test_batches(); - Dataset::write(batches, dataset_path.to_str().unwrap(), None) + let batch = make_test_batches(); + let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()); + Dataset::write(reader, dataset_path.to_str().unwrap(), None) .await .unwrap(); @@ -3217,9 +3164,12 @@ mod tests { let tmp_dir = tempdir().unwrap(); let uri = tmp_dir.path().to_str().unwrap(); - let batches = make_test_batches(); - let batches = Box::new(batches) as Box; - let table = NativeTable::create(uri, "test", vec![], batches, None, None, None, None) + let batch = make_test_batches(); + let reader: Box = Box::new(RecordBatchIterator::new( + vec![Ok(batch.clone())], + batch.schema(), + )); + let table = NativeTable::create(uri, "test", vec![], reader, None, None, None, None) .await .unwrap(); @@ -3233,33 +3183,6 @@ mod tests { ); } - #[tokio::test] - async fn test_add() { - let tmp_dir = tempdir().unwrap(); - let uri = tmp_dir.path().to_str().unwrap(); - let conn = connect(uri).execute().await.unwrap(); - - let batches = make_test_batches(); - let schema = batches.schema().clone(); - let table = conn.create_table("test", batches).execute().await.unwrap(); - assert_eq!(table.count_rows(None).await.unwrap(), 10); - - let new_batches = RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from_iter_values(100..110))], - ) - .unwrap()] - .into_iter() - .map(Ok), - schema.clone(), - ); - - table.add(new_batches).execute().await.unwrap(); - assert_eq!(table.count_rows(None).await.unwrap(), 20); - assert_eq!(table.name(), "test"); - } - #[tokio::test] async fn test_merge_insert() { let tmp_dir = tempdir().unwrap(); @@ -3276,7 +3199,7 @@ mod tests { assert_eq!(table.count_rows(None).await.unwrap(), 10); // Create new data with i=5..15 - let new_batches = Box::new(merge_insert_test_batches(5, 1)); + let new_batches = merge_insert_test_batches(5, 1); // Perform a "insert if not exists" let mut merge_insert_builder = table.merge_insert(&["i"]); @@ -3290,7 +3213,7 @@ mod tests { assert_eq!(result.num_attempts, 1); // Create new data with i=15..25 (no id matches) - let new_batches = Box::new(merge_insert_test_batches(15, 2)); + let new_batches = merge_insert_test_batches(15, 2); // Perform a "bulk update" (should not affect anything) let mut merge_insert_builder = table.merge_insert(&["i"]); merge_insert_builder.when_matched_update_all(None); @@ -3303,7 +3226,7 @@ mod tests { ); // Conditional update that only replaces the age=0 data - let new_batches = Box::new(merge_insert_test_batches(5, 3)); + let new_batches = merge_insert_test_batches(5, 3); let mut merge_insert_builder = table.merge_insert(&["i"]); merge_insert_builder.when_matched_update_all(Some("target.age = 0".to_string())); merge_insert_builder.execute(new_batches).await.unwrap(); @@ -3329,7 +3252,7 @@ mod tests { assert_eq!(table.count_rows(None).await.unwrap(), 10); // Test use_index=true (default behavior) - let new_batches = Box::new(merge_insert_test_batches(5, 1)); + let new_batches = merge_insert_test_batches(5, 1); let mut merge_insert_builder = table.merge_insert(&["i"]); merge_insert_builder.when_not_matched_insert_all(); merge_insert_builder.use_index(true); @@ -3337,7 +3260,7 @@ mod tests { assert_eq!(table.count_rows(None).await.unwrap(), 15); // Test use_index=false (force table scan) - let new_batches = Box::new(merge_insert_test_batches(15, 2)); + let new_batches = merge_insert_test_batches(15, 2); let mut merge_insert_builder = table.merge_insert(&["i"]); merge_insert_builder.when_not_matched_insert_all(); merge_insert_builder.use_index(false); @@ -3345,59 +3268,6 @@ mod tests { assert_eq!(table.count_rows(None).await.unwrap(), 25); } - #[tokio::test] - async fn test_add_overwrite() { - let tmp_dir = tempdir().unwrap(); - let uri = tmp_dir.path().to_str().unwrap(); - let conn = connect(uri).execute().await.unwrap(); - - let batches = make_test_batches(); - let schema = batches.schema().clone(); - let table = conn.create_table("test", batches).execute().await.unwrap(); - assert_eq!(table.count_rows(None).await.unwrap(), 10); - - let batches = vec![RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from_iter_values(100..110))], - ) - .unwrap()] - .into_iter() - .map(Ok); - - let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone()); - - // Can overwrite using AddDataOptions::mode - table - .add(new_batches) - .mode(AddDataMode::Overwrite) - .execute() - .await - .unwrap(); - assert_eq!(table.count_rows(None).await.unwrap(), 10); - assert_eq!(table.name(), "test"); - - // Can overwrite using underlying WriteParams (which - // take precedence over AddDataOptions::mode) - - let param: WriteParams = WriteParams { - mode: WriteMode::Overwrite, - ..Default::default() - }; - - let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone()); - table - .add(new_batches) - .write_options(WriteOptions { - lance_write_params: Some(param), - }) - .mode(AddDataMode::Append) - .execute() - .await - .unwrap(); - assert_eq!(table.count_rows(None).await.unwrap(), 10); - assert_eq!(table.name(), "test"); - } - #[derive(Default, Debug)] struct NoOpCacheWrapper { called: AtomicBool, @@ -3453,35 +3323,25 @@ mod tests { assert!(wrapper.called()); } - fn merge_insert_test_batches( - offset: i32, - age: i32, - ) -> impl RecordBatchReader + Send + Sync + 'static { + fn merge_insert_test_batches(offset: i32, age: i32) -> Box { let schema = Arc::new(Schema::new(vec![ Field::new("i", DataType::Int32, false), Field::new("age", DataType::Int32, false), ])); - RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_iter_values(offset..(offset + 10))), - Arc::new(Int32Array::from_iter_values(std::iter::repeat_n(age, 10))), - ], - )], - schema, + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_iter_values(offset..(offset + 10))), + Arc::new(Int32Array::from_iter_values(std::iter::repeat_n(age, 10))), + ], ) + .unwrap(); + Box::new(RecordBatchIterator::new(vec![Ok(batch)], schema)) } - fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { + fn make_test_batches() -> RecordBatch { let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); - RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from_iter_values(0..10))], - )], - schema, - ) + RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from_iter_values(0..10))]).unwrap() } #[tokio::test] @@ -3569,14 +3429,9 @@ mod tests { ); let vectors = Arc::new(create_fixed_size_list(float_arr, dimension).unwrap()); - let batches = RecordBatchIterator::new( - vec![RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap()] - .into_iter() - .map(Ok), - schema, - ); + let batch = RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap(); - let table = conn.create_table("test", batches).execute().await.unwrap(); + let table = conn.create_table("test", batch).execute().await.unwrap(); assert_eq!(table.index_stats("my_index").await.unwrap(), None); @@ -3658,14 +3513,9 @@ mod tests { let float_arr = Float32Array::from_iter_values((0..(num_rows * dimension)).map(|v| v as f32)); let vectors = Arc::new(create_fixed_size_list(float_arr, dimension as i32).unwrap()); - let batches = RecordBatchIterator::new( - vec![RecordBatch::try_new(schema.clone(), vec![vectors]).unwrap()] - .into_iter() - .map(Ok), - schema, - ); + let batch = RecordBatch::try_new(schema.clone(), vec![vectors]).unwrap(); - let table = conn.create_table("test", batches).execute().await.unwrap(); + let table = conn.create_table("test", batch).execute().await.unwrap(); let native_table = table.as_native().unwrap(); let builder = IvfPqIndexBuilder::default(); table @@ -3735,14 +3585,9 @@ mod tests { ); let vectors = Arc::new(create_fixed_size_list(float_arr, dimension).unwrap()); - let batches = RecordBatchIterator::new( - vec![RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap()] - .into_iter() - .map(Ok), - schema, - ); + let batch = RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap(); - let table = conn.create_table("test", batches).execute().await.unwrap(); + let table = conn.create_table("test", batch).execute().await.unwrap(); let stats = table.index_stats("my_index").await.unwrap(); assert!(stats.is_none()); @@ -3800,14 +3645,9 @@ mod tests { ); let vectors = Arc::new(create_fixed_size_list(float_arr, dimension).unwrap()); - let batches = RecordBatchIterator::new( - vec![RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap()] - .into_iter() - .map(Ok), - schema, - ); + let batch = RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap(); - let table = conn.create_table("test", batches).execute().await.unwrap(); + let table = conn.create_table("test", batch).execute().await.unwrap(); let stats = table.index_stats("my_index").await.unwrap(); assert!(stats.is_none()); @@ -3850,7 +3690,7 @@ mod tests { Ok(FixedSizeListArray::from(data)) } - fn some_sample_data() -> Box { + fn some_sample_data() -> Box { let batch = RecordBatch::try_new( Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])), vec![Arc::new(Int32Array::from(vec![1]))], @@ -3874,10 +3714,7 @@ mod tests { .unwrap(); let conn = ConnectBuilder::new(uri).execute().await.unwrap(); let table = conn - .create_table( - "my_table", - RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()), - ) + .create_table("my_table", batch.clone()) .execute() .await .unwrap(); @@ -3956,10 +3793,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_bitmap", - RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()), - ) + .create_table("test_bitmap", batch.clone()) .execute() .await .unwrap(); @@ -4060,10 +3894,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_bitmap", - RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()), - ) + .create_table("test_bitmap", batch.clone()) .execute() .await .unwrap(); @@ -4123,10 +3954,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_bitmap", - RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()), - ) + .create_table("test_bitmap", batch.clone()) .execute() .await .unwrap(); @@ -4171,7 +3999,7 @@ mod tests { let conn1 = ConnectBuilder::new(uri).execute().await.unwrap(); let table1 = conn1 - .create_empty_table("my_table", data.schema()) + .create_empty_table("my_table", RecordBatchReader::schema(&data)) .execute() .await .unwrap(); @@ -4441,10 +4269,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_stats", - RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()), - ) + .create_table("test_stats", batch.clone()) .execute() .await .unwrap(); @@ -4457,21 +4282,11 @@ mod tests { ], ) .unwrap(); - table - .add(RecordBatchIterator::new( - vec![Ok(batch.clone())], - batch.schema(), - )) - .execute() - .await - .unwrap(); + table.add(batch.clone()).execute().await.unwrap(); } let empty_table = conn - .create_table( - "test_stats_empty", - RecordBatchIterator::new(vec![], batch.schema()), - ) + .create_table("test_stats_empty", RecordBatch::new_empty(batch.schema())) .execute() .await .unwrap(); @@ -4545,22 +4360,12 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_list_indices_skip_frag_reuse", - RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()), - ) + .create_table("test_list_indices_skip_frag_reuse", batch.clone()) .execute() .await .unwrap(); - table - .add(RecordBatchIterator::new( - vec![Ok(batch.clone())], - batch.schema(), - )) - .execute() - .await - .unwrap(); + table.add(batch.clone()).execute().await.unwrap(); table .create_index(&["id"], Index::Bitmap(BitmapIndexBuilder {})) @@ -4590,8 +4395,9 @@ mod tests { let tmp_dir = tempdir().unwrap(); let dataset_path = tmp_dir.path().join("test_ns_query.lance"); - let batches = make_test_batches(); - Dataset::write(batches, dataset_path.to_str().unwrap(), None) + let batch = make_test_batches(); + let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()); + Dataset::write(reader, dataset_path.to_str().unwrap(), None) .await .unwrap(); @@ -4643,8 +4449,9 @@ mod tests { let tmp_dir = tempdir().unwrap(); let dataset_path = tmp_dir.path().join("test_ns_plain.lance"); - let batches = make_test_batches(); - Dataset::write(batches, dataset_path.to_str().unwrap(), None) + let batch = make_test_batches(); + let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()); + Dataset::write(reader, dataset_path.to_str().unwrap(), None) .await .unwrap(); diff --git a/rust/lancedb/src/table/add_data.rs b/rust/lancedb/src/table/add_data.rs new file mode 100644 index 000000000..740e53011 --- /dev/null +++ b/rust/lancedb/src/table/add_data.rs @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; + +use crate::data::scannable::Scannable; +use crate::embeddings::EmbeddingRegistry; +use crate::Result; + +use super::{BaseTable, WriteOptions}; + +#[derive(Debug, Clone, Default)] +pub enum AddDataMode { + /// Rows will be appended to the table (the default) + #[default] + Append, + /// The existing table will be overwritten with the new data + Overwrite, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +pub struct AddResult { + // The commit version associated with the operation. + // A version of `0` indicates compatibility with legacy servers that do not return + /// a commit version. + #[serde(default)] + pub version: u64, +} + +/// A builder for configuring a [`crate::table::Table::add`] operation +pub struct AddDataBuilder { + pub(crate) parent: Arc, + pub(crate) data: Box, + pub(crate) mode: AddDataMode, + pub(crate) write_options: WriteOptions, + pub(crate) embedding_registry: Option>, +} + +impl std::fmt::Debug for AddDataBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AddDataBuilder") + .field("parent", &self.parent) + .field("mode", &self.mode) + .field("write_options", &self.write_options) + .finish() + } +} + +impl AddDataBuilder { + pub(crate) fn new( + parent: Arc, + data: Box, + embedding_registry: Option>, + ) -> Self { + Self { + parent, + data, + mode: AddDataMode::Append, + write_options: WriteOptions::default(), + embedding_registry, + } + } + + pub fn mode(mut self, mode: AddDataMode) -> Self { + self.mode = mode; + self + } + + pub fn write_options(mut self, options: WriteOptions) -> Self { + self.write_options = options; + self + } + + pub async fn execute(self) -> Result { + self.parent.clone().add(self).await + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::{record_batch, RecordBatch, RecordBatchIterator}; + use arrow_schema::{ArrowError, DataType, Field, Schema}; + use futures::TryStreamExt; + use lance::dataset::{WriteMode, WriteParams}; + + use crate::arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}; + use crate::connect; + use crate::data::scannable::Scannable; + use crate::embeddings::{ + EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, + }; + use crate::query::{ExecutableQuery, QueryBase, Select}; + use crate::table::{ColumnDefinition, ColumnKind, Table, TableDefinition, WriteOptions}; + use crate::test_utils::embeddings::MockEmbed; + use crate::Error; + + use super::AddDataMode; + + async fn create_test_table() -> Table { + let conn = connect("memory://").execute().await.unwrap(); + let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap(); + conn.create_table("test", batch).execute().await.unwrap() + } + + async fn test_add_with_data(data: T) + where + T: Scannable + 'static, + { + let table = create_test_table().await; + let schema = data.schema(); + table.add(data).execute().await.unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), 5); // 3 initial + 2 added + assert_eq!(table.schema().await.unwrap(), schema); + } + + #[tokio::test] + async fn test_add_with_batch() { + let batch = record_batch!(("id", Int64, [4, 5])).unwrap(); + test_add_with_data(batch).await; + } + + #[tokio::test] + async fn test_add_with_vec_batch() { + let data = vec![ + record_batch!(("id", Int64, [4])).unwrap(), + record_batch!(("id", Int64, [5])).unwrap(), + ]; + test_add_with_data(data).await; + } + + #[tokio::test] + async fn test_add_with_record_batch_reader() { + let data = vec![ + record_batch!(("id", Int64, [4])).unwrap(), + record_batch!(("id", Int64, [5])).unwrap(), + ]; + let schema = data[0].schema(); + let reader: Box = Box::new( + RecordBatchIterator::new(data.into_iter().map(Ok), schema.clone()), + ); + test_add_with_data(reader).await; + } + + #[tokio::test] + async fn test_add_with_stream() { + let data = vec![ + record_batch!(("id", Int64, [4])).unwrap(), + record_batch!(("id", Int64, [5])).unwrap(), + ]; + let schema = data[0].schema(); + let inner = futures::stream::iter(data.into_iter().map(Ok)); + let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream { + schema, + stream: inner, + }); + test_add_with_data(stream).await; + } + + #[derive(Debug)] + struct MyError; + + impl std::fmt::Display for MyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MyError occurred") + } + } + + impl std::error::Error for MyError {} + + #[tokio::test] + async fn test_add_preserves_reader_error() { + let table = create_test_table().await; + let first_batch = record_batch!(("id", Int64, [4])).unwrap(); + let schema = first_batch.schema(); + let iterator = vec![ + Ok(first_batch), + Err(ArrowError::ExternalError(Box::new(MyError))), + ]; + let reader: Box = Box::new( + RecordBatchIterator::new(iterator.into_iter(), schema.clone()), + ); + + let result = table.add(reader).execute().await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_add_preserves_stream_error() { + let table = create_test_table().await; + let first_batch = record_batch!(("id", Int64, [4])).unwrap(); + let schema = first_batch.schema(); + let iterator = vec![ + Ok(first_batch), + Err(Error::External { + source: Box::new(MyError), + }), + ]; + let stream = futures::stream::iter(iterator); + let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream { + schema: schema.clone(), + stream, + }); + + let result = table.add(stream).execute().await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_add() { + let conn = connect("memory://").execute().await.unwrap(); + + let batch = record_batch!(("i", Int32, [0, 1, 2])).unwrap(); + let table = conn + .create_table("test", batch.clone()) + .execute() + .await + .unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), 3); + + let new_batch = record_batch!(("i", Int32, [3])).unwrap(); + table.add(new_batch).execute().await.unwrap(); + + assert_eq!(table.count_rows(None).await.unwrap(), 4); + assert_eq!(table.schema().await.unwrap(), batch.schema()); + } + + #[tokio::test] + async fn test_add_overwrite() { + let conn = connect("memory://").execute().await.unwrap(); + + let batch = record_batch!(("i", Int32, [0, 1, 2])).unwrap(); + let table = conn + .create_table("test", batch.clone()) + .execute() + .await + .unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), batch.num_rows()); + + let new_batch = record_batch!(("x", Float32, [0.0, 1.0])).unwrap(); + let res = table + .add(new_batch.clone()) + .mode(AddDataMode::Overwrite) + .execute() + .await + .unwrap(); + assert_eq!(res.version, table.version().await.unwrap()); + assert_eq!(table.count_rows(None).await.unwrap(), new_batch.num_rows()); + assert_eq!(table.schema().await.unwrap(), new_batch.schema()); + + // Can overwrite using underlying WriteParams (which + // take precedence over AddDataMode) + let param: WriteParams = WriteParams { + mode: WriteMode::Overwrite, + ..Default::default() + }; + + table + .add(new_batch.clone()) + .write_options(WriteOptions { + lance_write_params: Some(param), + }) + .mode(AddDataMode::Append) + .execute() + .await + .unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), new_batch.num_rows()); + } + + #[tokio::test] + async fn test_add_with_embeddings() { + let registry = Arc::new(MemoryRegistry::new()); + let mock_embedding: Arc = Arc::new(MockEmbed::new("mock", 4)); + registry.register("mock", mock_embedding).unwrap(); + + let conn = connect("memory://") + .embedding_registry(registry) + .execute() + .await + .unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("text", DataType::Utf8, false), + Field::new( + "text_embedding", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4), + false, + ), + ])); + + // Add embedding metadata to the schema + let embedding_def = EmbeddingDefinition::new("text", "mock", Some("text_embedding")); + let table_def = TableDefinition::new( + schema.clone(), + vec![ + ColumnDefinition { + kind: ColumnKind::Physical, + }, + ColumnDefinition { + kind: ColumnKind::Embedding(embedding_def), + }, + ], + ); + let rich_schema = table_def.into_rich_schema(); + + let table = conn + .create_empty_table("embed_test", rich_schema) + .execute() + .await + .unwrap(); + + // Now add new data WITHOUT the embedding column - it should be computed automatically + let new_batch = record_batch!(("text", Utf8, ["hello", "world"])).unwrap(); + table.add(new_batch).execute().await.unwrap(); + + assert_eq!(table.count_rows(None).await.unwrap(), 2); + + // Query to verify the embeddings were computed for the new rows + let results: Vec = table + .query() + .select(Select::columns(&["text", "text_embedding"])) + .execute() + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 2); + + // Check that all rows have embedding values (not null) + for batch in &results { + let embedding_col = batch.column(1); + assert_eq!(embedding_col.null_count(), 0); + } + } +} diff --git a/rust/lancedb/src/table/datafusion.rs b/rust/lancedb/src/table/datafusion.rs index 760871631..c7ced18ea 100644 --- a/rust/lancedb/src/table/datafusion.rs +++ b/rust/lancedb/src/table/datafusion.rs @@ -287,8 +287,7 @@ pub mod tests { use arrow::array::AsArray; use arrow_array::{ - BinaryArray, Float64Array, Int32Array, Int64Array, RecordBatch, RecordBatchIterator, - RecordBatchReader, StringArray, UInt32Array, + BinaryArray, Float64Array, Int32Array, Int64Array, RecordBatch, StringArray, UInt32Array, }; use arrow_schema::{DataType, Field, Schema}; use datafusion::{ @@ -308,7 +307,7 @@ pub mod tests { table::datafusion::BaseTableAdapter, }; - fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { + fn make_test_batches() -> RecordBatch { let metadata = HashMap::from_iter(vec![("foo".to_string(), "bar".to_string())]); let schema = Arc::new( Schema::new(vec![ @@ -317,19 +316,17 @@ pub mod tests { ]) .with_metadata(metadata), ); - RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_iter_values(0..10)), - Arc::new(UInt32Array::from_iter_values(0..10)), - ], - )], + RecordBatch::try_new( schema, + vec![ + Arc::new(Int32Array::from_iter_values(0..10)), + Arc::new(UInt32Array::from_iter_values(0..10)), + ], ) + .unwrap() } - fn make_tbl_two_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { + fn make_tbl_two_test_batches() -> RecordBatch { let metadata = HashMap::from_iter(vec![("foo".to_string(), "bar".to_string())]); let schema = Arc::new( Schema::new(vec![ @@ -342,28 +339,26 @@ pub mod tests { ]) .with_metadata(metadata), ); - RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int64Array::from_iter_values(0..1000)), - Arc::new(StringArray::from_iter_values( - (0..1000).map(|i| i.to_string()), - )), - Arc::new(Float64Array::from_iter_values((0..1000).map(|i| i as f64))), - Arc::new(StringArray::from_iter_values( - (0..1000).map(|i| format!("{{\"i\":{}}}", i)), - )), - Arc::new(BinaryArray::from_iter_values( - (0..1000).map(|i| (i as u32).to_be_bytes().to_vec()), - )), - Arc::new(StringArray::from_iter_values( - (0..1000).map(|i| i.to_string()), - )), - ], - )], + RecordBatch::try_new( schema, + vec![ + Arc::new(Int64Array::from_iter_values(0..1000)), + Arc::new(StringArray::from_iter_values( + (0..1000).map(|i| i.to_string()), + )), + Arc::new(Float64Array::from_iter_values((0..1000).map(|i| i as f64))), + Arc::new(StringArray::from_iter_values( + (0..1000).map(|i| format!("{{\"i\":{}}}", i)), + )), + Arc::new(BinaryArray::from_iter_values( + (0..1000).map(|i| (i as u32).to_be_bytes().to_vec()), + )), + Arc::new(StringArray::from_iter_values( + (0..1000).map(|i| i.to_string()), + )), + ], ) + .unwrap() } struct TestFixture { diff --git a/rust/lancedb/src/table/datafusion/insert.rs b/rust/lancedb/src/table/datafusion/insert.rs index e3ee371ed..53ae9520c 100644 --- a/rust/lancedb/src/table/datafusion/insert.rs +++ b/rust/lancedb/src/table/datafusion/insert.rs @@ -222,7 +222,7 @@ mod tests { use std::vec; use super::*; - use arrow_array::{record_batch, Int32Array, RecordBatchIterator}; + use arrow_array::{record_batch, RecordBatchIterator}; use datafusion::prelude::SessionContext; use datafusion_catalog::MemTable; use tempfile::tempdir; @@ -238,11 +238,8 @@ mod tests { // Create initial table let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); - let schema = batch.schema(); - let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); - let table = db - .create_table("test_insert", Box::new(reader)) + .create_table("test_insert", batch) .execute() .await .unwrap(); @@ -279,11 +276,8 @@ mod tests { // Create initial table with 3 rows let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); - let schema = batch.schema(); - let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); - let table = db - .create_table("test_overwrite", Box::new(reader)) + .create_table("test_overwrite", batch) .execute() .await .unwrap(); @@ -318,20 +312,9 @@ mod tests { let db = connect(uri).execute().await.unwrap(); // Create initial table - let schema = Arc::new(ArrowSchema::new(vec![Field::new( - "id", - DataType::Int32, - false, - )])); - let batches = vec![RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - ) - .unwrap()]; - let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone()); - + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); let table = db - .create_table("test_empty", Box::new(reader)) + .create_table("test_empty", batch) .execute() .await .unwrap(); @@ -352,12 +335,13 @@ mod tests { false, )])); // Empty batches - let source_reader = RecordBatchIterator::new( - std::iter::empty::>(), - source_schema, - ); + let source_reader: Box = + Box::new(RecordBatchIterator::new( + std::iter::empty::>(), + source_schema, + )); let source_table = db - .create_table("empty_source", Box::new(source_reader)) + .create_table("empty_source", source_reader) .execute() .await .unwrap(); @@ -389,20 +373,10 @@ mod tests { let db = connect(uri).execute().await.unwrap(); // Create initial table - let schema = Arc::new(ArrowSchema::new(vec![Field::new( - "id", - DataType::Int32, - true, - )])); - let batches = - vec![ - RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from(vec![1]))]) - .unwrap(), - ]; - let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone()); - + let batch = record_batch!(("id", Int32, [1])).unwrap(); + let schema = batch.schema(); let table = db - .create_table("test_multi_batch", Box::new(reader)) + .create_table("test_multi_batch", batch) .execute() .await .unwrap(); diff --git a/rust/lancedb/src/table/datafusion/udtf/fts.rs b/rust/lancedb/src/table/datafusion/udtf/fts.rs index c3eea0574..2f8b8beb8 100644 --- a/rust/lancedb/src/table/datafusion/udtf/fts.rs +++ b/rust/lancedb/src/table/datafusion/udtf/fts.rs @@ -97,7 +97,7 @@ mod tests { table::datafusion::BaseTableAdapter, Connection, Table, }; - use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray}; + use arrow_array::{Int32Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema as ArrowSchema}; use datafusion::prelude::SessionContext; @@ -173,14 +173,7 @@ mod tests { // Create LanceDB database and table let db = crate::connect("memory://test").execute().await.unwrap(); - let table = db - .create_table( - "foo", - RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema), - ) - .execute() - .await - .unwrap(); + let table = db.create_table("foo", batch).execute().await.unwrap(); // Create FTS index table @@ -323,13 +316,7 @@ mod tests { RecordBatch::try_new(metadata_schema.clone(), vec![metadata_col, extra_col]).unwrap(); let _metadata_table = db - .create_table( - "metadata", - RecordBatchIterator::new( - vec![Ok(metadata_batch.clone())].into_iter(), - metadata_schema.clone(), - ), - ) + .create_table("metadata", metadata_batch.clone()) .execute() .await .unwrap(); @@ -393,14 +380,7 @@ mod tests { let batch = RecordBatch::try_new(schema.clone(), vec![id_col, text_col, category_col]).unwrap(); - let table = db - .create_table( - table_name, - RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema), - ) - .execute() - .await - .unwrap(); + let table = db.create_table(table_name, batch).execute().await.unwrap(); // Create FTS index table @@ -546,14 +526,7 @@ mod tests { ])); let batch = RecordBatch::try_new(schema.clone(), vec![id_col, text_col]).unwrap(); - let table = db - .create_table( - "docs", - RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema), - ) - .execute() - .await - .unwrap(); + let table = db.create_table("docs", batch).execute().await.unwrap(); // Create FTS index with position information for phrase queries table @@ -691,14 +664,7 @@ mod tests { let batch = RecordBatch::try_new(schema.clone(), vec![id_col, title_col, content_col]).unwrap(); - let table = db - .create_table( - "multi_col", - RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema), - ) - .execute() - .await - .unwrap(); + let table = db.create_table("multi_col", batch).execute().await.unwrap(); // Create FTS indices on both columns table @@ -963,13 +929,7 @@ mod tests { let metadata_batch = RecordBatch::try_new(metadata_schema.clone(), vec![metadata_id, extra_info]).unwrap(); let _metadata_table = db - .create_table( - "metadata", - RecordBatchIterator::new( - vec![Ok(metadata_batch.clone())].into_iter(), - metadata_schema, - ), - ) + .create_table("metadata", metadata_batch.clone()) .execute() .await .unwrap(); @@ -1358,14 +1318,7 @@ mod tests { ])); let batch = RecordBatch::try_new(schema.clone(), vec![id_col, text_col]).unwrap(); - let table = db - .create_table( - "docs", - RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema), - ) - .execute() - .await - .unwrap(); + let table = db.create_table("docs", batch).execute().await.unwrap(); // Create FTS index with position information table @@ -1510,14 +1463,7 @@ mod tests { let batch = RecordBatch::try_new(schema.clone(), vec![id_col, title_col, content_col]).unwrap(); - let table = db - .create_table( - "docs", - RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema), - ) - .execute() - .await - .unwrap(); + let table = db.create_table("docs", batch).execute().await.unwrap(); // Create FTS indices on both columns table @@ -1591,14 +1537,7 @@ mod tests { let batch = RecordBatch::try_new(schema.clone(), vec![id_col, title_col, content_col]).unwrap(); - let table = db - .create_table( - "docs", - RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema), - ) - .execute() - .await - .unwrap(); + let table = db.create_table("docs", batch).execute().await.unwrap(); // Create FTS indices table @@ -1724,36 +1663,23 @@ mod tests { .unwrap(); // Create table with simple text for n-gram testing - let data = RecordBatchIterator::new( - vec![RecordBatch::try_new( - Arc::new(ArrowSchema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("text", DataType::Utf8, false), - ])), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(StringArray::from(vec![ - "hello world", - "lance database", - "lance is cool", - ])), - ], - ) - .unwrap()] - .into_iter() - .map(Ok), + let data = RecordBatch::try_new( Arc::new(ArrowSchema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("text", DataType::Utf8, false), ])), - ); + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![ + "hello world", + "lance database", + "lance is cool", + ])), + ], + ) + .unwrap(); - let table = Arc::new( - db.create_table("docs", Box::new(data)) - .execute() - .await - .unwrap(), - ); + let table = Arc::new(db.create_table("docs", data).execute().await.unwrap()); // Create FTS index with n-gram tokenizer (default min_ngram_length=3) table @@ -1876,43 +1802,29 @@ mod tests { .unwrap(); // Create table with two text columns - let data = RecordBatchIterator::new( - vec![RecordBatch::try_new( - Arc::new(ArrowSchema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("title", DataType::Utf8, false), - Field::new("content", DataType::Utf8, false), - ])), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(StringArray::from(vec![ - "Important Document", - "Another Document", - "Random Text", - ])), - Arc::new(StringArray::from(vec![ - "This is important information", - "This has details", - "Nothing special here", - ])), - ], - ) - .unwrap()] - .into_iter() - .map(Ok), + let data = RecordBatch::try_new( Arc::new(ArrowSchema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("title", DataType::Utf8, false), Field::new("content", DataType::Utf8, false), ])), - ); + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![ + "Important Document", + "Another Document", + "Random Text", + ])), + Arc::new(StringArray::from(vec![ + "This is important information", + "This has details", + "Nothing special here", + ])), + ], + ) + .unwrap(); - let table = Arc::new( - db.create_table("docs", Box::new(data)) - .execute() - .await - .unwrap(), - ); + let table = Arc::new(db.create_table("docs", data).execute().await.unwrap()); // Create FTS indices on both columns table diff --git a/rust/lancedb/src/table/delete.rs b/rust/lancedb/src/table/delete.rs index 7d17f46a2..4c4c304ba 100644 --- a/rust/lancedb/src/table/delete.rs +++ b/rust/lancedb/src/table/delete.rs @@ -34,7 +34,7 @@ pub(crate) async fn execute_delete(table: &NativeTable, predicate: &str) -> Resu #[cfg(test)] mod tests { use crate::connect; - use arrow_array::{record_batch, Int32Array, RecordBatch, RecordBatchIterator}; + use arrow_array::{record_batch, Int32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; use std::sync::Arc; @@ -53,10 +53,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_delete", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_delete", batch) .execute() .await .unwrap(); @@ -102,10 +99,7 @@ mod tests { let original_schema = batch.schema(); let table = conn - .create_table( - "test_delete_all", - RecordBatchIterator::new(vec![Ok(batch)], original_schema.clone()), - ) + .create_table("test_delete_all", batch) .execute() .await .unwrap(); @@ -126,13 +120,8 @@ mod tests { // Create a table with 5 rows let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5])).unwrap(); - let schema = batch.schema(); - let table = conn - .create_table( - "test_delete_noop", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_delete_noop", batch) .execute() .await .unwrap(); diff --git a/rust/lancedb/src/table/optimize.rs b/rust/lancedb/src/table/optimize.rs index e278a8dbe..f75671e5b 100644 --- a/rust/lancedb/src/table/optimize.rs +++ b/rust/lancedb/src/table/optimize.rs @@ -212,7 +212,7 @@ pub(crate) async fn execute_optimize( #[cfg(test)] mod tests { - use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray}; + use arrow_array::{Int32Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use rstest::rstest; use std::sync::Arc; @@ -236,10 +236,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_compact", - RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), - ) + .create_table("test_compact", batch) .execute() .await .unwrap(); @@ -253,11 +250,7 @@ mod tests { ))], ) .unwrap(); - table - .add(RecordBatchIterator::new(vec![Ok(batch)], schema.clone())) - .execute() - .await - .unwrap(); + table.add(batch).execute().await.unwrap(); } // Verify we have multiple fragments before compaction @@ -322,10 +315,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_prune", - RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), - ) + .create_table("test_prune", batch) .execute() .await .unwrap(); @@ -339,11 +329,7 @@ mod tests { ))], ) .unwrap(); - table - .add(RecordBatchIterator::new(vec![Ok(batch)], schema.clone())) - .execute() - .await - .unwrap(); + table.add(batch).execute().await.unwrap(); } // Verify multiple versions exist @@ -405,10 +391,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_index_optimize", - RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), - ) + .create_table("test_index_optimize", batch) .execute() .await .unwrap(); @@ -426,11 +409,7 @@ mod tests { vec![Arc::new(Int32Array::from_iter_values(100..200))], ) .unwrap(); - table - .add(RecordBatchIterator::new(vec![Ok(batch)], schema.clone())) - .execute() - .await - .unwrap(); + table.add(batch).execute().await.unwrap(); // Verify index stats before optimization let indices = table.list_indices().await.unwrap(); @@ -474,10 +453,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_optimize_all", - RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), - ) + .create_table("test_optimize_all", batch) .execute() .await .unwrap(); @@ -491,11 +467,7 @@ mod tests { ))], ) .unwrap(); - table - .add(RecordBatchIterator::new(vec![Ok(batch)], schema.clone())) - .execute() - .await - .unwrap(); + table.add(batch).execute().await.unwrap(); } // Run all optimizations @@ -559,20 +531,13 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_deferred_remap", - RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone()), - ) + .create_table("test_deferred_remap", batch.clone()) .execute() .await .unwrap(); // Add more data - table - .add(RecordBatchIterator::new(vec![Ok(batch)], schema.clone())) - .execute() - .await - .unwrap(); + table.add(batch).execute().await.unwrap(); // Create an index table @@ -648,20 +613,13 @@ mod tests { let original_schema = batch.schema(); let table = conn - .create_table( - "test_schema_preserved", - RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone()), - ) + .create_table("test_schema_preserved", batch.clone()) .execute() .await .unwrap(); // Add more data - table - .add(RecordBatchIterator::new(vec![Ok(batch)], schema.clone())) - .execute() - .await - .unwrap(); + table.add(batch).execute().await.unwrap(); // Run compaction table @@ -703,10 +661,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_empty_optimize", - RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), - ) + .create_table("test_empty_optimize", batch) .execute() .await .unwrap(); @@ -752,19 +707,12 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_checkout_optimize", - RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone()), - ) + .create_table("test_checkout_optimize", batch.clone()) .execute() .await .unwrap(); - table - .add(RecordBatchIterator::new(vec![Ok(batch)], schema.clone())) - .execute() - .await - .unwrap(); + table.add(batch).execute().await.unwrap(); table.checkout(1).await.unwrap(); diff --git a/rust/lancedb/src/table/schema_evolution.rs b/rust/lancedb/src/table/schema_evolution.rs index 0f07e0a0b..3f774cd98 100644 --- a/rust/lancedb/src/table/schema_evolution.rs +++ b/rust/lancedb/src/table/schema_evolution.rs @@ -89,7 +89,7 @@ pub(crate) async fn execute_drop_columns( #[cfg(test)] mod tests { - use arrow_array::{record_batch, Int32Array, RecordBatchIterator, StringArray}; + use arrow_array::{record_batch, Int32Array, StringArray}; use arrow_schema::DataType; use futures::TryStreamExt; use lance::dataset::ColumnAlteration; @@ -105,13 +105,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_add_columns", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_add_columns", batch) .execute() .await .unwrap(); @@ -169,13 +165,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("x", Int32, [10, 20, 30])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_add_multi_columns", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_add_multi_columns", batch) .execute() .await .unwrap(); @@ -205,13 +197,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_add_const_column", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_add_const_column", batch) .execute() .await .unwrap(); @@ -255,13 +243,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("old_name", Int32, [1, 2, 3])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_alter_rename", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_alter_rename", batch) .execute() .await .unwrap(); @@ -304,10 +288,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_alter_nullable", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_alter_nullable", batch) .execute() .await .unwrap(); @@ -332,13 +313,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("num", Int32, [1, 2, 3])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_cast_type", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_cast_type", batch) .execute() .await .unwrap(); @@ -379,13 +356,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("num", Int32, [1, 2, 3])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_invalid_cast", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_invalid_cast", batch) .execute() .await .unwrap(); @@ -407,13 +380,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("a", Int32, [1, 2, 3]), ("b", Int32, [4, 5, 6])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_alter_multi", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_alter_multi", batch) .execute() .await .unwrap(); @@ -441,13 +410,9 @@ mod tests { let batch = record_batch!(("keep", Int32, [1, 2, 3]), ("remove", Int32, [4, 5, 6])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_drop_single", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_drop_single", batch) .execute() .await .unwrap(); @@ -478,13 +443,9 @@ mod tests { ("d", Int32, [7, 8]) ) .unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_drop_multi", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_drop_multi", batch) .execute() .await .unwrap(); @@ -511,13 +472,9 @@ mod tests { ("extra", Int32, [10, 20, 30]) ) .unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_drop_preserves", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_drop_preserves", batch) .execute() .await .unwrap(); @@ -567,13 +524,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("existing", Int32, [1, 2, 3])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_drop_nonexistent", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_drop_nonexistent", batch) .execute() .await .unwrap(); @@ -593,13 +546,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("existing", Int32, [1, 2, 3])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_alter_nonexistent", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_alter_nonexistent", batch) .execute() .await .unwrap(); @@ -623,13 +572,8 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("a", Int32, [1, 2, 3]), ("b", Int32, [4, 5, 6])).unwrap(); - let schema = batch.schema(); - let table = conn - .create_table( - "test_version_increment", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_version_increment", batch) .execute() .await .unwrap(); diff --git a/rust/lancedb/src/table/update.rs b/rust/lancedb/src/table/update.rs index 6616dddc2..1d5ca5f4d 100644 --- a/rust/lancedb/src/table/update.rs +++ b/rust/lancedb/src/table/update.rs @@ -117,9 +117,8 @@ mod tests { use crate::query::{ExecutableQuery, Select}; use arrow_array::{ record_batch, Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array, - Float64Array, Int32Array, Int64Array, LargeStringArray, RecordBatch, RecordBatchIterator, - RecordBatchReader, StringArray, TimestampMillisecondArray, TimestampNanosecondArray, - UInt32Array, + Float64Array, Int32Array, Int64Array, LargeStringArray, RecordBatch, StringArray, + TimestampMillisecondArray, TimestampNanosecondArray, UInt32Array, }; use arrow_data::ArrayDataBuilder; use arrow_schema::{ArrowError, DataType, Field, Schema, TimeUnit}; @@ -167,51 +166,46 @@ mod tests { ), ])); - let record_batch_iter = RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_iter_values(0..10)), - Arc::new(Int64Array::from_iter_values(0..10)), - Arc::new(UInt32Array::from_iter_values(0..10)), - Arc::new(StringArray::from_iter_values(vec![ - "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", - ])), - Arc::new(LargeStringArray::from_iter_values(vec![ - "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", - ])), - Arc::new(Float32Array::from_iter_values((0..10).map(|i| i as f32))), - Arc::new(Float64Array::from_iter_values((0..10).map(|i| i as f64))), - Arc::new(Into::::into(vec![ - true, false, true, false, true, false, true, false, true, false, - ])), - Arc::new(Date32Array::from_iter_values(0..10)), - Arc::new(TimestampNanosecondArray::from_iter_values(0..10)), - Arc::new(TimestampMillisecondArray::from_iter_values(0..10)), - Arc::new( - create_fixed_size_list( - Float32Array::from_iter_values((0..20).map(|i| i as f32)), - 2, - ) - .unwrap(), - ), - Arc::new( - create_fixed_size_list( - Float64Array::from_iter_values((0..20).map(|i| i as f64)), - 2, - ) - .unwrap(), - ), - ], - ) - .unwrap()] - .into_iter() - .map(Ok), + let batch = RecordBatch::try_new( schema.clone(), - ); + vec![ + Arc::new(Int32Array::from_iter_values(0..10)), + Arc::new(Int64Array::from_iter_values(0..10)), + Arc::new(UInt32Array::from_iter_values(0..10)), + Arc::new(StringArray::from_iter_values(vec![ + "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", + ])), + Arc::new(LargeStringArray::from_iter_values(vec![ + "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", + ])), + Arc::new(Float32Array::from_iter_values((0..10).map(|i| i as f32))), + Arc::new(Float64Array::from_iter_values((0..10).map(|i| i as f64))), + Arc::new(Into::::into(vec![ + true, false, true, false, true, false, true, false, true, false, + ])), + Arc::new(Date32Array::from_iter_values(0..10)), + Arc::new(TimestampNanosecondArray::from_iter_values(0..10)), + Arc::new(TimestampMillisecondArray::from_iter_values(0..10)), + Arc::new( + create_fixed_size_list( + Float32Array::from_iter_values((0..20).map(|i| i as f32)), + 2, + ) + .unwrap(), + ), + Arc::new( + create_fixed_size_list( + Float64Array::from_iter_values((0..20).map(|i| i as f64)), + 2, + ) + .unwrap(), + ), + ], + ) + .unwrap(); let table = conn - .create_table("my_table", record_batch_iter) + .create_table("my_table", batch) .execute() .await .unwrap(); @@ -338,15 +332,13 @@ mod tests { Ok(FixedSizeListArray::from(data)) } - fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { + fn make_test_batch() -> RecordBatch { let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); - RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from_iter_values(0..10))], - )], - schema, + RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter_values(0..10))], ) + .unwrap() } #[tokio::test] @@ -367,12 +359,8 @@ mod tests { ) .unwrap(); - let schema = batch.schema(); - // need the iterator for create table - let record_batch_iter = RecordBatchIterator::new(vec![Ok(batch)], schema); - let table = conn - .create_table("my_table", record_batch_iter) + .create_table("my_table", batch) .execute() .await .unwrap(); @@ -430,7 +418,7 @@ mod tests { .await .unwrap(); let tbl = conn - .create_table("my_table", make_test_batches()) + .create_table("my_table", make_test_batch()) .execute() .await .unwrap(); diff --git a/rust/lancedb/src/test_utils.rs b/rust/lancedb/src/test_utils.rs index daf749bcc..1e7fc742b 100644 --- a/rust/lancedb/src/test_utils.rs +++ b/rust/lancedb/src/test_utils.rs @@ -3,3 +3,4 @@ pub mod connection; pub mod datagen; +pub mod embeddings; diff --git a/rust/lancedb/src/test_utils/datagen.rs b/rust/lancedb/src/test_utils/datagen.rs index 15b79e5d8..6439fbc5f 100644 --- a/rust/lancedb/src/test_utils/datagen.rs +++ b/rust/lancedb/src/test_utils/datagen.rs @@ -34,10 +34,7 @@ impl LanceDbDatagenExt for BatchGeneratorBuilder { schema, )); let db = connect("memory:///").execute().await.unwrap(); - db.create_table_streaming(table_name, stream) - .execute() - .await - .unwrap() + db.create_table(table_name, stream).execute().await.unwrap() } } @@ -48,8 +45,5 @@ pub async fn virtual_table(name: &str, values: &RecordBatch) -> Table { schema, )); let db = connect("memory:///").execute().await.unwrap(); - db.create_table_streaming(name, stream) - .execute() - .await - .unwrap() + db.create_table(name, stream).execute().await.unwrap() } diff --git a/rust/lancedb/src/test_utils/embeddings.rs b/rust/lancedb/src/test_utils/embeddings.rs new file mode 100644 index 000000000..48c9b5743 --- /dev/null +++ b/rust/lancedb/src/test_utils/embeddings.rs @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::{borrow::Cow, sync::Arc}; + +use arrow_array::{Array, FixedSizeListArray, Float32Array}; +use arrow_schema::{DataType, Field}; + +use crate::embeddings::EmbeddingFunction; +use crate::Result; + +#[derive(Debug, Clone)] +pub struct MockEmbed { + name: String, + dim: usize, +} + +impl MockEmbed { + pub fn new(name: impl Into, dim: usize) -> Self { + Self { + name: name.into(), + dim, + } + } +} + +impl EmbeddingFunction for MockEmbed { + fn name(&self) -> &str { + &self.name + } + + fn source_type(&self) -> Result> { + Ok(Cow::Borrowed(&DataType::Utf8)) + } + + fn dest_type(&self) -> Result> { + Ok(Cow::Owned(DataType::new_fixed_size_list( + DataType::Float32, + self.dim as _, + true, + ))) + } + + fn compute_source_embeddings(&self, source: Arc) -> Result> { + // We can't use the FixedSizeListBuilder here because it always adds a null bitmap + // and we want to explicitly work with non-nullable arrays. + let len = source.len(); + let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim])); + let field = Field::new("item", inner.data_type().clone(), false); + let arr = FixedSizeListArray::new(Arc::new(field), self.dim as _, inner, None); + + Ok(Arc::new(arr)) + } + + #[allow(unused_variables)] + fn compute_query_embeddings(&self, input: Arc) -> Result> { + todo!() + } +} diff --git a/rust/lancedb/tests/embedding_registry_test.rs b/rust/lancedb/tests/embedding_registry_test.rs index 4c636aad4..b002e47ff 100644 --- a/rust/lancedb/tests/embedding_registry_test.rs +++ b/rust/lancedb/tests/embedding_registry_test.rs @@ -15,7 +15,6 @@ use arrow_array::{ use arrow_schema::{DataType, Field, Schema}; use futures::StreamExt; use lancedb::{ - arrow::IntoArrow, connect, embeddings::{EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry}, query::ExecutableQuery, @@ -253,7 +252,7 @@ async fn test_no_func_in_registry_on_add() -> Result<()> { Ok(()) } -fn create_some_records() -> Result { +fn create_some_records() -> Result> { const TOTAL: usize = 2; let schema = Arc::new(Schema::new(vec![ diff --git a/rust/lancedb/tests/object_store_test.rs b/rust/lancedb/tests/object_store_test.rs index c874fefb9..0d06eb678 100644 --- a/rust/lancedb/tests/object_store_test.rs +++ b/rust/lancedb/tests/object_store_test.rs @@ -4,7 +4,7 @@ #![cfg(feature = "s3-test")] use std::sync::Arc; -use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray}; +use arrow_array::{Int32Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use aws_config::{BehaviorVersion, ConfigLoader, Region, SdkConfig}; @@ -111,7 +111,6 @@ async fn test_minio_lifecycle() -> Result<()> { .await?; let data = test_data(); - let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); let table = db.create_table("test_table", data).execute().await?; @@ -127,7 +126,6 @@ async fn test_minio_lifecycle() -> Result<()> { assert_eq!(row_count, 3); let data = test_data(); - let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); table.add(data).execute().await?; db.drop_table("test_table", &[]).await?; @@ -247,7 +245,6 @@ async fn test_encryption() -> Result<()> { // Create a table with encryption let data = test_data(); - let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); let mut builder = db.create_table("test_table", data); for (key, value) in CONFIG { @@ -274,7 +271,6 @@ async fn test_encryption() -> Result<()> { let table = db.open_table("test_table").execute().await?; let data = test_data(); - let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); table.add(data).execute().await?; validate_objects_encrypted(&bucket.0, "test_table", &key.0).await; @@ -300,7 +296,6 @@ async fn test_table_storage_options_override() -> Result<()> { // Create table overriding with key2 encryption let data = test_data(); - let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); let _table = db .create_table("test_override", data) .storage_option("aws_sse_kms_key_id", &key2.0) @@ -312,7 +307,6 @@ async fn test_table_storage_options_override() -> Result<()> { // Also test that a table created without override uses connection settings let data = test_data(); - let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); let _table2 = db.create_table("test_inherit", data).execute().await?; // Verify this table uses key1 from connection @@ -419,7 +413,6 @@ async fn test_concurrent_dynamodb_commit() { .unwrap(); let data = test_data(); - let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); let table = db.create_table("test_table", data).execute().await.unwrap(); @@ -430,7 +423,6 @@ async fn test_concurrent_dynamodb_commit() { let table = db.open_table("test_table").execute().await.unwrap(); let data = data.clone(); tasks.push(tokio::spawn(async move { - let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); table.add(data).execute().await.unwrap(); })); }