diff --git a/nodejs/src/connection.rs b/nodejs/src/connection.rs index a117ab41d..824ae7562 100644 --- a/nodejs/src/connection.rs +++ b/nodejs/src/connection.rs @@ -13,6 +13,7 @@ use crate::header::JsHeaderProvider; use crate::table::Table; use crate::ConnectionOptions; use lancedb::connection::{ConnectBuilder, Connection as LanceDBConnection}; + use lancedb::ipc::{ipc_file_to_batches, ipc_file_to_schema}; #[napi] diff --git a/python/src/connection.rs b/python/src/connection.rs index 2bd333ea0..7df89b101 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -121,7 +121,8 @@ impl Connection { let mode = Self::parse_create_mode_str(mode)?; - let batches = ArrowArrayStreamReader::from_pyarrow_bound(&data)?; + let batches: Box = + Box::new(ArrowArrayStreamReader::from_pyarrow_bound(&data)?); let mut builder = inner.create_table(name, batches).mode(mode); diff --git a/python/src/table.rs b/python/src/table.rs index 41e4df35f..353b22ff0 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -296,7 +296,8 @@ impl Table { data: Bound<'_, PyAny>, mode: String, ) -> PyResult> { - let batches = ArrowArrayStreamReader::from_pyarrow_bound(&data)?; + let batches: Box = + Box::new(ArrowArrayStreamReader::from_pyarrow_bound(&data)?); let mut op = self_.inner_ref()?.add(batches); if mode == "append" { op = op.mode(AddDataMode::Append); diff --git a/rust/lancedb/examples/bedrock.rs b/rust/lancedb/examples/bedrock.rs index 5cc7e0cbe..365453b60 100644 --- a/rust/lancedb/examples/bedrock.rs +++ b/rust/lancedb/examples/bedrock.rs @@ -3,13 +3,12 @@ use std::{iter::once, sync::Arc}; -use arrow_array::{Float64Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray}; +use arrow_array::{Float64Array, Int32Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use aws_config::Region; use aws_sdk_bedrockruntime::Client; use futures::StreamExt; use lancedb::{ - arrow::IntoArrow, connect, embeddings::{bedrock::BedrockEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction}, query::{ExecutableQuery, QueryBase}, @@ -67,7 +66,7 @@ async fn main() -> Result<()> { Ok(()) } -fn make_data() -> impl IntoArrow { +fn make_data() -> RecordBatch { let schema = Schema::new(vec![ Field::new("id", DataType::Int32, true), Field::new("text", DataType::Utf8, false), @@ -83,10 +82,9 @@ fn make_data() -> impl IntoArrow { ]); let price = Float64Array::from(vec![10.0, 50.0, 100.0, 30.0]); let schema = Arc::new(schema); - let rb = RecordBatch::try_new( + RecordBatch::try_new( schema.clone(), vec![Arc::new(id), Arc::new(text), Arc::new(price)], ) - .unwrap(); - Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema)) + .unwrap() } diff --git a/rust/lancedb/examples/full_text_search.rs b/rust/lancedb/examples/full_text_search.rs index 264c41487..15d7d7d88 100644 --- a/rust/lancedb/examples/full_text_search.rs +++ b/rust/lancedb/examples/full_text_search.rs @@ -3,12 +3,13 @@ use std::sync::Arc; -use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray}; +use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray}; use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; use lance_index::scalar::FullTextSearchQuery; use lancedb::connection::Connection; + use lancedb::index::scalar::FtsIndexBuilder; use lancedb::index::Index; use lancedb::query::{ExecutableQuery, QueryBase}; @@ -29,7 +30,7 @@ async fn main() -> Result<()> { Ok(()) } -fn create_some_records() -> Result> { +fn create_some_records() -> Result> { const TOTAL: usize = 1000; let schema = Arc::new(Schema::new(vec![ @@ -66,7 +67,7 @@ fn create_some_records() -> Result> { } async fn create_table(db: &Connection) -> Result { - let initial_data: Box = create_some_records()?; + let initial_data = create_some_records()?; let tbl = db.create_table("my_table", initial_data).execute().await?; Ok(tbl) } diff --git a/rust/lancedb/examples/hybrid_search.rs b/rust/lancedb/examples/hybrid_search.rs index 8a8dcda51..dc1a80a39 100644 --- a/rust/lancedb/examples/hybrid_search.rs +++ b/rust/lancedb/examples/hybrid_search.rs @@ -1,14 +1,13 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors -use arrow_array::{RecordBatch, RecordBatchIterator, StringArray}; +use arrow_array::{RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; use lance_index::scalar::FullTextSearchQuery; use lancedb::index::scalar::FtsIndexBuilder; use lancedb::index::Index; use lancedb::{ - arrow::IntoArrow, connect, embeddings::{ sentence_transformers::SentenceTransformersEmbeddings, EmbeddingDefinition, @@ -70,7 +69,7 @@ async fn main() -> Result<()> { Ok(()) } -fn make_data() -> impl IntoArrow { +fn make_data() -> RecordBatch { let schema = Schema::new(vec![Field::new("facts", DataType::Utf8, false)]); let facts = StringArray::from_iter_values(vec![ @@ -101,8 +100,7 @@ fn make_data() -> impl IntoArrow { "The first chatbot was ELIZA, created in the 1960s.", ]); let schema = Arc::new(schema); - let rb = RecordBatch::try_new(schema.clone(), vec![Arc::new(facts)]).unwrap(); - Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema)) + RecordBatch::try_new(schema.clone(), vec![Arc::new(facts)]).unwrap() } async fn create_index(table: &Table) -> Result<()> { diff --git a/rust/lancedb/examples/ivf_pq.rs b/rust/lancedb/examples/ivf_pq.rs index 9ce561c13..d171d5a45 100644 --- a/rust/lancedb/examples/ivf_pq.rs +++ b/rust/lancedb/examples/ivf_pq.rs @@ -8,13 +8,12 @@ use std::sync::Arc; use arrow_array::types::Float32Type; -use arrow_array::{ - FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader, -}; +use arrow_array::{FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; use lancedb::connection::Connection; + use lancedb::index::vector::IvfPqIndexBuilder; use lancedb::index::Index; use lancedb::query::{ExecutableQuery, QueryBase}; @@ -34,7 +33,7 @@ async fn main() -> Result<()> { Ok(()) } -fn create_some_records() -> Result> { +fn create_some_records() -> Result> { const TOTAL: usize = 1000; const DIM: usize = 128; @@ -73,9 +72,9 @@ fn create_some_records() -> Result> { } async fn create_table(db: &Connection) -> Result
{ - let initial_data: Box = create_some_records()?; + let initial_data = create_some_records()?; let tbl = db - .create_table("my_table", Box::new(initial_data)) + .create_table("my_table", initial_data) .execute() .await .unwrap(); diff --git a/rust/lancedb/examples/openai.rs b/rust/lancedb/examples/openai.rs index 73954f34a..2194e519a 100644 --- a/rust/lancedb/examples/openai.rs +++ b/rust/lancedb/examples/openai.rs @@ -5,11 +5,9 @@ use std::{iter::once, sync::Arc}; -use arrow_array::{Float64Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray}; -use arrow_schema::{DataType, Field, Schema}; +use arrow_array::{RecordBatch, StringArray}; use futures::StreamExt; use lancedb::{ - arrow::IntoArrow, connect, embeddings::{openai::OpenAIEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction}, query::{ExecutableQuery, QueryBase}, @@ -64,26 +62,20 @@ async fn main() -> Result<()> { } // --8<-- [end:openai_embeddings] -fn make_data() -> impl IntoArrow { - let schema = Schema::new(vec![ - Field::new("id", DataType::Int32, true), - Field::new("text", DataType::Utf8, false), - Field::new("price", DataType::Float64, false), - ]); - - let id = Int32Array::from(vec![1, 2, 3, 4]); - let text = StringArray::from_iter_values(vec![ - "Black T-Shirt", - "Leather Jacket", - "Winter Parka", - "Hooded Sweatshirt", - ]); - let price = Float64Array::from(vec![10.0, 50.0, 100.0, 30.0]); - let schema = Arc::new(schema); - let rb = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(id), Arc::new(text), Arc::new(price)], +fn make_data() -> RecordBatch { + arrow_array::record_batch!( + ("id", Int32, [1, 2, 3, 4]), + ( + "text", + Utf8, + [ + "Black T-Shirt", + "Leather Jacket", + "Winter Parka", + "Hooded Sweatshirt" + ] + ), + ("price", Float64, [10.0, 50.0, 100.0, 30.0]) ) - .unwrap(); - Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema)) + .unwrap() } diff --git a/rust/lancedb/examples/sentence_transformers.rs b/rust/lancedb/examples/sentence_transformers.rs index 9250430d4..2fa7ed30b 100644 --- a/rust/lancedb/examples/sentence_transformers.rs +++ b/rust/lancedb/examples/sentence_transformers.rs @@ -3,11 +3,10 @@ use std::{iter::once, sync::Arc}; -use arrow_array::{RecordBatch, RecordBatchIterator, StringArray}; +use arrow_array::{RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use futures::StreamExt; use lancedb::{ - arrow::IntoArrow, connect, embeddings::{ sentence_transformers::SentenceTransformersEmbeddings, EmbeddingDefinition, @@ -59,7 +58,7 @@ async fn main() -> Result<()> { Ok(()) } -fn make_data() -> impl IntoArrow { +fn make_data() -> RecordBatch { let schema = Schema::new(vec![Field::new("facts", DataType::Utf8, false)]); let facts = StringArray::from_iter_values(vec![ @@ -90,6 +89,5 @@ fn make_data() -> impl IntoArrow { "The first chatbot was ELIZA, created in the 1960s.", ]); let schema = Arc::new(schema); - let rb = RecordBatch::try_new(schema.clone(), vec![Arc::new(facts)]).unwrap(); - Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema)) + RecordBatch::try_new(schema.clone(), vec![Arc::new(facts)]).unwrap() } diff --git a/rust/lancedb/examples/simple.rs b/rust/lancedb/examples/simple.rs index 57b7be70b..157083d0a 100644 --- a/rust/lancedb/examples/simple.rs +++ b/rust/lancedb/examples/simple.rs @@ -8,11 +8,9 @@ use std::sync::Arc; use arrow_array::types::Float32Type; -use arrow_array::{FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterator}; +use arrow_array::{FixedSizeListArray, Int32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; - -use lancedb::arrow::IntoArrow; use lancedb::connection::Connection; use lancedb::index::Index; use lancedb::query::{ExecutableQuery, QueryBase}; @@ -59,7 +57,7 @@ async fn open_with_existing_tbl() -> Result<()> { Ok(()) } -fn create_some_records() -> Result { +fn create_some_records() -> Result { const TOTAL: usize = 1000; const DIM: usize = 128; @@ -76,25 +74,18 @@ fn create_some_records() -> Result { ])); // Create a RecordBatch stream. - let batches = RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)), - Arc::new( - FixedSizeListArray::from_iter_primitive::( - (0..TOTAL).map(|_| Some(vec![Some(1.0); DIM])), - DIM as i32, - ), - ), - ], - ) - .unwrap()] - .into_iter() - .map(Ok), + Ok(RecordBatch::try_new( schema.clone(), - ); - Ok(Box::new(batches)) + vec![ + Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)), + Arc::new( + FixedSizeListArray::from_iter_primitive::( + (0..TOTAL).map(|_| Some(vec![Some(1.0); DIM])), + DIM as i32, + ), + ), + ], + )?) } async fn create_table(db: &Connection) -> Result { diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index f58aba4ea..1e22e7c8f 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -6,8 +6,8 @@ use std::collections::HashMap; use std::sync::Arc; -use arrow_array::RecordBatchReader; -use arrow_schema::{Field, SchemaRef}; +use arrow_array::RecordBatch; +use arrow_schema::SchemaRef; use lance::dataset::ReadParams; use lance_namespace::models::{ CreateNamespaceRequest, CreateNamespaceResponse, DescribeNamespaceRequest, @@ -17,24 +17,20 @@ use lance_namespace::models::{ #[cfg(feature = "aws")] use object_store::aws::AwsCredential; -use crate::arrow::{IntoArrow, IntoArrowStream, SendableRecordBatchStream}; -use crate::database::listing::{ - ListingDatabase, OPT_NEW_TABLE_STORAGE_VERSION, OPT_NEW_TABLE_V2_MANIFEST_PATHS, -}; +use crate::connection::create_table::CreateTableBuilder; +use crate::data::scannable::Scannable; +use crate::database::listing::ListingDatabase; use crate::database::{ - CloneTableRequest, CreateTableData, CreateTableMode, CreateTableRequest, Database, - DatabaseOptions, OpenTableRequest, ReadConsistency, TableNamesRequest, -}; -use crate::embeddings::{ - EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings, + CloneTableRequest, Database, DatabaseOptions, OpenTableRequest, ReadConsistency, + TableNamesRequest, }; +use crate::embeddings::{EmbeddingRegistry, MemoryRegistry}; use crate::error::{Error, Result}; #[cfg(feature = "remote")] use crate::remote::{ client::ClientConfig, db::{OPT_REMOTE_API_KEY, OPT_REMOTE_HOST_OVERRIDE, OPT_REMOTE_REGION}, }; -use crate::table::{TableDefinition, WriteOptions}; use crate::Table; use lance::io::ObjectStoreParams; pub use lance_encoding::version::LanceFileVersion; @@ -42,6 +38,8 @@ pub use lance_encoding::version::LanceFileVersion; use lance_io::object_store::StorageOptions; use lance_io::object_store::{StorageOptionsAccessor, StorageOptionsProvider}; +mod create_table; + fn merge_storage_options( store_params: &mut ObjectStoreParams, pairs: impl IntoIterator, @@ -116,337 +114,6 @@ impl TableNamesBuilder { } } -pub struct NoData {} - -impl IntoArrow for NoData { - fn into_arrow(self) -> Result> { - unreachable!("NoData should never be converted to Arrow") - } -} - -// Stores the value given from the initial CreateTableBuilder::new call -// and defers errors until `execute` is called -enum CreateTableBuilderInitialData { - None, - Iterator(Result>), - Stream(Result), -} - -/// A builder for configuring a [`Connection::create_table`] operation -pub struct CreateTableBuilder { - parent: Arc, - embeddings: Vec<(EmbeddingDefinition, Arc)>, - embedding_registry: Arc, - request: CreateTableRequest, - // This is a bit clumsy but we defer errors until `execute` is called - // to maintain backwards compatibility - data: CreateTableBuilderInitialData, -} - -// Builder methods that only apply when we have initial data -impl CreateTableBuilder { - fn new( - parent: Arc, - name: String, - data: T, - embedding_registry: Arc, - ) -> Self { - let dummy_schema = Arc::new(arrow_schema::Schema::new(Vec::::default())); - Self { - parent, - request: CreateTableRequest::new( - name, - CreateTableData::Empty(TableDefinition::new_from_schema(dummy_schema)), - ), - embeddings: Vec::new(), - embedding_registry, - data: CreateTableBuilderInitialData::Iterator(data.into_arrow()), - } - } - - fn new_streaming( - parent: Arc, - name: String, - data: T, - embedding_registry: Arc, - ) -> Self { - let dummy_schema = Arc::new(arrow_schema::Schema::new(Vec::::default())); - Self { - parent, - request: CreateTableRequest::new( - name, - CreateTableData::Empty(TableDefinition::new_from_schema(dummy_schema)), - ), - embeddings: Vec::new(), - embedding_registry, - data: CreateTableBuilderInitialData::Stream(data.into_arrow()), - } - } - - /// Execute the create table operation - pub async fn execute(self) -> Result
{ - let embedding_registry = self.embedding_registry.clone(); - let parent = self.parent.clone(); - let request = self.into_request()?; - Ok(Table::new_with_embedding_registry( - parent.create_table(request).await?, - parent, - embedding_registry, - )) - } - - fn into_request(self) -> Result { - if self.embeddings.is_empty() { - match self.data { - CreateTableBuilderInitialData::Iterator(maybe_iter) => { - let data = maybe_iter?; - Ok(CreateTableRequest { - data: CreateTableData::Data(data), - ..self.request - }) - } - CreateTableBuilderInitialData::None => { - unreachable!("No data provided for CreateTableBuilder") - } - CreateTableBuilderInitialData::Stream(maybe_stream) => { - let data = maybe_stream?; - Ok(CreateTableRequest { - data: CreateTableData::StreamingData(data), - ..self.request - }) - } - } - } else { - let CreateTableBuilderInitialData::Iterator(maybe_iter) = self.data else { - return Err(Error::NotSupported { message: "Creating a table with embeddings is currently not support when the input is streaming".to_string() }); - }; - let data = maybe_iter?; - let data = Box::new(WithEmbeddings::new(data, self.embeddings)); - Ok(CreateTableRequest { - data: CreateTableData::Data(data), - ..self.request - }) - } - } -} - -// Builder methods that only apply when we do not have initial data -impl CreateTableBuilder { - fn new( - parent: Arc, - name: String, - schema: SchemaRef, - embedding_registry: Arc, - ) -> Self { - let table_definition = TableDefinition::new_from_schema(schema); - Self { - parent, - request: CreateTableRequest::new(name, CreateTableData::Empty(table_definition)), - data: CreateTableBuilderInitialData::None, - embeddings: Vec::default(), - embedding_registry, - } - } - - /// Execute the create table operation - pub async fn execute(self) -> Result
{ - let parent = self.parent.clone(); - let embedding_registry = self.embedding_registry.clone(); - let request = self.into_request()?; - Ok(Table::new_with_embedding_registry( - parent.create_table(request).await?, - parent, - embedding_registry, - )) - } - - fn into_request(self) -> Result { - if self.embeddings.is_empty() { - return Ok(self.request); - } - - let CreateTableData::Empty(table_def) = self.request.data else { - unreachable!("CreateTableBuilder should always have Empty data") - }; - - let schema = table_def.schema.clone(); - let empty_batch = arrow_array::RecordBatch::new_empty(schema.clone()); - - let reader = Box::new(std::iter::once(Ok(empty_batch)).collect::>()); - let reader = arrow_array::RecordBatchIterator::new(reader.into_iter(), schema); - let with_embeddings = WithEmbeddings::new(reader, self.embeddings); - let table_definition = with_embeddings.table_definition()?; - - Ok(CreateTableRequest { - data: CreateTableData::Empty(table_definition), - ..self.request - }) - } -} - -impl CreateTableBuilder { - /// Set the mode for creating the table - /// - /// This controls what happens if a table with the given name already exists - pub fn mode(mut self, mode: CreateTableMode) -> Self { - self.request.mode = mode; - self - } - - /// Apply the given write options when writing the initial data - pub fn write_options(mut self, write_options: WriteOptions) -> Self { - self.request.write_options = write_options; - self - } - - /// Set an option for the storage layer. - /// - /// Options already set on the connection will be inherited by the table, - /// but can be overridden here. - /// - /// See available options at - pub fn storage_option(mut self, key: impl Into, value: impl Into) -> Self { - let store_params = self - .request - .write_options - .lance_write_params - .get_or_insert(Default::default()) - .store_params - .get_or_insert(Default::default()); - merge_storage_options(store_params, [(key.into(), value.into())]); - self - } - - /// Set multiple options for the storage layer. - /// - /// Options already set on the connection will be inherited by the table, - /// but can be overridden here. - /// - /// See available options at - pub fn storage_options( - mut self, - pairs: impl IntoIterator, impl Into)>, - ) -> Self { - let store_params = self - .request - .write_options - .lance_write_params - .get_or_insert(Default::default()) - .store_params - .get_or_insert(Default::default()); - let updates = pairs - .into_iter() - .map(|(key, value)| (key.into(), value.into())); - merge_storage_options(store_params, updates); - self - } - - /// Add an embedding definition to the table. - /// - /// The `embedding_name` must match the name of an embedding function that - /// was previously registered with the connection's [`EmbeddingRegistry`]. - pub fn add_embedding(mut self, definition: EmbeddingDefinition) -> Result { - // Early verification of the embedding name - let embedding_func = self - .embedding_registry - .get(&definition.embedding_name) - .ok_or_else(|| Error::EmbeddingFunctionNotFound { - name: definition.embedding_name.clone(), - reason: "No embedding function found in the connection's embedding_registry" - .to_string(), - })?; - - self.embeddings.push((definition, embedding_func)); - Ok(self) - } - - /// Set whether to use V2 manifest paths for the table. (default: false) - /// - /// These paths provide more efficient opening of tables with many - /// versions on object stores. - /// - ///
Turning this on will make the dataset unreadable - /// for older versions of LanceDB (prior to 0.10.0).
- /// - /// To migrate an existing dataset, instead use the - /// [[NativeTable::migrate_manifest_paths_v2]]. - /// - /// This has no effect in LanceDB Cloud. - #[deprecated(since = "0.15.1", note = "Use `database_options` instead")] - pub fn enable_v2_manifest_paths(mut self, use_v2_manifest_paths: bool) -> Self { - let store_params = self - .request - .write_options - .lance_write_params - .get_or_insert_with(Default::default) - .store_params - .get_or_insert_with(Default::default); - let value = if use_v2_manifest_paths { - "true".to_string() - } else { - "false".to_string() - }; - merge_storage_options( - store_params, - [(OPT_NEW_TABLE_V2_MANIFEST_PATHS.to_string(), value)], - ); - self - } - - /// Set the data storage version. - /// - /// The default is `LanceFileVersion::Stable`. - #[deprecated(since = "0.15.1", note = "Use `database_options` instead")] - pub fn data_storage_version(mut self, data_storage_version: LanceFileVersion) -> Self { - let store_params = self - .request - .write_options - .lance_write_params - .get_or_insert_with(Default::default) - .store_params - .get_or_insert_with(Default::default); - merge_storage_options( - store_params, - [( - OPT_NEW_TABLE_STORAGE_VERSION.to_string(), - data_storage_version.to_string(), - )], - ); - self - } - - /// Set the namespace for the table - pub fn namespace(mut self, namespace: Vec) -> Self { - self.request.namespace = namespace; - self - } - - /// Set a custom location for the table. - /// - /// If not set, the database will derive a location from its URI and the table name. - /// This is useful when integrating with namespace systems that manage table locations. - pub fn location(mut self, location: impl Into) -> Self { - self.request.location = Some(location.into()); - self - } - - /// Set a storage options provider for automatic credential refresh. - /// - /// This allows tables to automatically refresh cloud storage credentials - /// when they expire, enabling long-running operations on remote storage. - pub fn storage_options_provider(mut self, provider: Arc) -> Self { - let store_params = self - .request - .write_options - .lance_write_params - .get_or_insert(Default::default()) - .store_params - .get_or_insert(Default::default()); - set_storage_options_provider(store_params, provider); - self - } -} - #[derive(Clone, Debug)] pub struct OpenTableBuilder { parent: Arc, @@ -684,35 +351,17 @@ impl Connection { /// /// * `name` - The name of the table /// * `initial_data` - The initial data to write to the table - pub fn create_table( + pub fn create_table( &self, name: impl Into, initial_data: T, - ) -> CreateTableBuilder { - CreateTableBuilder::::new( + ) -> CreateTableBuilder { + let initial_data = Box::new(initial_data); + CreateTableBuilder::new( self.internal.clone(), + self.embedding_registry.clone(), name.into(), initial_data, - self.embedding_registry.clone(), - ) - } - - /// Create a new table from a stream of data - /// - /// # Parameters - /// - /// * `name` - The name of the table - /// * `initial_data` - The initial data to write to the table - pub fn create_table_streaming( - &self, - name: impl Into, - initial_data: T, - ) -> CreateTableBuilder { - CreateTableBuilder::::new_streaming( - self.internal.clone(), - name.into(), - initial_data, - self.embedding_registry.clone(), ) } @@ -726,13 +375,9 @@ impl Connection { &self, name: impl Into, schema: SchemaRef, - ) -> CreateTableBuilder { - CreateTableBuilder::::new( - self.internal.clone(), - name.into(), - schema, - self.embedding_registry.clone(), - ) + ) -> CreateTableBuilder { + let empty_batch = RecordBatch::new_empty(schema); + self.create_table(name, empty_batch) } /// Open an existing table in the database @@ -1349,20 +994,11 @@ mod test_utils { #[cfg(test)] mod tests { - use crate::database::listing::{ListingDatabaseOptions, NewTableConfig}; - use crate::query::QueryBase; - use crate::query::{ExecutableQuery, QueryExecutionOptions}; - use crate::test_utils::connection::new_test_connection; - use arrow::compute::concat_batches; - use arrow_array::RecordBatchReader; use arrow_schema::{DataType, Field, Schema}; - use datafusion_physical_plan::stream::RecordBatchStreamAdapter; - use futures::{stream, TryStreamExt}; - use lance_core::error::{ArrowResult, DataFusionResult}; use lance_testing::datagen::{BatchGenerator, IncrementingInt32}; use tempfile::tempdir; - use crate::arrow::SimpleRecordBatchStream; + use crate::test_utils::connection::new_test_connection; use super::*; @@ -1478,139 +1114,6 @@ mod tests { assert_eq!(tables, vec!["table1".to_owned()]); } - fn make_data() -> Box { - let id = Box::new(IncrementingInt32::new().named("id".to_string())); - Box::new(BatchGenerator::new().col(id).batches(10, 2000)) - } - - #[tokio::test] - async fn test_create_table_v2() { - let tmp_dir = tempdir().unwrap(); - let uri = tmp_dir.path().to_str().unwrap(); - let db = connect(uri) - .database_options(&ListingDatabaseOptions { - new_table_config: NewTableConfig { - data_storage_version: Some(LanceFileVersion::Legacy), - ..Default::default() - }, - ..Default::default() - }) - .execute() - .await - .unwrap(); - - let tbl = db - .create_table("v1_test", make_data()) - .execute() - .await - .unwrap(); - - // In v1 the row group size will trump max_batch_length - let batches = tbl - .query() - .limit(20000) - .execute_with_options(QueryExecutionOptions { - max_batch_length: 50000, - ..Default::default() - }) - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - assert_eq!(batches.len(), 20); - - let db = connect(uri) - .database_options(&ListingDatabaseOptions { - new_table_config: NewTableConfig { - data_storage_version: Some(LanceFileVersion::Stable), - ..Default::default() - }, - ..Default::default() - }) - .execute() - .await - .unwrap(); - - let tbl = db - .create_table("v2_test", make_data()) - .execute() - .await - .unwrap(); - - // In v2 the page size is much bigger than 50k so we should get a single batch - let batches = tbl - .query() - .execute_with_options(QueryExecutionOptions { - max_batch_length: 50000, - ..Default::default() - }) - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - - assert_eq!(batches.len(), 1); - } - - #[tokio::test] - async fn test_create_table_streaming() { - let tmp_dir = tempdir().unwrap(); - - let uri = tmp_dir.path().to_str().unwrap(); - let db = connect(uri).execute().await.unwrap(); - - let batches = make_data().collect::>>().unwrap(); - - let schema = batches.first().unwrap().schema(); - let one_batch = concat_batches(&schema, batches.iter()).unwrap(); - - let ldb_stream = stream::iter(batches.clone().into_iter().map(Result::Ok)); - let ldb_stream: SendableRecordBatchStream = - Box::pin(SimpleRecordBatchStream::new(ldb_stream, schema.clone())); - - let tbl1 = db - .create_table_streaming("one", ldb_stream) - .execute() - .await - .unwrap(); - - let df_stream = stream::iter(batches.into_iter().map(DataFusionResult::Ok)); - let df_stream: datafusion_physical_plan::SendableRecordBatchStream = - Box::pin(RecordBatchStreamAdapter::new(schema.clone(), df_stream)); - - let tbl2 = db - .create_table_streaming("two", df_stream) - .execute() - .await - .unwrap(); - - let tbl1_data = tbl1 - .query() - .execute() - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - - let tbl1_data = concat_batches(&schema, tbl1_data.iter()).unwrap(); - assert_eq!(tbl1_data, one_batch); - - let tbl2_data = tbl2 - .query() - .execute() - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - - let tbl2_data = concat_batches(&schema, tbl2_data.iter()).unwrap(); - assert_eq!(tbl2_data, one_batch); - } - #[tokio::test] async fn drop_table() { let tc = new_test_connection().await.unwrap(); @@ -1640,41 +1143,6 @@ mod tests { assert_eq!(tables.len(), 0); } - #[tokio::test] - async fn test_create_table_already_exists() { - let tmp_dir = tempdir().unwrap(); - let uri = tmp_dir.path().to_str().unwrap(); - let db = connect(uri).execute().await.unwrap(); - let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)])); - db.create_empty_table("test", schema.clone()) - .execute() - .await - .unwrap(); - // TODO: None of the open table options are "inspectable" right now but once one is we - // should assert we are passing these options in correctly - db.create_empty_table("test", schema) - .mode(CreateTableMode::exist_ok(|mut req| { - req.index_cache_size = Some(16); - req - })) - .execute() - .await - .unwrap(); - let other_schema = Arc::new(Schema::new(vec![Field::new("y", DataType::Int32, false)])); - assert!(db - .create_empty_table("test", other_schema.clone()) - .execute() - .await - .is_err()); - let overwritten = db - .create_empty_table("test", other_schema.clone()) - .mode(CreateTableMode::Overwrite) - .execute() - .await - .unwrap(); - assert_eq!(other_schema, overwritten.schema().await.unwrap()); - } - #[tokio::test] async fn test_clone_table() { let tmp_dir = tempdir().unwrap(); @@ -1685,7 +1153,8 @@ mod tests { let mut batch_gen = BatchGenerator::new() .col(Box::new(IncrementingInt32::new().named("id"))) .col(Box::new(IncrementingInt32::new().named("value"))); - let reader = batch_gen.batches(5, 100); + let reader: Box = + Box::new(batch_gen.batches(5, 100)); let source_table = db .create_table("source_table", reader) @@ -1720,128 +1189,4 @@ mod tests { let cloned_count = cloned_table.count_rows(None).await.unwrap(); assert_eq!(source_count, cloned_count); } - - #[tokio::test] - async fn test_create_empty_table_with_embeddings() { - use crate::embeddings::{EmbeddingDefinition, EmbeddingFunction}; - use arrow_array::{ - Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray, - }; - use std::borrow::Cow; - - #[derive(Debug, Clone)] - struct MockEmbedding { - dim: usize, - } - - impl EmbeddingFunction for MockEmbedding { - fn name(&self) -> &str { - "test_embedding" - } - - fn source_type(&self) -> Result> { - Ok(Cow::Owned(DataType::Utf8)) - } - - fn dest_type(&self) -> Result> { - Ok(Cow::Owned(DataType::new_fixed_size_list( - DataType::Float32, - self.dim as i32, - true, - ))) - } - - fn compute_source_embeddings(&self, source: Arc) -> Result> { - let len = source.len(); - let values = vec![1.0f32; len * self.dim]; - let values = Arc::new(Float32Array::from(values)); - let field = Arc::new(Field::new("item", DataType::Float32, true)); - Ok(Arc::new(FixedSizeListArray::new( - field, - self.dim as i32, - values, - None, - ))) - } - - fn compute_query_embeddings(&self, _input: Arc) -> Result> { - unimplemented!() - } - } - - let tmp_dir = tempdir().unwrap(); - let uri = tmp_dir.path().to_str().unwrap(); - let db = connect(uri).execute().await.unwrap(); - - let embed_func = Arc::new(MockEmbedding { dim: 128 }); - db.embedding_registry() - .register("test_embedding", embed_func.clone()) - .unwrap(); - - let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)])); - let ed = EmbeddingDefinition { - source_column: "name".to_owned(), - dest_column: Some("name_embedding".to_owned()), - embedding_name: "test_embedding".to_owned(), - }; - - let table = db - .create_empty_table("test", schema) - .mode(CreateTableMode::Overwrite) - .add_embedding(ed) - .unwrap() - .execute() - .await - .unwrap(); - - let table_schema = table.schema().await.unwrap(); - assert!(table_schema.column_with_name("name").is_some()); - assert!(table_schema.column_with_name("name_embedding").is_some()); - - let embedding_field = table_schema.field_with_name("name_embedding").unwrap(); - assert_eq!( - embedding_field.data_type(), - &DataType::new_fixed_size_list(DataType::Float32, 128, true) - ); - - let input_schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)])); - let input_batch = RecordBatch::try_new( - input_schema.clone(), - vec![Arc::new(StringArray::from(vec![ - Some("Alice"), - Some("Bob"), - Some("Charlie"), - ]))], - ) - .unwrap(); - - let input_reader = Box::new(RecordBatchIterator::new( - vec![Ok(input_batch)].into_iter(), - input_schema, - )); - - table.add(input_reader).execute().await.unwrap(); - - let results = table - .query() - .execute() - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - - assert_eq!(results.len(), 1); - let batch = &results[0]; - assert_eq!(batch.num_rows(), 3); - assert!(batch.column_by_name("name_embedding").is_some()); - - let embedding_col = batch - .column_by_name("name_embedding") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(embedding_col.len(), 3); - } } diff --git a/rust/lancedb/src/connection/create_table.rs b/rust/lancedb/src/connection/create_table.rs new file mode 100644 index 000000000..8eb7d2207 --- /dev/null +++ b/rust/lancedb/src/connection/create_table.rs @@ -0,0 +1,612 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::sync::Arc; + +use lance_io::object_store::StorageOptionsProvider; + +use crate::{ + connection::{merge_storage_options, set_storage_options_provider}, + data::scannable::{Scannable, WithEmbeddingsScannable}, + database::{CreateTableMode, CreateTableRequest, Database}, + embeddings::{EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry}, + table::WriteOptions, + Error, Result, Table, +}; + +pub struct CreateTableBuilder { + parent: Arc, + embeddings: Vec<(EmbeddingDefinition, Arc)>, + embedding_registry: Arc, + request: CreateTableRequest, +} + +impl CreateTableBuilder { + pub(super) fn new( + parent: Arc, + embedding_registry: Arc, + name: String, + data: Box, + ) -> Self { + Self { + parent, + embeddings: Vec::new(), + embedding_registry, + request: CreateTableRequest::new(name, data), + } + } + + /// Set the mode for creating the table + /// + /// This controls what happens if a table with the given name already exists + pub fn mode(mut self, mode: CreateTableMode) -> Self { + self.request.mode = mode; + self + } + + /// Apply the given write options when writing the initial data + pub fn write_options(mut self, write_options: WriteOptions) -> Self { + self.request.write_options = write_options; + self + } + + /// Set an option for the storage layer. + /// + /// Options already set on the connection will be inherited by the table, + /// but can be overridden here. + /// + /// See available options at + pub fn storage_option(mut self, key: impl Into, value: impl Into) -> Self { + let store_params = self + .request + .write_options + .lance_write_params + .get_or_insert(Default::default()) + .store_params + .get_or_insert(Default::default()); + merge_storage_options(store_params, [(key.into(), value.into())]); + self + } + + /// Set multiple options for the storage layer. + /// + /// Options already set on the connection will be inherited by the table, + /// but can be overridden here. + /// + /// See available options at + pub fn storage_options( + mut self, + pairs: impl IntoIterator, impl Into)>, + ) -> Self { + let store_params = self + .request + .write_options + .lance_write_params + .get_or_insert(Default::default()) + .store_params + .get_or_insert(Default::default()); + let updates = pairs + .into_iter() + .map(|(key, value)| (key.into(), value.into())); + merge_storage_options(store_params, updates); + self + } + + /// Add an embedding definition to the table. + /// + /// The `embedding_name` must match the name of an embedding function that + /// was previously registered with the connection's [`EmbeddingRegistry`]. + pub fn add_embedding(mut self, definition: EmbeddingDefinition) -> Result { + // Early verification of the embedding name + let embedding_func = self + .embedding_registry + .get(&definition.embedding_name) + .ok_or_else(|| Error::EmbeddingFunctionNotFound { + name: definition.embedding_name.clone(), + reason: "No embedding function found in the connection's embedding_registry" + .to_string(), + })?; + + self.embeddings.push((definition, embedding_func)); + Ok(self) + } + + /// Set the namespace for the table + pub fn namespace(mut self, namespace: Vec) -> Self { + self.request.namespace = namespace; + self + } + + /// Set a custom location for the table. + /// + /// If not set, the database will derive a location from its URI and the table name. + /// This is useful when integrating with namespace systems that manage table locations. + pub fn location(mut self, location: impl Into) -> Self { + self.request.location = Some(location.into()); + self + } + + /// Set a storage options provider for automatic credential refresh. + /// + /// This allows tables to automatically refresh cloud storage credentials + /// when they expire, enabling long-running operations on remote storage. + pub fn storage_options_provider(mut self, provider: Arc) -> Self { + let store_params = self + .request + .write_options + .lance_write_params + .get_or_insert(Default::default()) + .store_params + .get_or_insert(Default::default()); + set_storage_options_provider(store_params, provider); + self + } + + /// Execute the create table operation + pub async fn execute(mut self) -> Result
{ + let embedding_registry = self.embedding_registry.clone(); + let parent = self.parent.clone(); + + // If embeddings were configured via add_embedding(), wrap the data + if !self.embeddings.is_empty() { + let wrapped_data: Box = Box::new(WithEmbeddingsScannable::try_new( + self.request.data, + self.embeddings, + )?); + self.request.data = wrapped_data; + } + + Ok(Table::new_with_embedding_registry( + parent.create_table(self.request).await?, + parent, + embedding_registry, + )) + } +} + +#[cfg(test)] +mod tests { + use arrow_array::{ + record_batch, Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, + }; + use arrow_schema::{ArrowError, DataType, Field, Schema}; + use futures::TryStreamExt; + use lance_file::version::LanceFileVersion; + use tempfile::tempdir; + + use crate::{ + arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}, + connect, + database::listing::{ListingDatabaseOptions, NewTableConfig}, + embeddings::{EmbeddingDefinition, EmbeddingFunction, MemoryRegistry}, + query::{ExecutableQuery, QueryBase, Select}, + test_utils::embeddings::MockEmbed, + }; + use std::borrow::Cow; + + use super::*; + + #[tokio::test] + async fn create_empty_table() { + let db = connect("memory://").execute().await.unwrap(); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("value", DataType::Float64, false), + ])); + db.create_empty_table("name", schema.clone()) + .execute() + .await + .unwrap(); + let table = db.open_table("name").execute().await.unwrap(); + assert_eq!(table.schema().await.unwrap(), schema); + assert_eq!(table.count_rows(None).await.unwrap(), 0); + } + + async fn test_create_table_with_data(data: T) + where + T: Scannable + 'static, + { + let db = connect("memory://").execute().await.unwrap(); + let schema = data.schema(); + db.create_table("data_table", data).execute().await.unwrap(); + let table = db.open_table("data_table").execute().await.unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), 3); + assert_eq!(table.schema().await.unwrap(), schema); + } + + #[tokio::test] + async fn create_table_with_batch() { + let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap(); + test_create_table_with_data(batch).await; + } + + #[tokio::test] + async fn test_create_table_with_vec_batch() { + let data = vec![ + record_batch!(("id", Int64, [1, 2])).unwrap(), + record_batch!(("id", Int64, [3])).unwrap(), + ]; + test_create_table_with_data(data).await; + } + + #[tokio::test] + async fn test_create_table_with_record_batch_reader() { + let data = vec![ + record_batch!(("id", Int64, [1, 2])).unwrap(), + record_batch!(("id", Int64, [3])).unwrap(), + ]; + let schema = data[0].schema(); + let reader: Box = Box::new( + RecordBatchIterator::new(data.into_iter().map(Ok), schema.clone()), + ); + test_create_table_with_data(reader).await; + } + + #[tokio::test] + async fn test_create_table_with_stream() { + let data = vec![ + record_batch!(("id", Int64, [1, 2])).unwrap(), + record_batch!(("id", Int64, [3])).unwrap(), + ]; + let schema = data[0].schema(); + let inner = futures::stream::iter(data.into_iter().map(Ok)); + let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream { + schema, + stream: inner, + }); + test_create_table_with_data(stream).await; + } + + #[derive(Debug)] + struct MyError; + + impl std::fmt::Display for MyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MyError occurred") + } + } + + impl std::error::Error for MyError {} + + #[tokio::test] + async fn test_create_preserves_reader_error() { + let first_batch = record_batch!(("id", Int64, [1, 2])).unwrap(); + let schema = first_batch.schema(); + let iterator = vec![ + Ok(first_batch), + Err(ArrowError::ExternalError(Box::new(MyError))), + ]; + let reader: Box = Box::new( + RecordBatchIterator::new(iterator.into_iter(), schema.clone()), + ); + + let db = connect("memory://").execute().await.unwrap(); + let result = db.create_table("failing_table", reader).execute().await; + + assert!(result.is_err()); + // TODO: when we upgrade to Lance 2.0.0, this should pass + // assert!(matches!(result, Err(Error::External { source}) + // if source.downcast_ref::().is_some() + // )); + } + + #[tokio::test] + async fn test_create_preserves_stream_error() { + let first_batch = record_batch!(("id", Int64, [1, 2])).unwrap(); + let schema = first_batch.schema(); + let iterator = vec![ + Ok(first_batch), + Err(Error::External { + source: Box::new(MyError), + }), + ]; + let stream = futures::stream::iter(iterator); + let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream { + schema: schema.clone(), + stream, + }); + + let db = connect("memory://").execute().await.unwrap(); + let result = db + .create_table("failing_stream_table", stream) + .execute() + .await; + + assert!(result.is_err()); + // TODO: when we upgrade to Lance 2.0.0, this should pass + // assert!(matches!(result, Err(Error::External { source}) + // if source.downcast_ref::().is_some() + // )); + } + + #[tokio::test] + #[allow(deprecated)] + async fn test_create_table_with_storage_options() { + let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap(); + let db = connect("memory://").execute().await.unwrap(); + + let table = db + .create_table("options_table", batch) + .storage_option("timeout", "30s") + .storage_options([("retry_count", "3")]) + .execute() + .await + .unwrap(); + + let final_options = table.storage_options().await.unwrap(); + assert_eq!(final_options.get("timeout"), Some(&"30s".to_string())); + assert_eq!(final_options.get("retry_count"), Some(&"3".to_string())); + } + + #[tokio::test] + async fn test_create_table_unregistered_embedding() { + let db = connect("memory://").execute().await.unwrap(); + let batch = record_batch!(("text", Utf8, ["hello", "world"])).unwrap(); + + // Try to add an embedding that doesn't exist in the registry + let result = db + .create_table("embed_table", batch) + .add_embedding(EmbeddingDefinition::new( + "text", + "nonexistent_embedding_function", + None::<&str>, + )); + + match result { + Err(Error::EmbeddingFunctionNotFound { name, .. }) => { + assert_eq!(name, "nonexistent_embedding_function"); + } + Err(other) => panic!("Expected EmbeddingFunctionNotFound error, got: {:?}", other), + Ok(_) => panic!("Expected error, but got Ok"), + } + } + + #[tokio::test] + async fn test_create_table_already_exists() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let db = connect(uri).execute().await.unwrap(); + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)])); + db.create_empty_table("test", schema.clone()) + .execute() + .await + .unwrap(); + db.create_empty_table("test", schema) + .mode(CreateTableMode::exist_ok(|mut req| { + req.index_cache_size = Some(16); + req + })) + .execute() + .await + .unwrap(); + let other_schema = Arc::new(Schema::new(vec![Field::new("y", DataType::Int32, false)])); + assert!(db + .create_empty_table("test", other_schema.clone()) + .execute() + .await + .is_err()); // TODO: assert what this error is + let overwritten = db + .create_empty_table("test", other_schema.clone()) + .mode(CreateTableMode::Overwrite) + .execute() + .await + .unwrap(); + assert_eq!(other_schema, overwritten.schema().await.unwrap()); + } + + #[tokio::test] + #[rstest::rstest] + #[case(LanceFileVersion::Legacy)] + #[case(LanceFileVersion::Stable)] + async fn test_create_table_with_storage_version( + #[case] data_storage_version: LanceFileVersion, + ) { + let db = connect("memory://") + .database_options(&ListingDatabaseOptions { + new_table_config: NewTableConfig { + data_storage_version: Some(data_storage_version), + ..Default::default() + }, + ..Default::default() + }) + .execute() + .await + .unwrap(); + + let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap(); + let table = db + .create_table("legacy_table", batch) + .execute() + .await + .unwrap(); + + let native_table = table.as_native().unwrap(); + let storage_format = native_table + .manifest() + .await + .unwrap() + .data_storage_format + .lance_file_version() + .unwrap(); + // Compare resolved versions since Stable/Next are aliases that resolve at storage time + assert_eq!(storage_format.resolve(), data_storage_version.resolve()); + } + + #[tokio::test] + async fn test_create_table_with_embedding() { + // Register the mock embedding function + let registry = Arc::new(MemoryRegistry::new()); + let mock_embedding: Arc = Arc::new(MockEmbed::new("mock", 4)); + registry.register("mock", mock_embedding).unwrap(); + + // Connect with the custom registry + let conn = connect("memory://") + .embedding_registry(registry) + .execute() + .await + .unwrap(); + + // Create data without the embedding column + let batch = record_batch!(("text", Utf8, ["hello", "world", "test"])).unwrap(); + + // Create table with add_embedding - embeddings should be computed automatically + let table = conn + .create_table("embed_test", batch) + .add_embedding(EmbeddingDefinition::new( + "text", + "mock", + Some("text_embedding"), + )) + .unwrap() + .execute() + .await + .unwrap(); + + // Verify row count + assert_eq!(table.count_rows(None).await.unwrap(), 3); + + // Verify the schema includes the embedding column + let result_schema = table.schema().await.unwrap(); + assert_eq!(result_schema.fields().len(), 2); + assert_eq!(result_schema.field(0).name(), "text"); + assert_eq!(result_schema.field(1).name(), "text_embedding"); + + // Verify the embedding column has the correct type + assert!(matches!( + result_schema.field(1).data_type(), + DataType::FixedSizeList(_, 4) + )); + + // Query to verify the embeddings were computed + let results: Vec = table + .query() + .select(Select::columns(&["text", "text_embedding"])) + .execute() + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 3); + + // Check that all rows have embedding values (not null) + for batch in &results { + let embedding_col = batch.column(1); + assert_eq!(embedding_col.null_count(), 0); + assert_eq!(embedding_col.len(), batch.num_rows()); + } + + // Verify the schema metadata contains the column definitions + assert!( + result_schema + .metadata + .contains_key("lancedb::column_definitions"), + "Schema metadata should contain column definitions" + ); + } + + #[tokio::test] + async fn test_create_empty_table_with_embeddings() { + #[derive(Debug, Clone)] + struct MockEmbedding { + dim: usize, + } + + impl EmbeddingFunction for MockEmbedding { + fn name(&self) -> &str { + "test_embedding" + } + + fn source_type(&self) -> Result> { + Ok(Cow::Owned(DataType::Utf8)) + } + + fn dest_type(&self) -> Result> { + Ok(Cow::Owned(DataType::new_fixed_size_list( + DataType::Float32, + self.dim as i32, + true, + ))) + } + + fn compute_source_embeddings(&self, source: Arc) -> Result> { + let len = source.len(); + let values = vec![1.0f32; len * self.dim]; + let values = Arc::new(Float32Array::from(values)); + let field = Arc::new(Field::new("item", DataType::Float32, true)); + Ok(Arc::new(FixedSizeListArray::new( + field, + self.dim as i32, + values, + None, + ))) + } + + fn compute_query_embeddings(&self, _input: Arc) -> Result> { + unimplemented!() + } + } + + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let db = connect(uri).execute().await.unwrap(); + + let embed_func = Arc::new(MockEmbedding { dim: 128 }); + db.embedding_registry() + .register("test_embedding", embed_func.clone()) + .unwrap(); + + let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)])); + let ed = EmbeddingDefinition { + source_column: "name".to_owned(), + dest_column: Some("name_embedding".to_owned()), + embedding_name: "test_embedding".to_owned(), + }; + + let table = db + .create_empty_table("test", schema) + .mode(CreateTableMode::Overwrite) + .add_embedding(ed) + .unwrap() + .execute() + .await + .unwrap(); + + let table_schema = table.schema().await.unwrap(); + assert!(table_schema.column_with_name("name").is_some()); + assert!(table_schema.column_with_name("name_embedding").is_some()); + + let embedding_field = table_schema.field_with_name("name_embedding").unwrap(); + assert_eq!( + embedding_field.data_type(), + &DataType::new_fixed_size_list(DataType::Float32, 128, true) + ); + + let input_batch = record_batch!(("name", Utf8, ["Alice", "Bob", "Charlie"])).unwrap(); + table.add(input_batch).execute().await.unwrap(); + + let results = table + .query() + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 3); + assert!(batch.column_by_name("name_embedding").is_some()); + + let embedding_col = batch + .column_by_name("name_embedding") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(embedding_col.len(), 3); + } +} diff --git a/rust/lancedb/src/data.rs b/rust/lancedb/src/data.rs index d988884a1..16840038f 100644 --- a/rust/lancedb/src/data.rs +++ b/rust/lancedb/src/data.rs @@ -5,3 +5,4 @@ pub mod inspect; pub mod sanitize; +pub mod scannable; diff --git a/rust/lancedb/src/data/scannable.rs b/rust/lancedb/src/data/scannable.rs new file mode 100644 index 000000000..350742bd7 --- /dev/null +++ b/rust/lancedb/src/data/scannable.rs @@ -0,0 +1,580 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! Data source abstraction for LanceDB. +//! +//! This module provides a [`Scannable`] trait that allows input data sources to express +//! capabilities (row count, rescannability) so the insert pipeline can make +//! better decisions about write parallelism and retry strategies. + +use std::sync::Arc; + +use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader}; +use arrow_schema::{ArrowError, SchemaRef}; +use async_trait::async_trait; +use futures::stream::once; +use futures::StreamExt; +use lance_datafusion::utils::StreamingWriteSource; + +use crate::arrow::{ + SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream, +}; +use crate::embeddings::{ + compute_embeddings_for_batch, compute_output_schema, EmbeddingDefinition, EmbeddingFunction, + EmbeddingRegistry, +}; +use crate::table::{ColumnDefinition, ColumnKind, TableDefinition}; +use crate::{Error, Result}; + +pub trait Scannable: Send { + /// Returns the schema of the data. + fn schema(&self) -> SchemaRef; + + /// Read data as a stream of record batches. + /// + /// For rescannable sources (in-memory data like RecordBatch, Vec), + /// this can be called multiple times and returns cloned data each time. + /// + /// For non-rescannable sources (streams, readers), this can only be called once. + /// Calling it a second time returns a stream whose first item is an error. + fn scan_as_stream(&mut self) -> SendableRecordBatchStream; + + /// Optional hint about the number of rows. + /// + /// When available, this allows the pipeline to estimate total data size + /// and choose appropriate partitioning. + fn num_rows(&self) -> Option { + None + } + + /// Whether the source can be re-read from the beginning. + /// + /// `true` for in-memory data (Tables, DataFrames) and disk-based sources (Datasets). + /// `false` for streaming sources (DuckDB results, network streams). + /// + /// When true, the pipeline can retry failed writes by rescanning. + fn rescannable(&self) -> bool { + false + } +} + +impl std::fmt::Debug for dyn Scannable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Scannable") + .field("schema", &self.schema()) + .field("num_rows", &self.num_rows()) + .field("rescannable", &self.rescannable()) + .finish() + } +} + +impl Scannable for RecordBatch { + fn schema(&self) -> SchemaRef { + Self::schema(self) + } + + fn scan_as_stream(&mut self) -> SendableRecordBatchStream { + let batch = self.clone(); + let schema = batch.schema(); + Box::pin(SimpleRecordBatchStream { + schema, + stream: once(async move { Ok(batch) }), + }) + } + + fn num_rows(&self) -> Option { + Some(Self::num_rows(self)) + } + + fn rescannable(&self) -> bool { + true + } +} + +impl Scannable for Vec { + fn schema(&self) -> SchemaRef { + if self.is_empty() { + Arc::new(arrow_schema::Schema::empty()) + } else { + self[0].schema() + } + } + + fn scan_as_stream(&mut self) -> SendableRecordBatchStream { + if self.is_empty() { + let schema = Scannable::schema(self); + return Box::pin(SimpleRecordBatchStream { + schema, + stream: once(async { + Err(Error::InvalidInput { + message: "Cannot scan an empty Vec".to_string(), + }) + }), + }); + } + let schema = Scannable::schema(self); + let batches = self.clone(); + let stream = futures::stream::iter(batches.into_iter().map(Ok)); + Box::pin(SimpleRecordBatchStream { schema, stream }) + } + + fn num_rows(&self) -> Option { + Some(self.iter().map(|b| b.num_rows()).sum()) + } + + fn rescannable(&self) -> bool { + true + } +} + +impl Scannable for Box { + fn schema(&self) -> SchemaRef { + RecordBatchReader::schema(self.as_ref()) + } + + fn scan_as_stream(&mut self) -> SendableRecordBatchStream { + let schema = Scannable::schema(self); + + // Swap self with a reader that errors on iteration, so a second call + // produces a clear error instead of silently returning empty data. + let err_reader: Box = Box::new(RecordBatchIterator::new( + vec![Err(ArrowError::InvalidArgumentError( + "Reader has already been consumed".into(), + ))], + schema.clone(), + )); + let reader = std::mem::replace(self, err_reader); + + // Bridge the blocking RecordBatchReader to an async stream via a channel. + let (tx, rx) = tokio::sync::mpsc::channel::>(2); + tokio::task::spawn_blocking(move || { + for batch_result in reader { + let result = batch_result.map_err(Into::into); + if tx.blocking_send(result).is_err() { + break; + } + } + }); + + let stream = futures::stream::unfold(rx, |mut rx| async move { + rx.recv().await.map(|batch| (batch, rx)) + }) + .fuse(); + + Box::pin(SimpleRecordBatchStream { schema, stream }) + } +} + +impl Scannable for SendableRecordBatchStream { + fn schema(&self) -> SchemaRef { + self.as_ref().schema() + } + + fn scan_as_stream(&mut self) -> SendableRecordBatchStream { + let schema = Scannable::schema(self); + + // Swap self with an error stream so a second call produces a clear error. + let error_stream = Box::pin(SimpleRecordBatchStream { + schema: schema.clone(), + stream: once(async { + Err(Error::InvalidInput { + message: "Stream has already been consumed".to_string(), + }) + }), + }); + std::mem::replace(self, error_stream) + } +} + +#[async_trait] +impl StreamingWriteSource for Box { + fn arrow_schema(&self) -> SchemaRef { + self.schema() + } + + fn into_stream(mut self) -> datafusion_physical_plan::SendableRecordBatchStream { + self.scan_as_stream().into_df_stream() + } +} + +/// A scannable that applies embeddings to the stream. +pub struct WithEmbeddingsScannable { + inner: Box, + embeddings: Vec<(EmbeddingDefinition, Arc)>, + output_schema: SchemaRef, +} + +impl WithEmbeddingsScannable { + /// Create a new WithEmbeddingsScannable. + /// + /// The embeddings are applied to the inner scannable's data as new columns. + pub fn try_new( + inner: Box, + embeddings: Vec<(EmbeddingDefinition, Arc)>, + ) -> Result { + let output_schema = compute_output_schema(&inner.schema(), &embeddings)?; + + // Build column definitions: Physical for base columns, Embedding for new ones + let base_col_count = inner.schema().fields().len(); + let column_definitions: Vec = (0..base_col_count) + .map(|_| ColumnDefinition { + kind: ColumnKind::Physical, + }) + .chain(embeddings.iter().map(|(ed, _)| ColumnDefinition { + kind: ColumnKind::Embedding(ed.clone()), + })) + .collect(); + + let table_definition = TableDefinition::new(output_schema, column_definitions); + let output_schema = table_definition.into_rich_schema(); + + Ok(Self { + inner, + embeddings, + output_schema, + }) + } +} + +impl Scannable for WithEmbeddingsScannable { + fn schema(&self) -> SchemaRef { + self.output_schema.clone() + } + + fn scan_as_stream(&mut self) -> SendableRecordBatchStream { + let inner_stream = self.inner.scan_as_stream(); + let embeddings = self.embeddings.clone(); + let output_schema = self.output_schema.clone(); + + let mapped_stream = inner_stream.then(move |batch_result| { + let embeddings = embeddings.clone(); + async move { + let batch = batch_result?; + let result = tokio::task::spawn_blocking(move || { + compute_embeddings_for_batch(batch, &embeddings) + }) + .await + .map_err(|e| Error::Runtime { + message: format!("Task panicked during embedding computation: {}", e), + })??; + Ok(result) + } + }); + + Box::pin(SimpleRecordBatchStream { + schema: output_schema, + stream: mapped_stream, + }) + } + + fn num_rows(&self) -> Option { + self.inner.num_rows() + } + + fn rescannable(&self) -> bool { + self.inner.rescannable() + } +} + +pub fn scannable_with_embeddings( + inner: Box, + table_definition: &TableDefinition, + registry: Option<&Arc>, +) -> Result> { + if let Some(registry) = registry { + let mut embeddings = Vec::with_capacity(table_definition.column_definitions.len()); + for cd in table_definition.column_definitions.iter() { + if let ColumnKind::Embedding(embedding_def) = &cd.kind { + match registry.get(&embedding_def.embedding_name) { + Some(func) => { + embeddings.push((embedding_def.clone(), func)); + } + None => { + return Err(Error::EmbeddingFunctionNotFound { + name: embedding_def.embedding_name.clone(), + reason: format!( + "Table was defined with an embedding column `{}` but no embedding function was found with that name within the registry.", + embedding_def.embedding_name + ), + }); + } + } + } + } + + if !embeddings.is_empty() { + return Ok(Box::new(WithEmbeddingsScannable::try_new( + inner, embeddings, + )?)); + } + } + + Ok(inner) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::record_batch; + use futures::TryStreamExt; + + #[tokio::test] + async fn test_record_batch_rescannable() { + let mut batch = record_batch!(("id", Int64, [0, 1, 2])).unwrap(); + + let stream1 = batch.scan_as_stream(); + let batches1: Vec = stream1.try_collect().await.unwrap(); + assert_eq!(batches1.len(), 1); + assert_eq!(batches1[0], batch); + + assert!(batch.rescannable()); + let stream2 = batch.scan_as_stream(); + let batches2: Vec = stream2.try_collect().await.unwrap(); + assert_eq!(batches2.len(), 1); + assert_eq!(batches2[0], batch); + } + + #[tokio::test] + async fn test_vec_batch_rescannable() { + let mut batches = vec![ + record_batch!(("id", Int64, [0, 1])).unwrap(), + record_batch!(("id", Int64, [2, 3, 4])).unwrap(), + ]; + + let stream1 = batches.scan_as_stream(); + let result1: Vec = stream1.try_collect().await.unwrap(); + assert_eq!(result1.len(), 2); + assert_eq!(result1[0], batches[0]); + assert_eq!(result1[1], batches[1]); + + assert!(batches.rescannable()); + let stream2 = batches.scan_as_stream(); + let result2: Vec = stream2.try_collect().await.unwrap(); + assert_eq!(result2.len(), 2); + assert_eq!(result2[0], batches[0]); + assert_eq!(result2[1], batches[1]); + } + + #[tokio::test] + async fn test_vec_batch_empty_errors() { + let mut empty: Vec = vec![]; + let mut stream = empty.scan_as_stream(); + let result = stream.next().await; + assert!(result.is_some()); + assert!(result.unwrap().is_err()); + } + + #[tokio::test] + async fn test_reader_not_rescannable() { + let batch = record_batch!(("id", Int64, [0, 1, 2])).unwrap(); + let schema = batch.schema(); + let mut reader: Box = Box::new( + RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone()), + ); + + let stream1 = reader.scan_as_stream(); + let result1: Vec = stream1.try_collect().await.unwrap(); + assert_eq!(result1.len(), 1); + assert_eq!(result1[0], batch); + + assert!(!reader.rescannable()); + // Second call returns a stream whose first item is an error + let mut stream2 = reader.scan_as_stream(); + let result2 = stream2.next().await; + assert!(result2.is_some()); + assert!(result2.unwrap().is_err()); + } + + #[tokio::test] + async fn test_stream_not_rescannable() { + let batch = record_batch!(("id", Int64, [0, 1, 2])).unwrap(); + let schema = batch.schema(); + let inner_stream = futures::stream::iter(vec![Ok(batch.clone())]); + let mut stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream { + schema: schema.clone(), + stream: inner_stream, + }); + + let stream1 = stream.scan_as_stream(); + let result1: Vec = stream1.try_collect().await.unwrap(); + assert_eq!(result1.len(), 1); + assert_eq!(result1[0], batch); + + assert!(!stream.rescannable()); + // Second call returns a stream whose first item is an error + let mut stream2 = stream.scan_as_stream(); + let result2 = stream2.next().await; + assert!(result2.is_some()); + assert!(result2.unwrap().is_err()); + } + + mod embedding_tests { + use super::*; + use crate::embeddings::MemoryRegistry; + use crate::table::{ColumnDefinition, ColumnKind}; + use crate::test_utils::embeddings::MockEmbed; + use arrow_array::Array as _; + use arrow_array::{ArrayRef, StringArray}; + use arrow_schema::{DataType, Field, Schema}; + + #[tokio::test] + async fn test_with_embeddings_scannable() { + let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)])); + let text_array = StringArray::from(vec!["hello", "world", "test"]); + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(text_array) as ArrayRef]) + .unwrap(); + + let mock_embedding: Arc = Arc::new(MockEmbed::new("mock", 4)); + let embedding_def = EmbeddingDefinition::new("text", "mock", Some("text_embedding")); + + let mut scannable = WithEmbeddingsScannable::try_new( + Box::new(batch.clone()), + vec![(embedding_def, mock_embedding)], + ) + .unwrap(); + + // Check that schema has the embedding column + let output_schema = scannable.schema(); + assert_eq!(output_schema.fields().len(), 2); + assert_eq!(output_schema.field(0).name(), "text"); + assert_eq!(output_schema.field(1).name(), "text_embedding"); + + // Check num_rows and rescannable are preserved + assert_eq!(scannable.num_rows(), Some(3)); + assert!(scannable.rescannable()); + + // Read the data + let stream = scannable.scan_as_stream(); + let results: Vec = stream.try_collect().await.unwrap(); + assert_eq!(results.len(), 1); + + let result_batch = &results[0]; + assert_eq!(result_batch.num_rows(), 3); + assert_eq!(result_batch.num_columns(), 2); + + // Verify the embedding column is present and has the right shape + let embedding_col = result_batch.column(1); + assert_eq!(embedding_col.len(), 3); + } + + #[tokio::test] + async fn test_maybe_embedded_scannable_no_embeddings() { + let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap(); + + // Create a table definition with no embedding columns + let table_def = TableDefinition::new_from_schema(batch.schema()); + + // Even with a registry, if there are no embedding columns, it's a passthrough + let registry: Arc = Arc::new(MemoryRegistry::new()); + let mut scannable = + scannable_with_embeddings(Box::new(batch.clone()), &table_def, Some(®istry)) + .unwrap(); + + // Check that data passes through unchanged + let stream = scannable.scan_as_stream(); + let results: Vec = stream.try_collect().await.unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0], batch); + } + + #[tokio::test] + async fn test_maybe_embedded_scannable_with_embeddings() { + let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)])); + let text_array = StringArray::from(vec!["hello", "world"]); + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(text_array) as ArrayRef]) + .unwrap(); + + // Create a table definition with an embedding column + let embedding_def = EmbeddingDefinition::new("text", "mock", Some("text_embedding")); + let embedding_schema = Arc::new(Schema::new(vec![ + Field::new("text", DataType::Utf8, false), + Field::new( + "text_embedding", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + 4, + ), + false, + ), + ])); + let table_def = TableDefinition::new( + embedding_schema, + vec![ + ColumnDefinition { + kind: ColumnKind::Physical, + }, + ColumnDefinition { + kind: ColumnKind::Embedding(embedding_def.clone()), + }, + ], + ); + + // Register the mock embedding function + let registry: Arc = Arc::new(MemoryRegistry::new()); + let mock_embedding: Arc = Arc::new(MockEmbed::new("mock", 4)); + registry.register("mock", mock_embedding).unwrap(); + + let mut scannable = + scannable_with_embeddings(Box::new(batch), &table_def, Some(®istry)).unwrap(); + + // Read and verify the data has embeddings + let stream = scannable.scan_as_stream(); + let results: Vec = stream.try_collect().await.unwrap(); + assert_eq!(results.len(), 1); + + let result_batch = &results[0]; + assert_eq!(result_batch.num_columns(), 2); + assert_eq!(result_batch.schema().field(1).name(), "text_embedding"); + } + + #[tokio::test] + async fn test_maybe_embedded_scannable_missing_function() { + let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)])); + let text_array = StringArray::from(vec!["hello"]); + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(text_array) as ArrayRef]) + .unwrap(); + + // Create a table definition with an embedding column + let embedding_def = + EmbeddingDefinition::new("text", "nonexistent", Some("text_embedding")); + let embedding_schema = Arc::new(Schema::new(vec![ + Field::new("text", DataType::Utf8, false), + Field::new( + "text_embedding", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + 4, + ), + false, + ), + ])); + let table_def = TableDefinition::new( + embedding_schema, + vec![ + ColumnDefinition { + kind: ColumnKind::Physical, + }, + ColumnDefinition { + kind: ColumnKind::Embedding(embedding_def), + }, + ], + ); + + // Registry has no embedding functions registered + let registry: Arc = Arc::new(MemoryRegistry::new()); + + let result = scannable_with_embeddings(Box::new(batch), &table_def, Some(®istry)); + + // Should fail because the embedding function is not found + assert!(result.is_err()); + let err = result.err().unwrap(); + assert!( + matches!(err, Error::EmbeddingFunctionNotFound { .. }), + "Expected EmbeddingFunctionNotFound" + ); + } + } +} diff --git a/rust/lancedb/src/database.rs b/rust/lancedb/src/database.rs index 36947727c..a4a221193 100644 --- a/rust/lancedb/src/database.rs +++ b/rust/lancedb/src/database.rs @@ -18,12 +18,7 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; -use arrow_array::RecordBatchReader; -use async_trait::async_trait; -use datafusion_physical_plan::stream::RecordBatchStreamAdapter; -use futures::stream; use lance::dataset::ReadParams; -use lance_datafusion::utils::StreamingWriteSource; use lance_namespace::models::{ CreateNamespaceRequest, CreateNamespaceResponse, DescribeNamespaceRequest, DescribeNamespaceResponse, DropNamespaceRequest, DropNamespaceResponse, ListNamespacesRequest, @@ -31,9 +26,9 @@ use lance_namespace::models::{ }; use lance_namespace::LanceNamespace; -use crate::arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt}; +use crate::data::scannable::Scannable; use crate::error::Result; -use crate::table::{BaseTable, TableDefinition, WriteOptions}; +use crate::table::{BaseTable, WriteOptions}; pub mod listing; pub mod namespace; @@ -115,51 +110,14 @@ impl Default for CreateTableMode { } } -/// The data to start a table or a schema to create an empty table -pub enum CreateTableData { - /// Creates a table using an iterator of data, the schema will be obtained from the data - Data(Box), - /// Creates a table using a stream of data, the schema will be obtained from the data - StreamingData(SendableRecordBatchStream), - /// Creates an empty table, the definition / schema must be provided separately - Empty(TableDefinition), -} - -impl CreateTableData { - pub fn schema(&self) -> Arc { - match self { - Self::Data(reader) => reader.schema(), - Self::StreamingData(stream) => stream.schema(), - Self::Empty(definition) => definition.schema.clone(), - } - } -} - -#[async_trait] -impl StreamingWriteSource for CreateTableData { - fn arrow_schema(&self) -> Arc { - self.schema() - } - fn into_stream(self) -> datafusion_physical_plan::SendableRecordBatchStream { - match self { - Self::Data(reader) => reader.into_stream(), - Self::StreamingData(stream) => stream.into_df_stream(), - Self::Empty(table_definition) => { - let schema = table_definition.schema.clone(); - Box::pin(RecordBatchStreamAdapter::new(schema, stream::empty())) - } - } - } -} - /// A request to create a table pub struct CreateTableRequest { /// The name of the new table pub name: String, /// The namespace to create the table in. Empty list represents root namespace. pub namespace: Vec, - /// Initial data to write to the table, can be None to create an empty table - pub data: CreateTableData, + /// Initial data to write to the table, can be empty. + pub data: Box, /// The mode to use when creating the table pub mode: CreateTableMode, /// Options to use when writing data (only used if `data` is not None) @@ -173,7 +131,7 @@ pub struct CreateTableRequest { } impl CreateTableRequest { - pub fn new(name: String, data: CreateTableData) -> Self { + pub fn new(name: String, data: Box) -> Self { Self { name, namespace: vec![], diff --git a/rust/lancedb/src/database/listing.rs b/rust/lancedb/src/database/listing.rs index b2b08b607..aea5ce24c 100644 --- a/rust/lancedb/src/database/listing.rs +++ b/rust/lancedb/src/database/listing.rs @@ -922,7 +922,7 @@ impl Database for ListingDatabase { .with_read_params(read_params.clone()) .load() .await - .map_err(|e| Error::Lance { source: e })?; + .map_err(|e| -> Error { e.into() })?; let version_ref = match (request.source_version, request.source_tag) { (Some(v), None) => Ok(Ref::Version(None, Some(v))), @@ -937,7 +937,7 @@ impl Database for ListingDatabase { source_dataset .shallow_clone(&target_uri, version_ref, Some(storage_params)) .await - .map_err(|e| Error::Lance { source: e })?; + .map_err(|e| -> Error { e.into() })?; let cloned_table = NativeTable::open_with_params( &target_uri, @@ -1098,8 +1098,10 @@ impl Database for ListingDatabase { mod tests { use super::*; use crate::connection::ConnectRequest; - use crate::database::{CreateTableData, CreateTableMode, CreateTableRequest, WriteOptions}; - use crate::table::{Table, TableDefinition}; + use crate::data::scannable::Scannable; + use crate::database::{CreateTableMode, CreateTableRequest}; + use crate::table::WriteOptions; + use crate::Table; use arrow_array::{Int32Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use std::path::PathBuf; @@ -1139,7 +1141,7 @@ mod tests { .create_table(CreateTableRequest { name: "source_table".to_string(), namespace: vec![], - data: CreateTableData::Empty(TableDefinition::new_from_schema(schema.clone())), + data: Box::new(RecordBatch::new_empty(schema.clone())) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1196,16 +1198,11 @@ mod tests { ) .unwrap(); - let reader = Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch)], - schema.clone(), - )); - let source_table = db .create_table(CreateTableRequest { name: "source_with_data".to_string(), namespace: vec![], - data: CreateTableData::Data(reader), + data: Box::new(batch) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1264,7 +1261,7 @@ mod tests { db.create_table(CreateTableRequest { name: "source".to_string(), namespace: vec![], - data: CreateTableData::Empty(TableDefinition::new_from_schema(schema)), + data: Box::new(RecordBatch::new_empty(schema)) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1300,7 +1297,7 @@ mod tests { db.create_table(CreateTableRequest { name: "source".to_string(), namespace: vec![], - data: CreateTableData::Empty(TableDefinition::new_from_schema(schema)), + data: Box::new(RecordBatch::new_empty(schema)) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1340,7 +1337,7 @@ mod tests { db.create_table(CreateTableRequest { name: "source".to_string(), namespace: vec![], - data: CreateTableData::Empty(TableDefinition::new_from_schema(schema)), + data: Box::new(RecordBatch::new_empty(schema)) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1380,7 +1377,7 @@ mod tests { db.create_table(CreateTableRequest { name: "source".to_string(), namespace: vec![], - data: CreateTableData::Empty(TableDefinition::new_from_schema(schema)), + data: Box::new(RecordBatch::new_empty(schema)) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1435,7 +1432,7 @@ mod tests { db.create_table(CreateTableRequest { name: "source".to_string(), namespace: vec![], - data: CreateTableData::Empty(TableDefinition::new_from_schema(schema)), + data: Box::new(RecordBatch::new_empty(schema)) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1484,16 +1481,11 @@ mod tests { ) .unwrap(); - let reader = Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch1)], - schema.clone(), - )); - let source_table = db .create_table(CreateTableRequest { name: "versioned_source".to_string(), namespace: vec![], - data: CreateTableData::Data(reader), + data: Box::new(batch1) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1517,14 +1509,7 @@ mod tests { let db = Arc::new(db); let source_table_obj = Table::new(source_table.clone(), db.clone()); - source_table_obj - .add(Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch2)], - schema.clone(), - ))) - .execute() - .await - .unwrap(); + source_table_obj.add(batch2).execute().await.unwrap(); // Verify source table now has 4 rows assert_eq!(source_table.count_rows(None).await.unwrap(), 4); @@ -1570,16 +1555,11 @@ mod tests { ) .unwrap(); - let reader = Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch1)], - schema.clone(), - )); - let source_table = db .create_table(CreateTableRequest { name: "tagged_source".to_string(), namespace: vec![], - data: CreateTableData::Data(reader), + data: Box::new(batch1), mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1607,14 +1587,7 @@ mod tests { .unwrap(); let source_table_obj = Table::new(source_table.clone(), db.clone()); - source_table_obj - .add(Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch2)], - schema.clone(), - ))) - .execute() - .await - .unwrap(); + source_table_obj.add(batch2).execute().await.unwrap(); // Source table should have 4 rows assert_eq!(source_table.count_rows(None).await.unwrap(), 4); @@ -1657,16 +1630,11 @@ mod tests { ) .unwrap(); - let reader = Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch1)], - schema.clone(), - )); - let source_table = db .create_table(CreateTableRequest { name: "independent_source".to_string(), namespace: vec![], - data: CreateTableData::Data(reader), + data: Box::new(batch1), mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1706,14 +1674,7 @@ mod tests { let db = Arc::new(db); let cloned_table_obj = Table::new(cloned_table.clone(), db.clone()); - cloned_table_obj - .add(Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch_clone)], - schema.clone(), - ))) - .execute() - .await - .unwrap(); + cloned_table_obj.add(batch_clone).execute().await.unwrap(); // Add different data to the source table let batch_source = RecordBatch::try_new( @@ -1726,14 +1687,7 @@ mod tests { .unwrap(); let source_table_obj = Table::new(source_table.clone(), db); - source_table_obj - .add(Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch_source)], - schema.clone(), - ))) - .execute() - .await - .unwrap(); + source_table_obj.add(batch_source).execute().await.unwrap(); // Verify they have evolved independently assert_eq!(source_table.count_rows(None).await.unwrap(), 4); // 2 + 2 @@ -1751,16 +1705,11 @@ mod tests { RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from(vec![1, 2]))]) .unwrap(); - let reader = Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch1)], - schema.clone(), - )); - let source_table = db .create_table(CreateTableRequest { name: "latest_version_source".to_string(), namespace: vec![], - data: CreateTableData::Data(reader), + data: Box::new(batch1), mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1779,14 +1728,7 @@ mod tests { .unwrap(); let source_table_obj = Table::new(source_table.clone(), db.clone()); - source_table_obj - .add(Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch)], - schema.clone(), - ))) - .execute() - .await - .unwrap(); + source_table_obj.add(batch).execute().await.unwrap(); } // Source should have 8 rows total (2 + 2 + 2 + 2) @@ -1849,16 +1791,11 @@ mod tests { ) .unwrap(); - let reader = Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch)], - schema.clone(), - )); - let table = db .create_table(CreateTableRequest { name: "test_stable".to_string(), namespace: vec![], - data: CreateTableData::Data(reader), + data: Box::new(batch), mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -1887,11 +1824,6 @@ mod tests { ) .unwrap(); - let reader = Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch)], - schema.clone(), - )); - let mut storage_options = HashMap::new(); storage_options.insert( OPT_NEW_TABLE_ENABLE_STABLE_ROW_IDS.to_string(), @@ -1914,7 +1846,7 @@ mod tests { .create_table(CreateTableRequest { name: "test_stable_table_level".to_string(), namespace: vec![], - data: CreateTableData::Data(reader), + data: Box::new(batch), mode: CreateTableMode::Create, write_options, location: None, @@ -1963,11 +1895,6 @@ mod tests { ) .unwrap(); - let reader = Box::new(arrow_array::RecordBatchIterator::new( - vec![Ok(batch)], - schema.clone(), - )); - let mut storage_options = HashMap::new(); storage_options.insert( OPT_NEW_TABLE_ENABLE_STABLE_ROW_IDS.to_string(), @@ -1990,7 +1917,7 @@ mod tests { .create_table(CreateTableRequest { name: "test_override".to_string(), namespace: vec![], - data: CreateTableData::Data(reader), + data: Box::new(batch), mode: CreateTableMode::Create, write_options, location: None, @@ -2108,7 +2035,7 @@ mod tests { db.create_table(CreateTableRequest { name: "table1".to_string(), namespace: vec![], - data: CreateTableData::Empty(TableDefinition::new_from_schema(schema.clone())), + data: Box::new(RecordBatch::new_empty(schema.clone())) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, @@ -2120,7 +2047,7 @@ mod tests { db.create_table(CreateTableRequest { name: "table2".to_string(), namespace: vec![], - data: CreateTableData::Empty(TableDefinition::new_from_schema(schema)), + data: Box::new(RecordBatch::new_empty(schema)) as Box, mode: CreateTableMode::Create, write_options: Default::default(), location: None, diff --git a/rust/lancedb/src/database/namespace.rs b/rust/lancedb/src/database/namespace.rs index b8ad19cd0..91a55809e 100644 --- a/rust/lancedb/src/database/namespace.rs +++ b/rust/lancedb/src/database/namespace.rs @@ -354,15 +354,13 @@ mod tests { use super::*; use crate::connect_namespace; use crate::query::ExecutableQuery; - use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray}; + use arrow_array::{Int32Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; use tempfile::tempdir; /// Helper function to create test data - fn create_test_data() -> RecordBatchIterator< - std::vec::IntoIter>, - > { + fn create_test_data() -> RecordBatch { let schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("name", DataType::Utf8, false), @@ -371,12 +369,7 @@ mod tests { let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); let name_array = StringArray::from(vec!["Alice", "Bob", "Charlie", "David", "Eve"]); - let batch = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(id_array), Arc::new(name_array)], - ) - .unwrap(); - RecordBatchIterator::new(vec![std::result::Result::Ok(batch)].into_iter(), schema) + RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(name_array)]).unwrap() } #[tokio::test] @@ -618,13 +611,7 @@ mod tests { // Test: Overwrite the table let table2 = conn - .create_table( - "overwrite_test", - RecordBatchIterator::new( - vec![std::result::Result::Ok(test_data2)].into_iter(), - schema, - ), - ) + .create_table("overwrite_test", test_data2) .namespace(vec!["test_ns".into()]) .mode(CreateTableMode::Overwrite) .execute() diff --git a/rust/lancedb/src/dataloader/permutation/builder.rs b/rust/lancedb/src/dataloader/permutation/builder.rs index 9fb507443..c0c418e55 100644 --- a/rust/lancedb/src/dataloader/permutation/builder.rs +++ b/rust/lancedb/src/dataloader/permutation/builder.rs @@ -13,7 +13,7 @@ use lance_datafusion::exec::SessionContextExt; use crate::{ arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream}, connect, - database::{CreateTableData, CreateTableRequest, Database}, + database::{CreateTableRequest, Database}, dataloader::permutation::{ shuffle::{Shuffler, ShufflerConfig}, split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN}, @@ -313,10 +313,8 @@ impl PermutationBuilder { } }; - let create_table_request = CreateTableRequest::new( - name.to_string(), - CreateTableData::StreamingData(streaming_data), - ); + let create_table_request = + CreateTableRequest::new(name.to_string(), Box::new(streaming_data)); let table = database.create_table(create_table_request).await?; @@ -347,7 +345,7 @@ mod tests { .col("col_b", lance_datagen::array::step::()) .into_ldb_stream(RowCount::from(100), BatchCount::from(10)); let data_table = db - .create_table_streaming("base_tbl", initial_data) + .create_table("base_tbl", initial_data) .execute() .await .unwrap(); @@ -387,7 +385,7 @@ mod tests { .col("some_value", lance_datagen::array::step::()) .into_ldb_stream(RowCount::from(100), BatchCount::from(10)); let data_table = db - .create_table_streaming("mytbl", initial_data) + .create_table("mytbl", initial_data) .execute() .await .unwrap(); diff --git a/rust/lancedb/src/embeddings.rs b/rust/lancedb/src/embeddings.rs index 56b87c1cb..b6272a123 100644 --- a/rust/lancedb/src/embeddings.rs +++ b/rust/lancedb/src/embeddings.rs @@ -18,7 +18,7 @@ use std::{ }; use arrow_array::{Array, RecordBatch, RecordBatchReader}; -use arrow_schema::{DataType, Field, SchemaBuilder}; +use arrow_schema::{DataType, Field, SchemaBuilder, SchemaRef}; // use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -190,6 +190,112 @@ impl WithEmbeddings { } } +/// Compute embedding arrays for a batch. +/// +/// When multiple embedding functions are defined, they are computed in parallel using +/// scoped threads. For a single embedding function, computation is done inline. +fn compute_embedding_arrays( + batch: &RecordBatch, + embeddings: &[(EmbeddingDefinition, Arc)], +) -> Result>> { + if embeddings.len() == 1 { + let (fld, func) = &embeddings[0]; + let src_column = + batch + .column_by_name(&fld.source_column) + .ok_or_else(|| Error::InvalidInput { + message: format!("Source column '{}' not found", fld.source_column), + })?; + return Ok(vec![func.compute_source_embeddings(src_column.clone())?]); + } + + // Parallel path: multiple embeddings + std::thread::scope(|s| { + let handles: Vec<_> = embeddings + .iter() + .map(|(fld, func)| { + let src_column = batch.column_by_name(&fld.source_column).ok_or_else(|| { + Error::InvalidInput { + message: format!("Source column '{}' not found", fld.source_column), + } + })?; + + let handle = s.spawn(move || func.compute_source_embeddings(src_column.clone())); + + Ok(handle) + }) + .collect::>()?; + + handles + .into_iter() + .map(|h| { + h.join().map_err(|e| Error::Runtime { + message: format!("Thread panicked during embedding computation: {:?}", e), + })? + }) + .collect() + }) +} + +/// Compute the output schema when embeddings are applied to a base schema. +/// +/// This returns the schema with embedding columns appended. +pub fn compute_output_schema( + base_schema: &SchemaRef, + embeddings: &[(EmbeddingDefinition, Arc)], +) -> Result { + let mut sb: SchemaBuilder = base_schema.as_ref().into(); + + for (ed, func) in embeddings { + let src_field = base_schema + .field_with_name(&ed.source_column) + .map_err(|_| Error::InvalidInput { + message: format!("Source column '{}' not found in schema", ed.source_column), + })?; + + let field_name = ed + .dest_column + .clone() + .unwrap_or_else(|| format!("{}_embedding", &ed.source_column)); + + sb.push(Field::new( + field_name, + func.dest_type()?.into_owned(), + src_field.is_nullable(), + )); + } + + Ok(Arc::new(sb.finish())) +} + +/// Compute embeddings for a batch and append as new columns. +/// +/// This function computes embeddings using the provided embedding functions and +/// appends them as new columns to the batch. +pub fn compute_embeddings_for_batch( + batch: RecordBatch, + embeddings: &[(EmbeddingDefinition, Arc)], +) -> Result { + let embedding_arrays = compute_embedding_arrays(&batch, embeddings)?; + + let mut result = batch; + for ((fld, _), embedding) in embeddings.iter().zip(embedding_arrays.iter()) { + let dst_field_name = fld + .dest_column + .clone() + .unwrap_or_else(|| format!("{}_embedding", &fld.source_column)); + + let dst_field = Field::new( + dst_field_name, + embedding.data_type().clone(), + embedding.nulls().is_some(), + ); + + result = result.try_with_column(dst_field, embedding.clone())?; + } + Ok(result) +} + impl WithEmbeddings { fn dest_fields(&self) -> Result> { let schema = self.inner.schema(); @@ -240,48 +346,6 @@ impl WithEmbeddings { column_definitions, }) } - - fn compute_embeddings_parallel(&self, batch: &RecordBatch) -> Result>> { - if self.embeddings.len() == 1 { - let (fld, func) = &self.embeddings[0]; - let src_column = - batch - .column_by_name(&fld.source_column) - .ok_or_else(|| Error::InvalidInput { - message: format!("Source column '{}' not found", fld.source_column), - })?; - return Ok(vec![func.compute_source_embeddings(src_column.clone())?]); - } - - // Parallel path: multiple embeddings - std::thread::scope(|s| { - let handles: Vec<_> = self - .embeddings - .iter() - .map(|(fld, func)| { - let src_column = batch.column_by_name(&fld.source_column).ok_or_else(|| { - Error::InvalidInput { - message: format!("Source column '{}' not found", fld.source_column), - } - })?; - - let handle = - s.spawn(move || func.compute_source_embeddings(src_column.clone())); - - Ok(handle) - }) - .collect::>()?; - - handles - .into_iter() - .map(|h| { - h.join().map_err(|e| Error::Runtime { - message: format!("Thread panicked during embedding computation: {:?}", e), - })? - }) - .collect() - }) - } } impl Iterator for MaybeEmbedded { @@ -309,37 +373,13 @@ impl Iterator for WithEmbeddings { fn next(&mut self) -> Option { let batch = self.inner.next()?; match batch { - Ok(batch) => { - let embeddings = match self.compute_embeddings_parallel(&batch) { - Ok(emb) => emb, - Err(e) => { - return Some(Err(arrow_schema::ArrowError::ComputeError(format!( - "Error computing embedding: {}", - e - )))) - } - }; - - let mut batch = batch; - for ((fld, _), embedding) in self.embeddings.iter().zip(embeddings.iter()) { - let dst_field_name = fld - .dest_column - .clone() - .unwrap_or_else(|| format!("{}_embedding", &fld.source_column)); - - let dst_field = Field::new( - dst_field_name, - embedding.data_type().clone(), - embedding.nulls().is_some(), - ); - - match batch.try_with_column(dst_field.clone(), embedding.clone()) { - Ok(b) => batch = b, - Err(e) => return Some(Err(e)), - }; - } - Some(Ok(batch)) - } + Ok(batch) => match compute_embeddings_for_batch(batch, &self.embeddings) { + Ok(batch_with_embeddings) => Some(Ok(batch_with_embeddings)), + Err(e) => Some(Err(arrow_schema::ArrowError::ComputeError(format!( + "Error computing embedding: {}", + e + )))), + }, Err(e) => Some(Err(e)), } } diff --git a/rust/lancedb/src/error.rs b/rust/lancedb/src/error.rs index 4312b3860..55e2350ac 100644 --- a/rust/lancedb/src/error.rs +++ b/rust/lancedb/src/error.rs @@ -6,7 +6,7 @@ use std::sync::PoisonError; use arrow_schema::ArrowError; use snafu::Snafu; -type BoxError = Box; +pub(crate) type BoxError = Box; #[derive(Debug, Snafu)] #[snafu(visibility(pub(crate)))] @@ -80,6 +80,9 @@ pub enum Error { Arrow { source: ArrowError }, #[snafu(display("LanceDBError: not supported: {message}"))] NotSupported { message: String }, + /// External error pass through from user code. + #[snafu(transparent)] + External { source: BoxError }, #[snafu(whatever, display("{message}"))] Other { message: String, @@ -92,15 +95,26 @@ pub type Result = std::result::Result; impl From for Error { fn from(source: ArrowError) -> Self { - Self::Arrow { source } + match source { + ArrowError::ExternalError(source) => match source.downcast::() { + Ok(e) => *e, + Err(source) => Self::External { source }, + }, + _ => Self::Arrow { source }, + } } } impl From for Error { fn from(source: lance::Error) -> Self { - // TODO: Once Lance is changed to preserve ObjectStore, DataFusion, and Arrow errors, we can - // pass those variants through here as well. - Self::Lance { source } + // Try to unwrap external errors that were wrapped by lance + match source { + lance::Error::Wrapped { error, .. } => match error.downcast::() { + Ok(e) => *e, + Err(source) => Self::External { source }, + }, + _ => Self::Lance { source }, + } } } diff --git a/rust/lancedb/src/io/object_store.rs b/rust/lancedb/src/io/object_store.rs index 4988ae05c..d935c099a 100644 --- a/rust/lancedb/src/io/object_store.rs +++ b/rust/lancedb/src/io/object_store.rs @@ -218,8 +218,9 @@ mod test { datagen = datagen.col(Box::::default()); datagen = datagen.col(Box::new(RandomVector::default().named("vector".into()))); + let data: Box = Box::new(datagen.batch(100)); let res = db - .create_table("test", Box::new(datagen.batch(100))) + .create_table("test", data) .write_options(WriteOptions { lance_write_params: Some(param), }) diff --git a/rust/lancedb/src/ipc.rs b/rust/lancedb/src/ipc.rs index c814b41bf..e71f2e521 100644 --- a/rust/lancedb/src/ipc.rs +++ b/rust/lancedb/src/ipc.rs @@ -12,10 +12,10 @@ use arrow_schema::Schema; use crate::{Error, Result}; /// Convert a Arrow IPC file to a batch reader -pub fn ipc_file_to_batches(buf: Vec) -> Result { +pub fn ipc_file_to_batches(buf: Vec) -> Result> { let buf_reader = Cursor::new(buf); let reader = FileReader::try_new(buf_reader, None)?; - Ok(reader) + Ok(Box::new(reader)) } /// Convert record batches to Arrow IPC file diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index 64699cb15..944613253 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -39,7 +39,6 @@ //! #### Connect to a database. //! //! ```rust -//! # use arrow_schema::{Field, Schema}; //! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! let db = lancedb::connect("data/sample-lancedb").execute().await.unwrap(); //! # }); @@ -74,7 +73,10 @@ //! //! #### Create a table //! -//! To create a Table, you need to provide a [`arrow_schema::Schema`] and a [`arrow_array::RecordBatch`] stream. +//! To create a Table, you need to provide an [`arrow_array::RecordBatch`]. The +//! schema of the `RecordBatch` determines the schema of the table. +//! +//! Vector columns should be represented as `FixedSizeList` data type. //! //! ```rust //! # use std::sync::Arc; @@ -85,34 +87,29 @@ //! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! # let tmpdir = tempfile::tempdir().unwrap(); //! # let db = lancedb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap(); +//! let ndims = 128; //! let schema = Arc::new(Schema::new(vec![ //! Field::new("id", DataType::Int32, false), //! Field::new( //! "vector", -//! DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 128), +//! DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), ndims), //! true, //! ), //! ])); -//! // Create a RecordBatch stream. -//! let batches = RecordBatchIterator::new( -//! vec![RecordBatch::try_new( +//! let data = RecordBatch::try_new( //! schema.clone(), //! vec![ //! Arc::new(Int32Array::from_iter_values(0..256)), //! Arc::new( //! FixedSizeListArray::from_iter_primitive::( -//! (0..256).map(|_| Some(vec![Some(1.0); 128])), -//! 128, +//! (0..256).map(|_| Some(vec![Some(1.0); ndims as usize])), +//! ndims, //! ), //! ), //! ], //! ) -//! .unwrap()] -//! .into_iter() -//! .map(Ok), -//! schema.clone(), -//! ); -//! db.create_table("my_table", Box::new(batches)) +//! .unwrap(); +//! db.create_table("my_table", data) //! .execute() //! .await //! .unwrap(); @@ -151,42 +148,18 @@ //! #### Open table and search //! //! ```rust -//! # use std::sync::Arc; //! # use futures::TryStreamExt; -//! # use arrow_schema::{DataType, Schema, Field}; -//! # use arrow_array::{RecordBatch, RecordBatchIterator}; -//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type}; //! # use lancedb::query::{ExecutableQuery, QueryBase}; -//! # tokio::runtime::Runtime::new().unwrap().block_on(async { -//! # let tmpdir = tempfile::tempdir().unwrap(); -//! # let db = lancedb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap(); -//! # let schema = Arc::new(Schema::new(vec![ -//! # Field::new("id", DataType::Int32, false), -//! # Field::new("vector", DataType::FixedSizeList( -//! # Arc::new(Field::new("item", DataType::Float32, true)), 128), true), -//! # ])); -//! # let batches = RecordBatchIterator::new(vec![ -//! # RecordBatch::try_new(schema.clone(), -//! # vec![ -//! # Arc::new(Int32Array::from_iter_values(0..10)), -//! # Arc::new(FixedSizeListArray::from_iter_primitive::( -//! # (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)), -//! # ]).unwrap() -//! # ].into_iter().map(Ok), -//! # schema.clone()); -//! # db.create_table("my_table", Box::new(batches)).execute().await.unwrap(); -//! # let table = db.open_table("my_table").execute().await.unwrap(); +//! # async fn example(table: &lancedb::Table) -> lancedb::Result<()> { //! let results = table //! .query() -//! .nearest_to(&[1.0; 128]) -//! .unwrap() +//! .nearest_to(&[1.0; 128])? //! .execute() -//! .await -//! .unwrap() +//! .await? //! .try_collect::>() -//! .await -//! .unwrap(); -//! # }); +//! .await?; +//! # Ok(()) +//! # } //! ``` pub mod arrow; diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index c3e2f8faa..b1691c8ac 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -1381,7 +1381,7 @@ mod tests { use arrow::{array::downcast_array, compute::concat_batches, datatypes::Int32Type}; use arrow_array::{ cast::AsArray, types::Float32Type, FixedSizeListArray, Float32Array, Int32Array, - RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray, + RecordBatch, StringArray, }; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; use futures::{StreamExt, TryStreamExt}; @@ -1402,7 +1402,7 @@ mod tests { let batches = make_test_batches(); let conn = connect(uri).execute().await.unwrap(); let table = conn - .create_table("my_table", Box::new(batches)) + .create_table("my_table", batches) .execute() .await .unwrap(); @@ -1463,7 +1463,7 @@ mod tests { let batches = make_non_empty_batches(); let conn = connect(uri).execute().await.unwrap(); let table = conn - .create_table("my_table", Box::new(batches)) + .create_table("my_table", batches) .execute() .await .unwrap(); @@ -1525,7 +1525,7 @@ mod tests { let batches = make_non_empty_batches(); let conn = connect(uri).execute().await.unwrap(); let table = conn - .create_table("my_table", Box::new(batches)) + .create_table("my_table", batches) .execute() .await .unwrap(); @@ -1578,7 +1578,7 @@ mod tests { let batches = make_non_empty_batches(); let conn = connect(uri).execute().await.unwrap(); let table = conn - .create_table("my_table", Box::new(batches)) + .create_table("my_table", batches) .execute() .await .unwrap(); @@ -1599,13 +1599,13 @@ mod tests { assert!(result.is_err()); } - fn make_non_empty_batches() -> impl RecordBatchReader + Send + 'static { + fn make_non_empty_batches() -> Box { let vec = Box::new(RandomVector::new().named("vector".to_string())); let id = Box::new(IncrementingInt32::new().named("id".to_string())); - BatchGenerator::new().col(vec).col(id).batch(512) + Box::new(BatchGenerator::new().col(vec).col(id).batch(512)) } - fn make_test_batches() -> impl RecordBatchReader + Send + 'static { + fn make_test_batches() -> RecordBatch { let dim: usize = 128; let schema = Arc::new(ArrowSchema::new(vec![ ArrowField::new("key", DataType::Int32, false), @@ -1619,12 +1619,7 @@ mod tests { ), ArrowField::new("uri", DataType::Utf8, true), ])); - RecordBatchIterator::new( - vec![RecordBatch::new_empty(schema.clone())] - .into_iter() - .map(Ok), - schema, - ) + RecordBatch::new_empty(schema) } async fn make_test_table(tmp_dir: &tempfile::TempDir) -> Table { @@ -1633,7 +1628,7 @@ mod tests { let batches = make_non_empty_batches(); let conn = connect(uri).execute().await.unwrap(); - conn.create_table("my_table", Box::new(batches)) + conn.create_table("my_table", batches) .execute() .await .unwrap() @@ -1862,10 +1857,8 @@ mod tests { let record_batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vector)]).unwrap(); - let record_batch_iter = - RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone()); let table = conn - .create_table("my_table", record_batch_iter) + .create_table("my_table", record_batch) .execute() .await .unwrap(); @@ -1949,10 +1942,8 @@ mod tests { ], ) .unwrap(); - let record_batch_iter = - RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone()); let table = conn - .create_table("my_table", record_batch_iter) + .create_table("my_table", record_batch) .mode(CreateTableMode::Overwrite) .execute() .await @@ -2062,8 +2053,6 @@ mod tests { async fn test_pagination_with_fts() { let db = connect("memory://test").execute().await.unwrap(); let data = fts_test_data(400); - let schema = data.schema(); - let data = RecordBatchIterator::new(vec![Ok(data)], schema); let table = db.create_table("test_table", data).execute().await.unwrap(); table diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 46e51569a..73d2f14da 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -491,7 +491,7 @@ impl RestfulLanceDbClient { } /// Apply dynamic headers from the header provider if configured - async fn apply_dynamic_headers(&self, mut request: Request) -> Result { + pub(crate) async fn apply_dynamic_headers(&self, mut request: Request) -> Result { if let Some(ref provider) = self.header_provider { let headers = provider.get_headers().await?; let request_headers = request.headers_mut(); @@ -555,7 +555,9 @@ impl RestfulLanceDbClient { message: "Attempted to retry a request that cannot be cloned".to_string(), })?; let (_, r) = tmp_req.build_split(); - let mut r = r.unwrap(); + let mut r = r.map_err(|e| Error::Runtime { + message: format!("Failed to build request: {}", e), + })?; let request_id = self.extract_request_id(&mut r); let mut retry_counter = RetryCounter::new(retry_config, request_id.clone()); @@ -571,7 +573,9 @@ impl RestfulLanceDbClient { } let (c, request) = req_builder.build_split(); - let mut request = request.unwrap(); + let mut request = request.map_err(|e| Error::Runtime { + message: format!("Failed to build request: {}", e), + })?; self.set_request_id(&mut request, &request_id.clone()); // Apply dynamic headers before each retry attempt @@ -621,7 +625,7 @@ impl RestfulLanceDbClient { } } - fn log_request(&self, request: &Request, request_id: &String) { + pub(crate) fn log_request(&self, request: &Request, request_id: &String) { if log::log_enabled!(log::Level::Debug) { let content_type = request .headers() diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 66736a872..38541f079 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -4,13 +4,11 @@ use std::collections::HashMap; use std::sync::Arc; -use arrow_array::RecordBatchIterator; use async_trait::async_trait; use http::StatusCode; use lance_io::object_store::StorageOptions; use moka::future::Cache; use reqwest::header::CONTENT_TYPE; -use tokio::task::spawn_blocking; use lance_namespace::models::{ CreateNamespaceRequest, CreateNamespaceResponse, DescribeNamespaceRequest, @@ -19,16 +17,17 @@ use lance_namespace::models::{ }; use crate::database::{ - CloneTableRequest, CreateTableData, CreateTableMode, CreateTableRequest, Database, - DatabaseOptions, OpenTableRequest, ReadConsistency, TableNamesRequest, + CloneTableRequest, CreateTableMode, CreateTableRequest, Database, DatabaseOptions, + OpenTableRequest, ReadConsistency, TableNamesRequest, }; use crate::error::Result; +use crate::remote::util::stream_as_body; use crate::table::BaseTable; use crate::Error; use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender}; use super::table::RemoteTable; -use super::util::{batches_to_ipc_bytes, parse_server_version}; +use super::util::parse_server_version; use super::ARROW_STREAM_CONTENT_TYPE; // Request structure for the remote clone table API @@ -436,26 +435,8 @@ impl Database for RemoteDatabase { Ok(response) } - async fn create_table(&self, request: CreateTableRequest) -> Result> { - let data = match request.data { - CreateTableData::Data(data) => data, - CreateTableData::StreamingData(_) => { - return Err(Error::NotSupported { - message: "Creating a remote table from a streaming source".to_string(), - }) - } - CreateTableData::Empty(table_definition) => { - let schema = table_definition.schema.clone(); - Box::new(RecordBatchIterator::new(vec![], schema)) - } - }; - - // TODO: https://github.com/lancedb/lancedb/issues/1026 - // We should accept data from an async source. In the meantime, spawn this as blocking - // to make sure we don't block the tokio runtime if the source is slow. - let data_buffer = spawn_blocking(move || batches_to_ipc_bytes(data)) - .await - .unwrap()?; + async fn create_table(&self, mut request: CreateTableRequest) -> Result> { + let body = stream_as_body(request.data.scan_as_stream())?; let identifier = build_table_identifier(&request.name, &request.namespace, &self.client.id_delimiter); @@ -463,7 +444,7 @@ impl Database for RemoteDatabase { .client .post(&format!("/v1/table/{}/create/", identifier)) .query(&[("mode", Into::<&str>::into(&request.mode))]) - .body(data_buffer) + .body(body) .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE); let (request_id, rsp) = self.client.send(req).await?; @@ -813,7 +794,7 @@ mod tests { use std::collections::HashMap; use std::sync::{Arc, OnceLock}; - use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator}; + use arrow_array::{Int32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; use crate::connection::ConnectBuilder; @@ -993,8 +974,7 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema()); - let table = conn.create_table("table1", reader).execute().await.unwrap(); + let table = conn.create_table("table1", data).execute().await.unwrap(); assert_eq!(table.name(), "table1"); } @@ -1011,8 +991,7 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema()); - let result = conn.create_table("table1", reader).execute().await; + let result = conn.create_table("table1", data).execute().await; assert!(result.is_err()); assert!( matches!(result, Err(crate::Error::TableAlreadyExists { name }) if name == "table1") @@ -1045,8 +1024,7 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema()); - let mut builder = conn.create_table("table1", reader); + let mut builder = conn.create_table("table1", data.clone()); if let Some(mode) = mode { builder = builder.mode(mode); } @@ -1071,9 +1049,8 @@ mod tests { .unwrap(); let called: Arc> = Arc::new(OnceLock::new()); - let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema()); let called_in_cb = called.clone(); - conn.create_table("table1", reader) + conn.create_table("table1", data) .mode(CreateTableMode::ExistOk(Box::new(move |b| { called_in_cb.clone().set(true).unwrap(); b @@ -1262,9 +1239,8 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema()); let table = conn - .create_table("table1", reader) + .create_table("table1", data) .namespace(vec!["ns1".to_string()]) .execute() .await @@ -1730,10 +1706,8 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let reader = RecordBatchIterator::new([Ok(data.clone())], schema.clone()); - let table = conn - .create_table("test_table", reader) + .create_table("test_table", data) .namespace(namespace.clone()) .execute() .await; @@ -1806,9 +1780,7 @@ mod tests { let data = RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from(vec![i]))]) .unwrap(); - let reader = RecordBatchIterator::new([Ok(data.clone())], schema.clone()); - - conn.create_table(format!("table{}", i), reader) + conn.create_table(format!("table{}", i), data) .namespace(namespace.clone()) .execute() .await diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index cf0989b0b..8ed4b6236 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -3,9 +3,11 @@ pub mod insert; +use crate::data::scannable::Scannable; use crate::index::Index; use crate::index::IndexStatistics; use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest}; +use crate::remote::util::stream_as_ipc; use crate::table::AddColumnsResult; use crate::table::AddResult; use crate::table::AlterColumnsResult; @@ -45,10 +47,10 @@ use tokio::sync::RwLock; use super::client::RequestResultExt; use super::client::{HttpSend, RestfulLanceDbClient, Sender}; use super::db::ServerVersion; +use super::util::stream_as_body; use super::ARROW_STREAM_CONTENT_TYPE; use crate::index::waiter::wait_for_index; use crate::{ - connection::NoData, error::Result, index::{IndexBuilder, IndexConfig}, query::QueryExecutionOptions, @@ -264,7 +266,10 @@ impl RemoteTable { fn reader_as_body(data: Box) -> Result { // TODO: Once Phalanx supports compression, we should use it here. - let mut writer = arrow_ipc::writer::StreamWriter::try_new(Vec::new(), &data.schema())?; + let mut writer = arrow_ipc::writer::StreamWriter::try_new( + Vec::new(), + &RecordBatchReader::schema(&*data), + )?; // Mutex is just here to make it sync. We shouldn't have any contention. let mut data = Mutex::new(data); @@ -340,6 +345,110 @@ impl RemoteTable { Ok(res) } + /// Send a request with data from a Scannable source. + /// + /// For rescannable sources, this will retry on retryable errors by re-reading + /// the data. For non-rescannable sources (streams), only a single attempt is made. + async fn send_scannable( + &self, + req_builder: RequestBuilder, + data: &mut dyn Scannable, + ) -> Result<(String, Response)> { + use crate::remote::retry::RetryCounter; + + // Right now, Python and Typescript don't pass down re-scannable data yet. + // So to preserve existing retry behavior, we have to collect data in + // memory for now. Once they expose rescannable data sources, we can remove this. + if !data.rescannable() && self.client.retry_config.retries > 0 { + let mut body = Vec::new(); + stream_as_ipc(data.scan_as_stream())? + .try_for_each(|b| { + body.extend_from_slice(&b); + futures::future::ok(()) + }) + .await?; + let req_builder = req_builder.body(body); + return self.client.send_with_retry(req_builder, None, true).await; + } + + let rescannable = data.rescannable(); + let max_retries = if rescannable { + self.client.retry_config.retries + } else { + 0 + }; + + // Clone the request builder to extract the request id + let tmp_req = req_builder.try_clone().ok_or_else(|| Error::Runtime { + message: "Attempted to retry a request that cannot be cloned".to_string(), + })?; + let (_, r) = tmp_req.build_split(); + let mut r = r.map_err(|e| Error::Runtime { + message: format!("Failed to build request: {}", e), + })?; + let request_id = self.client.extract_request_id(&mut r); + let mut retry_counter = RetryCounter::new(&self.client.retry_config, request_id.clone()); + + loop { + // Re-read data on each attempt + let stream = data.scan_as_stream(); + let body = stream_as_body(stream)?; + + let mut req_builder = req_builder.try_clone().ok_or_else(|| Error::Runtime { + message: "Attempted to retry a request that cannot be cloned".to_string(), + })?; + req_builder = req_builder.body(body); + + let (c, request) = req_builder.build_split(); + let mut request = request.map_err(|e| Error::Runtime { + message: format!("Failed to build request: {}", e), + })?; + self.client.set_request_id(&mut request, &request_id); + + // Apply dynamic headers + request = self.client.apply_dynamic_headers(request).await?; + + self.client.log_request(&request, &request_id); + + let response = match self.client.sender.send(&c, request).await { + Ok(r) => r, + Err(err) => { + if err.is_connect() { + retry_counter.increment_connect_failures(err)?; + } else if err.is_body() || err.is_decode() { + retry_counter.increment_read_failures(err)?; + } else { + return Err(crate::Error::Http { + source: err.into(), + request_id, + status_code: None, + }); + } + tokio::time::sleep(retry_counter.next_sleep_time()).await; + continue; + } + }; + + let status = response.status(); + + // Check for retryable status codes + if self.client.retry_config.statuses.contains(&status) + && retry_counter.request_failures < max_retries + { + let http_err = crate::Error::Http { + source: format!("Retryable status code: {}", status).into(), + request_id: request_id.clone(), + status_code: Some(status), + }; + retry_counter.increment_request_failures(http_err)?; + tokio::time::sleep(retry_counter.next_sleep_time()).await; + continue; + } + + return Ok((request_id, response)); + } + } + pub(super) async fn handle_table_not_found( table_name: &str, response: reqwest::Response, @@ -656,8 +765,9 @@ impl std::fmt::Display for RemoteTable { #[cfg(all(test, feature = "remote"))] mod test_utils { use super::*; - use crate::remote::client::test_utils::client_with_handler; use crate::remote::client::test_utils::MockSender; + use crate::remote::client::test_utils::{client_with_handler, client_with_handler_and_config}; + use crate::remote::ClientConfig; impl RemoteTable { pub fn new_mock(name: String, handler: F, version: Option) -> Self @@ -676,6 +786,23 @@ mod test_utils { location: RwLock::new(None), } } + + pub fn new_mock_with_config(name: String, handler: F, config: ClientConfig) -> Self + where + F: Fn(reqwest::Request) -> http::Response + Send + Sync + 'static, + T: Into, + { + let client = client_with_handler_and_config(handler, config); + Self { + client, + name: name.clone(), + namespace: vec![], + identifier: name, + server_version: ServerVersion::default(), + version: RwLock::new(None), + location: RwLock::new(None), + } + } } } @@ -797,11 +924,7 @@ impl BaseTable for RemoteTable { status_code: None, }) } - async fn add( - &self, - add: AddDataBuilder, - data: Box, - ) -> Result { + async fn add(&self, mut add: AddDataBuilder) -> Result { self.check_mutable().await?; let mut request = self .client @@ -815,7 +938,7 @@ impl BaseTable for RemoteTable { } } - let (request_id, response) = self.send_streaming(request, data, true).await?; + let (request_id, response) = self.send_scannable(request, &mut *add.data).await?; let response = self.check_table_response(&request_id, response).await?; let body = response.text().await.err_to_http(request_id.clone())?; if body.trim().is_empty() { @@ -1584,12 +1707,14 @@ impl TryFrom for MergeInsertRequest { #[cfg(test)] mod tests { + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; use std::{collections::HashMap, pin::Pin}; use super::*; use arrow::{array::AsArray, compute::concat_batches, datatypes::Int32Type}; - use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator}; + use arrow_array::{record_batch, Int32Array, RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; use chrono::{DateTime, Utc}; use futures::{future::BoxFuture, StreamExt, TryFutureExt}; @@ -1623,7 +1748,8 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let example_data = || { + let example_data_for_add = || batch.clone(); + let example_data_for_merge = || -> Box { Box::new(RecordBatchIterator::new( [Ok(batch.clone())], batch.schema(), @@ -1636,11 +1762,11 @@ mod tests { Box::pin(table.schema().map_ok(|_| ())), Box::pin(table.count_rows(None).map_ok(|_| ())), Box::pin(table.update().column("a", "a + 1").execute().map_ok(|_| ())), - Box::pin(table.add(example_data()).execute().map_ok(|_| ())), + Box::pin(table.add(example_data_for_add()).execute().map_ok(|_| ())), Box::pin( table .merge_insert(&["test"]) - .execute(example_data()) + .execute(example_data_for_merge()) .map_ok(|_| ()), ), Box::pin(table.delete("false").map_ok(|_| ())), @@ -1772,8 +1898,15 @@ mod tests { fn write_ipc_stream(data: &RecordBatch) -> Vec { let mut body = Vec::new(); { - let mut writer = arrow_ipc::writer::StreamWriter::try_new(&mut body, &data.schema()) - .expect("Failed to create writer"); + let options = arrow_ipc::writer::IpcWriteOptions::default() + .try_with_compression(Some(arrow_ipc::CompressionType::LZ4_FRAME)) + .expect("Failed to create IPC write options"); + let mut writer = arrow_ipc::writer::StreamWriter::try_new_with_options( + &mut body, + &data.schema(), + options, + ) + .expect("Failed to create writer"); writer.write(data).expect("Failed to write data"); writer.finish().expect("Failed to finish"); } @@ -1831,11 +1964,7 @@ mod tests { panic!("Unexpected request path: {}", request.url().path()); } }); - let result = table - .add(RecordBatchIterator::new([Ok(data.clone())], data.schema())) - .execute() - .await - .unwrap(); + let result = table.add(data.clone()).execute().await.unwrap(); // Check version matches expected value assert_eq!(result.version, expected_version); @@ -1892,7 +2021,7 @@ mod tests { }); let result = table - .add(RecordBatchIterator::new([Ok(data.clone())], data.schema())) + .add(data.clone()) .mode(AddDataMode::Overwrite) .execute() .await @@ -2032,7 +2161,7 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let data = Box::new(RecordBatchIterator::new( + let data: Box = Box::new(RecordBatchIterator::new( [Ok(batch.clone())], batch.schema(), )); @@ -2084,7 +2213,7 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let data = Box::new(RecordBatchIterator::new( + let data: Box = Box::new(RecordBatchIterator::new( [Ok(batch.clone())], batch.schema(), )); @@ -3015,7 +3144,7 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let data = Box::new(RecordBatchIterator::new( + let data: Box = Box::new(RecordBatchIterator::new( [Ok(batch.clone())], batch.schema(), )); @@ -3030,10 +3159,7 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let res = table - .add(RecordBatchIterator::new([Ok(data.clone())], data.schema())) - .execute() - .await; + let res = table.add(data.clone()).execute().await; assert!(matches!(res, Err(Error::NotSupported { .. }))); let res = table @@ -3295,11 +3421,7 @@ mod tests { } }); - let result = table - .add(RecordBatchIterator::new([Ok(data.clone())], data.schema())) - .execute() - .await - .unwrap(); + let result = table.add(data.clone()).execute().await.unwrap(); assert_eq!(result.version, 2); @@ -3446,4 +3568,95 @@ mod tests { assert_eq!(uri2, "gs://bucket/table"); assert_eq!(call_count.load(Ordering::SeqCst), 1); // Still 1, no new call } + + #[tokio::test] + async fn test_add_retries_rescannable_data() { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + // Configure with retries enabled (default is 3) + let config = crate::remote::ClientConfig::default(); + + let table = Table::new_with_handler_and_config( + "my_table", + move |_request| { + let count = call_count_clone.fetch_add(1, Ordering::SeqCst); + if count < 2 { + // First two attempts fail with a retryable error (409) + http::Response::builder().status(409).body("").unwrap() + } else { + // Third attempt succeeds + http::Response::builder() + .status(200) + .body(r#"{"version": 1}"#) + .unwrap() + } + }, + config, + ); + + // RecordBatch is rescannable - should retry and succeed + let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap(); + let result = table.add(batch).execute().await; + + assert!( + result.is_ok(), + "Expected success after retries: {:?}", + result + ); + assert_eq!( + call_count.load(Ordering::SeqCst), + 3, + "Expected 2 failed attempts + 1 success = 3 total" + ); + } + + #[tokio::test] + async fn test_add_no_retry_for_non_rescannable() { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + // Configure with retries enabled + let config = crate::remote::ClientConfig::default(); + + let table = Table::new_with_handler_and_config( + "my_table", + move |_request| { + call_count_clone.fetch_add(1, Ordering::SeqCst); + // Always fail with retryable error + http::Response::builder().status(409).body("").unwrap() + }, + config, + ); + + // RecordBatchReader is NOT rescannable - should NOT retry + let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap(); + let reader: Box = Box::new( + RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()), + ); + + let result = table.add(reader).execute().await; + + // Should fail because we can't retry non-rescannable sources + assert!(result.is_err()); + // Right now, we actually do retry, so we get 3 failures. In the future + // this will change and we need to update the test. + assert!( + matches!( + result.unwrap_err(), + Error::Retry { + request_failures: 3, + .. + } + ), + "Expected RequestFailed with status 409" + ); + // TODO: After we implement proper non-rescannable handling, uncomment below + // (This is blocked on getting Python and Node to pass down re-scannable data.) + // assert_eq!( + // call_count.load(Ordering::SeqCst), + // 1, + // "Expected only one attempt for non-rescannable source" + // ); + } } diff --git a/rust/lancedb/src/remote/util.rs b/rust/lancedb/src/remote/util.rs index 4b92ec0c2..51ddb97de 100644 --- a/rust/lancedb/src/remote/util.rs +++ b/rust/lancedb/src/remote/util.rs @@ -1,29 +1,50 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors -use std::io::Cursor; - -use arrow_array::RecordBatchReader; +use arrow_ipc::CompressionType; +use futures::{Stream, StreamExt}; use reqwest::Response; -use crate::Result; +use crate::{arrow::SendableRecordBatchStream, Result}; use super::db::ServerVersion; -pub fn batches_to_ipc_bytes(batches: impl RecordBatchReader) -> Result> { +pub fn stream_as_ipc( + data: SendableRecordBatchStream, +) -> Result>> { + let options = arrow_ipc::writer::IpcWriteOptions::default() + .try_with_compression(Some(CompressionType::LZ4_FRAME))?; const WRITE_BUF_SIZE: usize = 4096; let buf = Vec::with_capacity(WRITE_BUF_SIZE); - let mut buf = Cursor::new(buf); - { - let mut writer = arrow_ipc::writer::StreamWriter::try_new(&mut buf, &batches.schema())?; + let writer = + arrow_ipc::writer::StreamWriter::try_new_with_options(buf, &data.schema(), options)?; + let stream = futures::stream::try_unfold( + (data, writer, false), + move |(mut data, mut writer, finished)| async move { + if finished { + return Ok(None); + } + match data.next().await { + Some(Ok(batch)) => { + writer.write(&batch)?; + let buffer = std::mem::take(writer.get_mut()); + Ok(Some((bytes::Bytes::from(buffer), (data, writer, false)))) + } + Some(Err(e)) => Err(e), + None => { + writer.finish()?; + let buffer = std::mem::take(writer.get_mut()); + Ok(Some((bytes::Bytes::from(buffer), (data, writer, true)))) + } + } + }, + ); + Ok(stream) +} - for batch in batches { - let batch = batch?; - writer.write(&batch)?; - } - writer.finish()?; - } - Ok(buf.into_inner()) +pub fn stream_as_body(data: SendableRecordBatchStream) -> Result { + let stream = stream_as_ipc(data)?; + Ok(reqwest::Body::wrap_stream(stream)) } pub fn parse_server_version(req_id: &str, rsp: &Response) -> Result { diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 776479a8b..b7761f5f9 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -5,7 +5,7 @@ use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder}; use arrow::datatypes::{Float32Type, UInt8Type}; -use arrow_array::{RecordBatchIterator, RecordBatchReader}; +use arrow_array::{RecordBatch, RecordBatchReader}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion_expr::Expr; @@ -50,10 +50,9 @@ use std::format; use std::path::Path; use std::sync::Arc; -use crate::arrow::IntoArrow; -use crate::connection::NoData; +use crate::data::scannable::{scannable_with_embeddings, Scannable}; use crate::database::Database; -use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry}; +use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MemoryRegistry}; use crate::error::{Error, Result}; use crate::index::vector::VectorIndex; use crate::index::IndexStatistics; @@ -72,6 +71,7 @@ use crate::utils::{ use self::dataset::DatasetConsistencyWrapper; use self::merge::MergeInsertBuilder; +mod add_data; pub mod datafusion; pub(crate) mod dataset; pub mod delete; @@ -80,6 +80,8 @@ pub mod optimize; pub mod schema_evolution; pub mod update; +pub use add_data::{AddDataBuilder, AddDataMode, AddResult}; + use crate::index::waiter::wait_for_index; pub use chrono::Duration; pub use delete::DeleteResult; @@ -196,60 +198,6 @@ pub struct WriteOptions { pub lance_write_params: Option, } -#[derive(Debug, Clone, Default)] -pub enum AddDataMode { - /// Rows will be appended to the table (the default) - #[default] - Append, - /// The existing table will be overwritten with the new data - Overwrite, -} - -/// A builder for configuring a [`crate::connection::Connection::create_table`] or [`Table::add`] -/// operation -pub struct AddDataBuilder { - parent: Arc, - pub(crate) data: T, - pub(crate) mode: AddDataMode, - pub(crate) write_options: WriteOptions, - embedding_registry: Option>, -} - -impl std::fmt::Debug for AddDataBuilder { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("AddDataBuilder") - .field("parent", &self.parent) - .field("mode", &self.mode) - .field("write_options", &self.write_options) - .finish() - } -} - -impl AddDataBuilder { - pub fn mode(mut self, mode: AddDataMode) -> Self { - self.mode = mode; - self - } - - pub fn write_options(mut self, options: WriteOptions) -> Self { - self.write_options = options; - self - } - - pub async fn execute(self) -> Result { - let parent = self.parent.clone(); - let data = self.data.into_arrow()?; - let without_data = AddDataBuilder:: { - data: NoData {}, - mode: self.mode, - parent: self.parent, - write_options: self.write_options, - embedding_registry: self.embedding_registry, - }; - parent.add(without_data, data).await - } -} - /// Filters that can be used to limit the rows returned by a query pub enum Filter { /// A SQL filter string @@ -283,15 +231,6 @@ pub trait Tags: Send + Sync { async fn update(&mut self, tag: &str, version: u64) -> Result<()>; } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] -pub struct AddResult { - // The commit version associated with the operation. - // A version of `0` indicates compatibility with legacy servers that do not return - /// a commit version. - #[serde(default)] - pub version: u64, -} - #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] pub struct MergeResult { // The commit version associated with the operation. @@ -364,11 +303,7 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { ) -> Result; /// Add new records to the table. - async fn add( - &self, - add: AddDataBuilder, - data: Box, - ) -> Result; + async fn add(&self, add: AddDataBuilder) -> Result; /// Delete rows from the table. async fn delete(&self, predicate: &str) -> Result; /// Update rows in the table. @@ -513,6 +448,30 @@ mod test_utils { embedding_registry: Arc::new(MemoryRegistry::new()), } } + + pub fn new_with_handler_and_config( + name: impl Into, + handler: impl Fn(reqwest::Request) -> http::Response + Clone + Send + Sync + 'static, + config: crate::remote::ClientConfig, + ) -> Self + where + T: Into, + { + let inner = Arc::new(crate::remote::table::RemoteTable::new_mock_with_config( + name.into(), + handler.clone(), + config.clone(), + )); + let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock_with_config( + handler, config, + )); + Self { + inner, + database: Some(database), + // Registry is unused. + embedding_registry: Arc::new(MemoryRegistry::new()), + } + } } } @@ -613,16 +572,14 @@ impl Table { /// /// # Arguments /// - /// * `batches` data to be added to the Table + /// * `data` data to be added to the Table /// * `options` options to control how data is added - pub fn add(&self, batches: T) -> AddDataBuilder { - AddDataBuilder { - parent: self.inner.clone(), - data: batches, - mode: AddDataMode::Append, - write_options: WriteOptions::default(), - embedding_registry: Some(self.embedding_registry.clone()), - } + pub fn add(&self, data: T) -> AddDataBuilder { + AddDataBuilder::new( + self.inner.clone(), + Box::new(data), + Some(self.embedding_registry.clone()), + ) } /// Update existing records in the Table @@ -661,31 +618,26 @@ impl Table { /// .execute() /// .await /// .unwrap(); - /// # let schema = Arc::new(Schema::new(vec![ - /// # Field::new("id", DataType::Int32, false), - /// # Field::new("vector", DataType::FixedSizeList( - /// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true), - /// # ])); - /// let batches = RecordBatchIterator::new( - /// vec![RecordBatch::try_new( - /// schema.clone(), - /// vec![ - /// Arc::new(Int32Array::from_iter_values(0..10)), - /// Arc::new( - /// FixedSizeListArray::from_iter_primitive::( - /// (0..10).map(|_| Some(vec![Some(1.0); 128])), - /// 128, - /// ), - /// ), - /// ], - /// ) - /// .unwrap()] - /// .into_iter() - /// .map(Ok), + /// let schema = Arc::new(Schema::new(vec![ + /// Field::new("id", DataType::Int32, false), + /// Field::new("vector", DataType::FixedSizeList( + /// Arc::new(Field::new("item", DataType::Float32, true)), 128), true), + /// ])); + /// let data = RecordBatch::try_new( /// schema.clone(), - /// ); + /// vec![ + /// Arc::new(Int32Array::from_iter_values(0..10)), + /// Arc::new( + /// FixedSizeListArray::from_iter_primitive::( + /// (0..10).map(|_| Some(vec![Some(1.0); 128])), + /// 128, + /// ), + /// ), + /// ], + /// ) + /// .unwrap(); /// let tbl = db - /// .create_table("delete_test", Box::new(batches)) + /// .create_table("delete_test", data) /// .execute() /// .await /// .unwrap(); @@ -1445,7 +1397,7 @@ impl NativeTable { name: name.to_string(), source: Box::new(e), }, - source => Error::Lance { source }, + e => e.into(), })?; let dataset = DatasetConsistencyWrapper::new_latest(dataset, read_consistency_interval); @@ -1529,7 +1481,7 @@ impl NativeTable { lance::Error::Namespace { source, .. } => Error::Runtime { message: format!("Failed to get table info from namespace: {:?}", source), }, - source => Error::Lance { source }, + e => e.into(), })?; let dataset = builder @@ -1541,7 +1493,7 @@ impl NativeTable { name: name.to_string(), source: Box::new(e), }, - source => Error::Lance { source }, + e => e.into(), })?; let uri = dataset.uri().to_string(); @@ -1635,7 +1587,7 @@ impl NativeTable { lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists { name: name.to_string(), }, - source => Error::Lance { source }, + e => e.into(), })?; let id = Self::build_id(&namespace, name); @@ -1662,12 +1614,12 @@ impl NativeTable { read_consistency_interval: Option, namespace_client: Option>, ) -> Result { - let batches = RecordBatchIterator::new(vec![], schema); + let data: Box = Box::new(RecordBatch::new_empty(schema)); Self::create( uri, name, namespace, - batches, + data, write_store_wrapper, params, read_consistency_interval, @@ -1756,7 +1708,7 @@ impl NativeTable { lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists { name: name.to_string(), }, - source => Error::Lance { source }, + e => e.into(), })?; let id = Self::build_id(&namespace, name); @@ -2538,17 +2490,7 @@ impl BaseTable for NativeTable { } } - async fn add( - &self, - add: AddDataBuilder, - data: Box, - ) -> Result { - let data = Box::new(MaybeEmbedded::try_new( - data, - self.table_definition().await?, - add.embedding_registry, - )?) as Box; - + async fn add(&self, add: AddDataBuilder) -> Result { let lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams { mode: match add.mode { AddDataMode::Append => WriteMode::Append, @@ -2557,6 +2499,11 @@ impl BaseTable for NativeTable { ..Default::default() }); + // Apply embeddings if configured + let table_def = self.table_definition().await?; + let data = + scannable_with_embeddings(add.data, &table_def, add.embedding_registry.as_ref())?; + let dataset = { // Limited scope for the mutable borrow of self.dataset avoids deadlock. let ds = self.dataset.get_mut().await?; @@ -3164,7 +3111,6 @@ mod tests { use arrow_data::ArrayDataBuilder; use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; - use lance::dataset::WriteMode; use lance::io::{ObjectStoreParams, WrappingObjectStore}; use lance::Dataset; use rand::Rng; @@ -3183,8 +3129,9 @@ mod tests { let tmp_dir = tempdir().unwrap(); let dataset_path = tmp_dir.path().join("test.lance"); - let batches = make_test_batches(); - Dataset::write(batches, dataset_path.to_str().unwrap(), None) + let batch = make_test_batches(); + let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()); + Dataset::write(reader, dataset_path.to_str().unwrap(), None) .await .unwrap(); @@ -3217,9 +3164,12 @@ mod tests { let tmp_dir = tempdir().unwrap(); let uri = tmp_dir.path().to_str().unwrap(); - let batches = make_test_batches(); - let batches = Box::new(batches) as Box; - let table = NativeTable::create(uri, "test", vec![], batches, None, None, None, None) + let batch = make_test_batches(); + let reader: Box = Box::new(RecordBatchIterator::new( + vec![Ok(batch.clone())], + batch.schema(), + )); + let table = NativeTable::create(uri, "test", vec![], reader, None, None, None, None) .await .unwrap(); @@ -3233,33 +3183,6 @@ mod tests { ); } - #[tokio::test] - async fn test_add() { - let tmp_dir = tempdir().unwrap(); - let uri = tmp_dir.path().to_str().unwrap(); - let conn = connect(uri).execute().await.unwrap(); - - let batches = make_test_batches(); - let schema = batches.schema().clone(); - let table = conn.create_table("test", batches).execute().await.unwrap(); - assert_eq!(table.count_rows(None).await.unwrap(), 10); - - let new_batches = RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from_iter_values(100..110))], - ) - .unwrap()] - .into_iter() - .map(Ok), - schema.clone(), - ); - - table.add(new_batches).execute().await.unwrap(); - assert_eq!(table.count_rows(None).await.unwrap(), 20); - assert_eq!(table.name(), "test"); - } - #[tokio::test] async fn test_merge_insert() { let tmp_dir = tempdir().unwrap(); @@ -3276,7 +3199,7 @@ mod tests { assert_eq!(table.count_rows(None).await.unwrap(), 10); // Create new data with i=5..15 - let new_batches = Box::new(merge_insert_test_batches(5, 1)); + let new_batches = merge_insert_test_batches(5, 1); // Perform a "insert if not exists" let mut merge_insert_builder = table.merge_insert(&["i"]); @@ -3290,7 +3213,7 @@ mod tests { assert_eq!(result.num_attempts, 1); // Create new data with i=15..25 (no id matches) - let new_batches = Box::new(merge_insert_test_batches(15, 2)); + let new_batches = merge_insert_test_batches(15, 2); // Perform a "bulk update" (should not affect anything) let mut merge_insert_builder = table.merge_insert(&["i"]); merge_insert_builder.when_matched_update_all(None); @@ -3303,7 +3226,7 @@ mod tests { ); // Conditional update that only replaces the age=0 data - let new_batches = Box::new(merge_insert_test_batches(5, 3)); + let new_batches = merge_insert_test_batches(5, 3); let mut merge_insert_builder = table.merge_insert(&["i"]); merge_insert_builder.when_matched_update_all(Some("target.age = 0".to_string())); merge_insert_builder.execute(new_batches).await.unwrap(); @@ -3329,7 +3252,7 @@ mod tests { assert_eq!(table.count_rows(None).await.unwrap(), 10); // Test use_index=true (default behavior) - let new_batches = Box::new(merge_insert_test_batches(5, 1)); + let new_batches = merge_insert_test_batches(5, 1); let mut merge_insert_builder = table.merge_insert(&["i"]); merge_insert_builder.when_not_matched_insert_all(); merge_insert_builder.use_index(true); @@ -3337,7 +3260,7 @@ mod tests { assert_eq!(table.count_rows(None).await.unwrap(), 15); // Test use_index=false (force table scan) - let new_batches = Box::new(merge_insert_test_batches(15, 2)); + let new_batches = merge_insert_test_batches(15, 2); let mut merge_insert_builder = table.merge_insert(&["i"]); merge_insert_builder.when_not_matched_insert_all(); merge_insert_builder.use_index(false); @@ -3345,59 +3268,6 @@ mod tests { assert_eq!(table.count_rows(None).await.unwrap(), 25); } - #[tokio::test] - async fn test_add_overwrite() { - let tmp_dir = tempdir().unwrap(); - let uri = tmp_dir.path().to_str().unwrap(); - let conn = connect(uri).execute().await.unwrap(); - - let batches = make_test_batches(); - let schema = batches.schema().clone(); - let table = conn.create_table("test", batches).execute().await.unwrap(); - assert_eq!(table.count_rows(None).await.unwrap(), 10); - - let batches = vec![RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from_iter_values(100..110))], - ) - .unwrap()] - .into_iter() - .map(Ok); - - let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone()); - - // Can overwrite using AddDataOptions::mode - table - .add(new_batches) - .mode(AddDataMode::Overwrite) - .execute() - .await - .unwrap(); - assert_eq!(table.count_rows(None).await.unwrap(), 10); - assert_eq!(table.name(), "test"); - - // Can overwrite using underlying WriteParams (which - // take precedence over AddDataOptions::mode) - - let param: WriteParams = WriteParams { - mode: WriteMode::Overwrite, - ..Default::default() - }; - - let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone()); - table - .add(new_batches) - .write_options(WriteOptions { - lance_write_params: Some(param), - }) - .mode(AddDataMode::Append) - .execute() - .await - .unwrap(); - assert_eq!(table.count_rows(None).await.unwrap(), 10); - assert_eq!(table.name(), "test"); - } - #[derive(Default, Debug)] struct NoOpCacheWrapper { called: AtomicBool, @@ -3453,35 +3323,25 @@ mod tests { assert!(wrapper.called()); } - fn merge_insert_test_batches( - offset: i32, - age: i32, - ) -> impl RecordBatchReader + Send + Sync + 'static { + fn merge_insert_test_batches(offset: i32, age: i32) -> Box { let schema = Arc::new(Schema::new(vec![ Field::new("i", DataType::Int32, false), Field::new("age", DataType::Int32, false), ])); - RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_iter_values(offset..(offset + 10))), - Arc::new(Int32Array::from_iter_values(std::iter::repeat_n(age, 10))), - ], - )], - schema, + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_iter_values(offset..(offset + 10))), + Arc::new(Int32Array::from_iter_values(std::iter::repeat_n(age, 10))), + ], ) + .unwrap(); + Box::new(RecordBatchIterator::new(vec![Ok(batch)], schema)) } - fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { + fn make_test_batches() -> RecordBatch { let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); - RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from_iter_values(0..10))], - )], - schema, - ) + RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from_iter_values(0..10))]).unwrap() } #[tokio::test] @@ -3569,14 +3429,9 @@ mod tests { ); let vectors = Arc::new(create_fixed_size_list(float_arr, dimension).unwrap()); - let batches = RecordBatchIterator::new( - vec![RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap()] - .into_iter() - .map(Ok), - schema, - ); + let batch = RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap(); - let table = conn.create_table("test", batches).execute().await.unwrap(); + let table = conn.create_table("test", batch).execute().await.unwrap(); assert_eq!(table.index_stats("my_index").await.unwrap(), None); @@ -3658,14 +3513,9 @@ mod tests { let float_arr = Float32Array::from_iter_values((0..(num_rows * dimension)).map(|v| v as f32)); let vectors = Arc::new(create_fixed_size_list(float_arr, dimension as i32).unwrap()); - let batches = RecordBatchIterator::new( - vec![RecordBatch::try_new(schema.clone(), vec![vectors]).unwrap()] - .into_iter() - .map(Ok), - schema, - ); + let batch = RecordBatch::try_new(schema.clone(), vec![vectors]).unwrap(); - let table = conn.create_table("test", batches).execute().await.unwrap(); + let table = conn.create_table("test", batch).execute().await.unwrap(); let native_table = table.as_native().unwrap(); let builder = IvfPqIndexBuilder::default(); table @@ -3735,14 +3585,9 @@ mod tests { ); let vectors = Arc::new(create_fixed_size_list(float_arr, dimension).unwrap()); - let batches = RecordBatchIterator::new( - vec![RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap()] - .into_iter() - .map(Ok), - schema, - ); + let batch = RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap(); - let table = conn.create_table("test", batches).execute().await.unwrap(); + let table = conn.create_table("test", batch).execute().await.unwrap(); let stats = table.index_stats("my_index").await.unwrap(); assert!(stats.is_none()); @@ -3800,14 +3645,9 @@ mod tests { ); let vectors = Arc::new(create_fixed_size_list(float_arr, dimension).unwrap()); - let batches = RecordBatchIterator::new( - vec![RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap()] - .into_iter() - .map(Ok), - schema, - ); + let batch = RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap(); - let table = conn.create_table("test", batches).execute().await.unwrap(); + let table = conn.create_table("test", batch).execute().await.unwrap(); let stats = table.index_stats("my_index").await.unwrap(); assert!(stats.is_none()); @@ -3850,7 +3690,7 @@ mod tests { Ok(FixedSizeListArray::from(data)) } - fn some_sample_data() -> Box { + fn some_sample_data() -> Box { let batch = RecordBatch::try_new( Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])), vec![Arc::new(Int32Array::from(vec![1]))], @@ -3874,10 +3714,7 @@ mod tests { .unwrap(); let conn = ConnectBuilder::new(uri).execute().await.unwrap(); let table = conn - .create_table( - "my_table", - RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()), - ) + .create_table("my_table", batch.clone()) .execute() .await .unwrap(); @@ -3956,10 +3793,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_bitmap", - RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()), - ) + .create_table("test_bitmap", batch.clone()) .execute() .await .unwrap(); @@ -4060,10 +3894,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_bitmap", - RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()), - ) + .create_table("test_bitmap", batch.clone()) .execute() .await .unwrap(); @@ -4123,10 +3954,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_bitmap", - RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()), - ) + .create_table("test_bitmap", batch.clone()) .execute() .await .unwrap(); @@ -4171,7 +3999,7 @@ mod tests { let conn1 = ConnectBuilder::new(uri).execute().await.unwrap(); let table1 = conn1 - .create_empty_table("my_table", data.schema()) + .create_empty_table("my_table", RecordBatchReader::schema(&data)) .execute() .await .unwrap(); @@ -4441,10 +4269,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_stats", - RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()), - ) + .create_table("test_stats", batch.clone()) .execute() .await .unwrap(); @@ -4457,21 +4282,11 @@ mod tests { ], ) .unwrap(); - table - .add(RecordBatchIterator::new( - vec![Ok(batch.clone())], - batch.schema(), - )) - .execute() - .await - .unwrap(); + table.add(batch.clone()).execute().await.unwrap(); } let empty_table = conn - .create_table( - "test_stats_empty", - RecordBatchIterator::new(vec![], batch.schema()), - ) + .create_table("test_stats_empty", RecordBatch::new_empty(batch.schema())) .execute() .await .unwrap(); @@ -4545,22 +4360,12 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_list_indices_skip_frag_reuse", - RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()), - ) + .create_table("test_list_indices_skip_frag_reuse", batch.clone()) .execute() .await .unwrap(); - table - .add(RecordBatchIterator::new( - vec![Ok(batch.clone())], - batch.schema(), - )) - .execute() - .await - .unwrap(); + table.add(batch.clone()).execute().await.unwrap(); table .create_index(&["id"], Index::Bitmap(BitmapIndexBuilder {})) @@ -4590,8 +4395,9 @@ mod tests { let tmp_dir = tempdir().unwrap(); let dataset_path = tmp_dir.path().join("test_ns_query.lance"); - let batches = make_test_batches(); - Dataset::write(batches, dataset_path.to_str().unwrap(), None) + let batch = make_test_batches(); + let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()); + Dataset::write(reader, dataset_path.to_str().unwrap(), None) .await .unwrap(); @@ -4643,8 +4449,9 @@ mod tests { let tmp_dir = tempdir().unwrap(); let dataset_path = tmp_dir.path().join("test_ns_plain.lance"); - let batches = make_test_batches(); - Dataset::write(batches, dataset_path.to_str().unwrap(), None) + let batch = make_test_batches(); + let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()); + Dataset::write(reader, dataset_path.to_str().unwrap(), None) .await .unwrap(); diff --git a/rust/lancedb/src/table/add_data.rs b/rust/lancedb/src/table/add_data.rs new file mode 100644 index 000000000..740e53011 --- /dev/null +++ b/rust/lancedb/src/table/add_data.rs @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; + +use crate::data::scannable::Scannable; +use crate::embeddings::EmbeddingRegistry; +use crate::Result; + +use super::{BaseTable, WriteOptions}; + +#[derive(Debug, Clone, Default)] +pub enum AddDataMode { + /// Rows will be appended to the table (the default) + #[default] + Append, + /// The existing table will be overwritten with the new data + Overwrite, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +pub struct AddResult { + // The commit version associated with the operation. + // A version of `0` indicates compatibility with legacy servers that do not return + /// a commit version. + #[serde(default)] + pub version: u64, +} + +/// A builder for configuring a [`crate::table::Table::add`] operation +pub struct AddDataBuilder { + pub(crate) parent: Arc, + pub(crate) data: Box, + pub(crate) mode: AddDataMode, + pub(crate) write_options: WriteOptions, + pub(crate) embedding_registry: Option>, +} + +impl std::fmt::Debug for AddDataBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AddDataBuilder") + .field("parent", &self.parent) + .field("mode", &self.mode) + .field("write_options", &self.write_options) + .finish() + } +} + +impl AddDataBuilder { + pub(crate) fn new( + parent: Arc, + data: Box, + embedding_registry: Option>, + ) -> Self { + Self { + parent, + data, + mode: AddDataMode::Append, + write_options: WriteOptions::default(), + embedding_registry, + } + } + + pub fn mode(mut self, mode: AddDataMode) -> Self { + self.mode = mode; + self + } + + pub fn write_options(mut self, options: WriteOptions) -> Self { + self.write_options = options; + self + } + + pub async fn execute(self) -> Result { + self.parent.clone().add(self).await + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::{record_batch, RecordBatch, RecordBatchIterator}; + use arrow_schema::{ArrowError, DataType, Field, Schema}; + use futures::TryStreamExt; + use lance::dataset::{WriteMode, WriteParams}; + + use crate::arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}; + use crate::connect; + use crate::data::scannable::Scannable; + use crate::embeddings::{ + EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, + }; + use crate::query::{ExecutableQuery, QueryBase, Select}; + use crate::table::{ColumnDefinition, ColumnKind, Table, TableDefinition, WriteOptions}; + use crate::test_utils::embeddings::MockEmbed; + use crate::Error; + + use super::AddDataMode; + + async fn create_test_table() -> Table { + let conn = connect("memory://").execute().await.unwrap(); + let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap(); + conn.create_table("test", batch).execute().await.unwrap() + } + + async fn test_add_with_data(data: T) + where + T: Scannable + 'static, + { + let table = create_test_table().await; + let schema = data.schema(); + table.add(data).execute().await.unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), 5); // 3 initial + 2 added + assert_eq!(table.schema().await.unwrap(), schema); + } + + #[tokio::test] + async fn test_add_with_batch() { + let batch = record_batch!(("id", Int64, [4, 5])).unwrap(); + test_add_with_data(batch).await; + } + + #[tokio::test] + async fn test_add_with_vec_batch() { + let data = vec![ + record_batch!(("id", Int64, [4])).unwrap(), + record_batch!(("id", Int64, [5])).unwrap(), + ]; + test_add_with_data(data).await; + } + + #[tokio::test] + async fn test_add_with_record_batch_reader() { + let data = vec![ + record_batch!(("id", Int64, [4])).unwrap(), + record_batch!(("id", Int64, [5])).unwrap(), + ]; + let schema = data[0].schema(); + let reader: Box = Box::new( + RecordBatchIterator::new(data.into_iter().map(Ok), schema.clone()), + ); + test_add_with_data(reader).await; + } + + #[tokio::test] + async fn test_add_with_stream() { + let data = vec![ + record_batch!(("id", Int64, [4])).unwrap(), + record_batch!(("id", Int64, [5])).unwrap(), + ]; + let schema = data[0].schema(); + let inner = futures::stream::iter(data.into_iter().map(Ok)); + let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream { + schema, + stream: inner, + }); + test_add_with_data(stream).await; + } + + #[derive(Debug)] + struct MyError; + + impl std::fmt::Display for MyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MyError occurred") + } + } + + impl std::error::Error for MyError {} + + #[tokio::test] + async fn test_add_preserves_reader_error() { + let table = create_test_table().await; + let first_batch = record_batch!(("id", Int64, [4])).unwrap(); + let schema = first_batch.schema(); + let iterator = vec![ + Ok(first_batch), + Err(ArrowError::ExternalError(Box::new(MyError))), + ]; + let reader: Box = Box::new( + RecordBatchIterator::new(iterator.into_iter(), schema.clone()), + ); + + let result = table.add(reader).execute().await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_add_preserves_stream_error() { + let table = create_test_table().await; + let first_batch = record_batch!(("id", Int64, [4])).unwrap(); + let schema = first_batch.schema(); + let iterator = vec![ + Ok(first_batch), + Err(Error::External { + source: Box::new(MyError), + }), + ]; + let stream = futures::stream::iter(iterator); + let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream { + schema: schema.clone(), + stream, + }); + + let result = table.add(stream).execute().await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_add() { + let conn = connect("memory://").execute().await.unwrap(); + + let batch = record_batch!(("i", Int32, [0, 1, 2])).unwrap(); + let table = conn + .create_table("test", batch.clone()) + .execute() + .await + .unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), 3); + + let new_batch = record_batch!(("i", Int32, [3])).unwrap(); + table.add(new_batch).execute().await.unwrap(); + + assert_eq!(table.count_rows(None).await.unwrap(), 4); + assert_eq!(table.schema().await.unwrap(), batch.schema()); + } + + #[tokio::test] + async fn test_add_overwrite() { + let conn = connect("memory://").execute().await.unwrap(); + + let batch = record_batch!(("i", Int32, [0, 1, 2])).unwrap(); + let table = conn + .create_table("test", batch.clone()) + .execute() + .await + .unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), batch.num_rows()); + + let new_batch = record_batch!(("x", Float32, [0.0, 1.0])).unwrap(); + let res = table + .add(new_batch.clone()) + .mode(AddDataMode::Overwrite) + .execute() + .await + .unwrap(); + assert_eq!(res.version, table.version().await.unwrap()); + assert_eq!(table.count_rows(None).await.unwrap(), new_batch.num_rows()); + assert_eq!(table.schema().await.unwrap(), new_batch.schema()); + + // Can overwrite using underlying WriteParams (which + // take precedence over AddDataMode) + let param: WriteParams = WriteParams { + mode: WriteMode::Overwrite, + ..Default::default() + }; + + table + .add(new_batch.clone()) + .write_options(WriteOptions { + lance_write_params: Some(param), + }) + .mode(AddDataMode::Append) + .execute() + .await + .unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), new_batch.num_rows()); + } + + #[tokio::test] + async fn test_add_with_embeddings() { + let registry = Arc::new(MemoryRegistry::new()); + let mock_embedding: Arc = Arc::new(MockEmbed::new("mock", 4)); + registry.register("mock", mock_embedding).unwrap(); + + let conn = connect("memory://") + .embedding_registry(registry) + .execute() + .await + .unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("text", DataType::Utf8, false), + Field::new( + "text_embedding", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4), + false, + ), + ])); + + // Add embedding metadata to the schema + let embedding_def = EmbeddingDefinition::new("text", "mock", Some("text_embedding")); + let table_def = TableDefinition::new( + schema.clone(), + vec![ + ColumnDefinition { + kind: ColumnKind::Physical, + }, + ColumnDefinition { + kind: ColumnKind::Embedding(embedding_def), + }, + ], + ); + let rich_schema = table_def.into_rich_schema(); + + let table = conn + .create_empty_table("embed_test", rich_schema) + .execute() + .await + .unwrap(); + + // Now add new data WITHOUT the embedding column - it should be computed automatically + let new_batch = record_batch!(("text", Utf8, ["hello", "world"])).unwrap(); + table.add(new_batch).execute().await.unwrap(); + + assert_eq!(table.count_rows(None).await.unwrap(), 2); + + // Query to verify the embeddings were computed for the new rows + let results: Vec = table + .query() + .select(Select::columns(&["text", "text_embedding"])) + .execute() + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 2); + + // Check that all rows have embedding values (not null) + for batch in &results { + let embedding_col = batch.column(1); + assert_eq!(embedding_col.null_count(), 0); + } + } +} diff --git a/rust/lancedb/src/table/datafusion.rs b/rust/lancedb/src/table/datafusion.rs index 760871631..c7ced18ea 100644 --- a/rust/lancedb/src/table/datafusion.rs +++ b/rust/lancedb/src/table/datafusion.rs @@ -287,8 +287,7 @@ pub mod tests { use arrow::array::AsArray; use arrow_array::{ - BinaryArray, Float64Array, Int32Array, Int64Array, RecordBatch, RecordBatchIterator, - RecordBatchReader, StringArray, UInt32Array, + BinaryArray, Float64Array, Int32Array, Int64Array, RecordBatch, StringArray, UInt32Array, }; use arrow_schema::{DataType, Field, Schema}; use datafusion::{ @@ -308,7 +307,7 @@ pub mod tests { table::datafusion::BaseTableAdapter, }; - fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { + fn make_test_batches() -> RecordBatch { let metadata = HashMap::from_iter(vec![("foo".to_string(), "bar".to_string())]); let schema = Arc::new( Schema::new(vec![ @@ -317,19 +316,17 @@ pub mod tests { ]) .with_metadata(metadata), ); - RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_iter_values(0..10)), - Arc::new(UInt32Array::from_iter_values(0..10)), - ], - )], + RecordBatch::try_new( schema, + vec![ + Arc::new(Int32Array::from_iter_values(0..10)), + Arc::new(UInt32Array::from_iter_values(0..10)), + ], ) + .unwrap() } - fn make_tbl_two_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { + fn make_tbl_two_test_batches() -> RecordBatch { let metadata = HashMap::from_iter(vec![("foo".to_string(), "bar".to_string())]); let schema = Arc::new( Schema::new(vec![ @@ -342,28 +339,26 @@ pub mod tests { ]) .with_metadata(metadata), ); - RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int64Array::from_iter_values(0..1000)), - Arc::new(StringArray::from_iter_values( - (0..1000).map(|i| i.to_string()), - )), - Arc::new(Float64Array::from_iter_values((0..1000).map(|i| i as f64))), - Arc::new(StringArray::from_iter_values( - (0..1000).map(|i| format!("{{\"i\":{}}}", i)), - )), - Arc::new(BinaryArray::from_iter_values( - (0..1000).map(|i| (i as u32).to_be_bytes().to_vec()), - )), - Arc::new(StringArray::from_iter_values( - (0..1000).map(|i| i.to_string()), - )), - ], - )], + RecordBatch::try_new( schema, + vec![ + Arc::new(Int64Array::from_iter_values(0..1000)), + Arc::new(StringArray::from_iter_values( + (0..1000).map(|i| i.to_string()), + )), + Arc::new(Float64Array::from_iter_values((0..1000).map(|i| i as f64))), + Arc::new(StringArray::from_iter_values( + (0..1000).map(|i| format!("{{\"i\":{}}}", i)), + )), + Arc::new(BinaryArray::from_iter_values( + (0..1000).map(|i| (i as u32).to_be_bytes().to_vec()), + )), + Arc::new(StringArray::from_iter_values( + (0..1000).map(|i| i.to_string()), + )), + ], ) + .unwrap() } struct TestFixture { diff --git a/rust/lancedb/src/table/datafusion/insert.rs b/rust/lancedb/src/table/datafusion/insert.rs index e3ee371ed..53ae9520c 100644 --- a/rust/lancedb/src/table/datafusion/insert.rs +++ b/rust/lancedb/src/table/datafusion/insert.rs @@ -222,7 +222,7 @@ mod tests { use std::vec; use super::*; - use arrow_array::{record_batch, Int32Array, RecordBatchIterator}; + use arrow_array::{record_batch, RecordBatchIterator}; use datafusion::prelude::SessionContext; use datafusion_catalog::MemTable; use tempfile::tempdir; @@ -238,11 +238,8 @@ mod tests { // Create initial table let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); - let schema = batch.schema(); - let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); - let table = db - .create_table("test_insert", Box::new(reader)) + .create_table("test_insert", batch) .execute() .await .unwrap(); @@ -279,11 +276,8 @@ mod tests { // Create initial table with 3 rows let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); - let schema = batch.schema(); - let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); - let table = db - .create_table("test_overwrite", Box::new(reader)) + .create_table("test_overwrite", batch) .execute() .await .unwrap(); @@ -318,20 +312,9 @@ mod tests { let db = connect(uri).execute().await.unwrap(); // Create initial table - let schema = Arc::new(ArrowSchema::new(vec![Field::new( - "id", - DataType::Int32, - false, - )])); - let batches = vec![RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - ) - .unwrap()]; - let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone()); - + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); let table = db - .create_table("test_empty", Box::new(reader)) + .create_table("test_empty", batch) .execute() .await .unwrap(); @@ -352,12 +335,13 @@ mod tests { false, )])); // Empty batches - let source_reader = RecordBatchIterator::new( - std::iter::empty::>(), - source_schema, - ); + let source_reader: Box = + Box::new(RecordBatchIterator::new( + std::iter::empty::>(), + source_schema, + )); let source_table = db - .create_table("empty_source", Box::new(source_reader)) + .create_table("empty_source", source_reader) .execute() .await .unwrap(); @@ -389,20 +373,10 @@ mod tests { let db = connect(uri).execute().await.unwrap(); // Create initial table - let schema = Arc::new(ArrowSchema::new(vec![Field::new( - "id", - DataType::Int32, - true, - )])); - let batches = - vec![ - RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from(vec![1]))]) - .unwrap(), - ]; - let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone()); - + let batch = record_batch!(("id", Int32, [1])).unwrap(); + let schema = batch.schema(); let table = db - .create_table("test_multi_batch", Box::new(reader)) + .create_table("test_multi_batch", batch) .execute() .await .unwrap(); diff --git a/rust/lancedb/src/table/datafusion/udtf/fts.rs b/rust/lancedb/src/table/datafusion/udtf/fts.rs index c3eea0574..2f8b8beb8 100644 --- a/rust/lancedb/src/table/datafusion/udtf/fts.rs +++ b/rust/lancedb/src/table/datafusion/udtf/fts.rs @@ -97,7 +97,7 @@ mod tests { table::datafusion::BaseTableAdapter, Connection, Table, }; - use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray}; + use arrow_array::{Int32Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema as ArrowSchema}; use datafusion::prelude::SessionContext; @@ -173,14 +173,7 @@ mod tests { // Create LanceDB database and table let db = crate::connect("memory://test").execute().await.unwrap(); - let table = db - .create_table( - "foo", - RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema), - ) - .execute() - .await - .unwrap(); + let table = db.create_table("foo", batch).execute().await.unwrap(); // Create FTS index table @@ -323,13 +316,7 @@ mod tests { RecordBatch::try_new(metadata_schema.clone(), vec![metadata_col, extra_col]).unwrap(); let _metadata_table = db - .create_table( - "metadata", - RecordBatchIterator::new( - vec![Ok(metadata_batch.clone())].into_iter(), - metadata_schema.clone(), - ), - ) + .create_table("metadata", metadata_batch.clone()) .execute() .await .unwrap(); @@ -393,14 +380,7 @@ mod tests { let batch = RecordBatch::try_new(schema.clone(), vec![id_col, text_col, category_col]).unwrap(); - let table = db - .create_table( - table_name, - RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema), - ) - .execute() - .await - .unwrap(); + let table = db.create_table(table_name, batch).execute().await.unwrap(); // Create FTS index table @@ -546,14 +526,7 @@ mod tests { ])); let batch = RecordBatch::try_new(schema.clone(), vec![id_col, text_col]).unwrap(); - let table = db - .create_table( - "docs", - RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema), - ) - .execute() - .await - .unwrap(); + let table = db.create_table("docs", batch).execute().await.unwrap(); // Create FTS index with position information for phrase queries table @@ -691,14 +664,7 @@ mod tests { let batch = RecordBatch::try_new(schema.clone(), vec![id_col, title_col, content_col]).unwrap(); - let table = db - .create_table( - "multi_col", - RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema), - ) - .execute() - .await - .unwrap(); + let table = db.create_table("multi_col", batch).execute().await.unwrap(); // Create FTS indices on both columns table @@ -963,13 +929,7 @@ mod tests { let metadata_batch = RecordBatch::try_new(metadata_schema.clone(), vec![metadata_id, extra_info]).unwrap(); let _metadata_table = db - .create_table( - "metadata", - RecordBatchIterator::new( - vec![Ok(metadata_batch.clone())].into_iter(), - metadata_schema, - ), - ) + .create_table("metadata", metadata_batch.clone()) .execute() .await .unwrap(); @@ -1358,14 +1318,7 @@ mod tests { ])); let batch = RecordBatch::try_new(schema.clone(), vec![id_col, text_col]).unwrap(); - let table = db - .create_table( - "docs", - RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema), - ) - .execute() - .await - .unwrap(); + let table = db.create_table("docs", batch).execute().await.unwrap(); // Create FTS index with position information table @@ -1510,14 +1463,7 @@ mod tests { let batch = RecordBatch::try_new(schema.clone(), vec![id_col, title_col, content_col]).unwrap(); - let table = db - .create_table( - "docs", - RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema), - ) - .execute() - .await - .unwrap(); + let table = db.create_table("docs", batch).execute().await.unwrap(); // Create FTS indices on both columns table @@ -1591,14 +1537,7 @@ mod tests { let batch = RecordBatch::try_new(schema.clone(), vec![id_col, title_col, content_col]).unwrap(); - let table = db - .create_table( - "docs", - RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema), - ) - .execute() - .await - .unwrap(); + let table = db.create_table("docs", batch).execute().await.unwrap(); // Create FTS indices table @@ -1724,36 +1663,23 @@ mod tests { .unwrap(); // Create table with simple text for n-gram testing - let data = RecordBatchIterator::new( - vec![RecordBatch::try_new( - Arc::new(ArrowSchema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("text", DataType::Utf8, false), - ])), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(StringArray::from(vec![ - "hello world", - "lance database", - "lance is cool", - ])), - ], - ) - .unwrap()] - .into_iter() - .map(Ok), + let data = RecordBatch::try_new( Arc::new(ArrowSchema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("text", DataType::Utf8, false), ])), - ); + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![ + "hello world", + "lance database", + "lance is cool", + ])), + ], + ) + .unwrap(); - let table = Arc::new( - db.create_table("docs", Box::new(data)) - .execute() - .await - .unwrap(), - ); + let table = Arc::new(db.create_table("docs", data).execute().await.unwrap()); // Create FTS index with n-gram tokenizer (default min_ngram_length=3) table @@ -1876,43 +1802,29 @@ mod tests { .unwrap(); // Create table with two text columns - let data = RecordBatchIterator::new( - vec![RecordBatch::try_new( - Arc::new(ArrowSchema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("title", DataType::Utf8, false), - Field::new("content", DataType::Utf8, false), - ])), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(StringArray::from(vec![ - "Important Document", - "Another Document", - "Random Text", - ])), - Arc::new(StringArray::from(vec![ - "This is important information", - "This has details", - "Nothing special here", - ])), - ], - ) - .unwrap()] - .into_iter() - .map(Ok), + let data = RecordBatch::try_new( Arc::new(ArrowSchema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("title", DataType::Utf8, false), Field::new("content", DataType::Utf8, false), ])), - ); + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![ + "Important Document", + "Another Document", + "Random Text", + ])), + Arc::new(StringArray::from(vec![ + "This is important information", + "This has details", + "Nothing special here", + ])), + ], + ) + .unwrap(); - let table = Arc::new( - db.create_table("docs", Box::new(data)) - .execute() - .await - .unwrap(), - ); + let table = Arc::new(db.create_table("docs", data).execute().await.unwrap()); // Create FTS indices on both columns table diff --git a/rust/lancedb/src/table/delete.rs b/rust/lancedb/src/table/delete.rs index 7d17f46a2..4c4c304ba 100644 --- a/rust/lancedb/src/table/delete.rs +++ b/rust/lancedb/src/table/delete.rs @@ -34,7 +34,7 @@ pub(crate) async fn execute_delete(table: &NativeTable, predicate: &str) -> Resu #[cfg(test)] mod tests { use crate::connect; - use arrow_array::{record_batch, Int32Array, RecordBatch, RecordBatchIterator}; + use arrow_array::{record_batch, Int32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; use std::sync::Arc; @@ -53,10 +53,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_delete", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_delete", batch) .execute() .await .unwrap(); @@ -102,10 +99,7 @@ mod tests { let original_schema = batch.schema(); let table = conn - .create_table( - "test_delete_all", - RecordBatchIterator::new(vec![Ok(batch)], original_schema.clone()), - ) + .create_table("test_delete_all", batch) .execute() .await .unwrap(); @@ -126,13 +120,8 @@ mod tests { // Create a table with 5 rows let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5])).unwrap(); - let schema = batch.schema(); - let table = conn - .create_table( - "test_delete_noop", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_delete_noop", batch) .execute() .await .unwrap(); diff --git a/rust/lancedb/src/table/optimize.rs b/rust/lancedb/src/table/optimize.rs index e278a8dbe..f75671e5b 100644 --- a/rust/lancedb/src/table/optimize.rs +++ b/rust/lancedb/src/table/optimize.rs @@ -212,7 +212,7 @@ pub(crate) async fn execute_optimize( #[cfg(test)] mod tests { - use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray}; + use arrow_array::{Int32Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use rstest::rstest; use std::sync::Arc; @@ -236,10 +236,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_compact", - RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), - ) + .create_table("test_compact", batch) .execute() .await .unwrap(); @@ -253,11 +250,7 @@ mod tests { ))], ) .unwrap(); - table - .add(RecordBatchIterator::new(vec![Ok(batch)], schema.clone())) - .execute() - .await - .unwrap(); + table.add(batch).execute().await.unwrap(); } // Verify we have multiple fragments before compaction @@ -322,10 +315,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_prune", - RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), - ) + .create_table("test_prune", batch) .execute() .await .unwrap(); @@ -339,11 +329,7 @@ mod tests { ))], ) .unwrap(); - table - .add(RecordBatchIterator::new(vec![Ok(batch)], schema.clone())) - .execute() - .await - .unwrap(); + table.add(batch).execute().await.unwrap(); } // Verify multiple versions exist @@ -405,10 +391,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_index_optimize", - RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), - ) + .create_table("test_index_optimize", batch) .execute() .await .unwrap(); @@ -426,11 +409,7 @@ mod tests { vec![Arc::new(Int32Array::from_iter_values(100..200))], ) .unwrap(); - table - .add(RecordBatchIterator::new(vec![Ok(batch)], schema.clone())) - .execute() - .await - .unwrap(); + table.add(batch).execute().await.unwrap(); // Verify index stats before optimization let indices = table.list_indices().await.unwrap(); @@ -474,10 +453,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_optimize_all", - RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), - ) + .create_table("test_optimize_all", batch) .execute() .await .unwrap(); @@ -491,11 +467,7 @@ mod tests { ))], ) .unwrap(); - table - .add(RecordBatchIterator::new(vec![Ok(batch)], schema.clone())) - .execute() - .await - .unwrap(); + table.add(batch).execute().await.unwrap(); } // Run all optimizations @@ -559,20 +531,13 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_deferred_remap", - RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone()), - ) + .create_table("test_deferred_remap", batch.clone()) .execute() .await .unwrap(); // Add more data - table - .add(RecordBatchIterator::new(vec![Ok(batch)], schema.clone())) - .execute() - .await - .unwrap(); + table.add(batch).execute().await.unwrap(); // Create an index table @@ -648,20 +613,13 @@ mod tests { let original_schema = batch.schema(); let table = conn - .create_table( - "test_schema_preserved", - RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone()), - ) + .create_table("test_schema_preserved", batch.clone()) .execute() .await .unwrap(); // Add more data - table - .add(RecordBatchIterator::new(vec![Ok(batch)], schema.clone())) - .execute() - .await - .unwrap(); + table.add(batch).execute().await.unwrap(); // Run compaction table @@ -703,10 +661,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_empty_optimize", - RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), - ) + .create_table("test_empty_optimize", batch) .execute() .await .unwrap(); @@ -752,19 +707,12 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_checkout_optimize", - RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone()), - ) + .create_table("test_checkout_optimize", batch.clone()) .execute() .await .unwrap(); - table - .add(RecordBatchIterator::new(vec![Ok(batch)], schema.clone())) - .execute() - .await - .unwrap(); + table.add(batch).execute().await.unwrap(); table.checkout(1).await.unwrap(); diff --git a/rust/lancedb/src/table/schema_evolution.rs b/rust/lancedb/src/table/schema_evolution.rs index 0f07e0a0b..3f774cd98 100644 --- a/rust/lancedb/src/table/schema_evolution.rs +++ b/rust/lancedb/src/table/schema_evolution.rs @@ -89,7 +89,7 @@ pub(crate) async fn execute_drop_columns( #[cfg(test)] mod tests { - use arrow_array::{record_batch, Int32Array, RecordBatchIterator, StringArray}; + use arrow_array::{record_batch, Int32Array, StringArray}; use arrow_schema::DataType; use futures::TryStreamExt; use lance::dataset::ColumnAlteration; @@ -105,13 +105,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_add_columns", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_add_columns", batch) .execute() .await .unwrap(); @@ -169,13 +165,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("x", Int32, [10, 20, 30])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_add_multi_columns", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_add_multi_columns", batch) .execute() .await .unwrap(); @@ -205,13 +197,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_add_const_column", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_add_const_column", batch) .execute() .await .unwrap(); @@ -255,13 +243,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("old_name", Int32, [1, 2, 3])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_alter_rename", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_alter_rename", batch) .execute() .await .unwrap(); @@ -304,10 +288,7 @@ mod tests { .unwrap(); let table = conn - .create_table( - "test_alter_nullable", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_alter_nullable", batch) .execute() .await .unwrap(); @@ -332,13 +313,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("num", Int32, [1, 2, 3])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_cast_type", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_cast_type", batch) .execute() .await .unwrap(); @@ -379,13 +356,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("num", Int32, [1, 2, 3])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_invalid_cast", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_invalid_cast", batch) .execute() .await .unwrap(); @@ -407,13 +380,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("a", Int32, [1, 2, 3]), ("b", Int32, [4, 5, 6])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_alter_multi", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_alter_multi", batch) .execute() .await .unwrap(); @@ -441,13 +410,9 @@ mod tests { let batch = record_batch!(("keep", Int32, [1, 2, 3]), ("remove", Int32, [4, 5, 6])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_drop_single", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_drop_single", batch) .execute() .await .unwrap(); @@ -478,13 +443,9 @@ mod tests { ("d", Int32, [7, 8]) ) .unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_drop_multi", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_drop_multi", batch) .execute() .await .unwrap(); @@ -511,13 +472,9 @@ mod tests { ("extra", Int32, [10, 20, 30]) ) .unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_drop_preserves", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_drop_preserves", batch) .execute() .await .unwrap(); @@ -567,13 +524,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("existing", Int32, [1, 2, 3])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_drop_nonexistent", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_drop_nonexistent", batch) .execute() .await .unwrap(); @@ -593,13 +546,9 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("existing", Int32, [1, 2, 3])).unwrap(); - let schema = batch.schema(); let table = conn - .create_table( - "test_alter_nonexistent", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_alter_nonexistent", batch) .execute() .await .unwrap(); @@ -623,13 +572,8 @@ mod tests { let conn = connect("memory://").execute().await.unwrap(); let batch = record_batch!(("a", Int32, [1, 2, 3]), ("b", Int32, [4, 5, 6])).unwrap(); - let schema = batch.schema(); - let table = conn - .create_table( - "test_version_increment", - RecordBatchIterator::new(vec![Ok(batch)], schema), - ) + .create_table("test_version_increment", batch) .execute() .await .unwrap(); diff --git a/rust/lancedb/src/table/update.rs b/rust/lancedb/src/table/update.rs index 6616dddc2..1d5ca5f4d 100644 --- a/rust/lancedb/src/table/update.rs +++ b/rust/lancedb/src/table/update.rs @@ -117,9 +117,8 @@ mod tests { use crate::query::{ExecutableQuery, Select}; use arrow_array::{ record_batch, Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array, - Float64Array, Int32Array, Int64Array, LargeStringArray, RecordBatch, RecordBatchIterator, - RecordBatchReader, StringArray, TimestampMillisecondArray, TimestampNanosecondArray, - UInt32Array, + Float64Array, Int32Array, Int64Array, LargeStringArray, RecordBatch, StringArray, + TimestampMillisecondArray, TimestampNanosecondArray, UInt32Array, }; use arrow_data::ArrayDataBuilder; use arrow_schema::{ArrowError, DataType, Field, Schema, TimeUnit}; @@ -167,51 +166,46 @@ mod tests { ), ])); - let record_batch_iter = RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_iter_values(0..10)), - Arc::new(Int64Array::from_iter_values(0..10)), - Arc::new(UInt32Array::from_iter_values(0..10)), - Arc::new(StringArray::from_iter_values(vec![ - "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", - ])), - Arc::new(LargeStringArray::from_iter_values(vec![ - "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", - ])), - Arc::new(Float32Array::from_iter_values((0..10).map(|i| i as f32))), - Arc::new(Float64Array::from_iter_values((0..10).map(|i| i as f64))), - Arc::new(Into::::into(vec![ - true, false, true, false, true, false, true, false, true, false, - ])), - Arc::new(Date32Array::from_iter_values(0..10)), - Arc::new(TimestampNanosecondArray::from_iter_values(0..10)), - Arc::new(TimestampMillisecondArray::from_iter_values(0..10)), - Arc::new( - create_fixed_size_list( - Float32Array::from_iter_values((0..20).map(|i| i as f32)), - 2, - ) - .unwrap(), - ), - Arc::new( - create_fixed_size_list( - Float64Array::from_iter_values((0..20).map(|i| i as f64)), - 2, - ) - .unwrap(), - ), - ], - ) - .unwrap()] - .into_iter() - .map(Ok), + let batch = RecordBatch::try_new( schema.clone(), - ); + vec![ + Arc::new(Int32Array::from_iter_values(0..10)), + Arc::new(Int64Array::from_iter_values(0..10)), + Arc::new(UInt32Array::from_iter_values(0..10)), + Arc::new(StringArray::from_iter_values(vec![ + "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", + ])), + Arc::new(LargeStringArray::from_iter_values(vec![ + "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", + ])), + Arc::new(Float32Array::from_iter_values((0..10).map(|i| i as f32))), + Arc::new(Float64Array::from_iter_values((0..10).map(|i| i as f64))), + Arc::new(Into::::into(vec![ + true, false, true, false, true, false, true, false, true, false, + ])), + Arc::new(Date32Array::from_iter_values(0..10)), + Arc::new(TimestampNanosecondArray::from_iter_values(0..10)), + Arc::new(TimestampMillisecondArray::from_iter_values(0..10)), + Arc::new( + create_fixed_size_list( + Float32Array::from_iter_values((0..20).map(|i| i as f32)), + 2, + ) + .unwrap(), + ), + Arc::new( + create_fixed_size_list( + Float64Array::from_iter_values((0..20).map(|i| i as f64)), + 2, + ) + .unwrap(), + ), + ], + ) + .unwrap(); let table = conn - .create_table("my_table", record_batch_iter) + .create_table("my_table", batch) .execute() .await .unwrap(); @@ -338,15 +332,13 @@ mod tests { Ok(FixedSizeListArray::from(data)) } - fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { + fn make_test_batch() -> RecordBatch { let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); - RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from_iter_values(0..10))], - )], - schema, + RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter_values(0..10))], ) + .unwrap() } #[tokio::test] @@ -367,12 +359,8 @@ mod tests { ) .unwrap(); - let schema = batch.schema(); - // need the iterator for create table - let record_batch_iter = RecordBatchIterator::new(vec![Ok(batch)], schema); - let table = conn - .create_table("my_table", record_batch_iter) + .create_table("my_table", batch) .execute() .await .unwrap(); @@ -430,7 +418,7 @@ mod tests { .await .unwrap(); let tbl = conn - .create_table("my_table", make_test_batches()) + .create_table("my_table", make_test_batch()) .execute() .await .unwrap(); diff --git a/rust/lancedb/src/test_utils.rs b/rust/lancedb/src/test_utils.rs index daf749bcc..1e7fc742b 100644 --- a/rust/lancedb/src/test_utils.rs +++ b/rust/lancedb/src/test_utils.rs @@ -3,3 +3,4 @@ pub mod connection; pub mod datagen; +pub mod embeddings; diff --git a/rust/lancedb/src/test_utils/datagen.rs b/rust/lancedb/src/test_utils/datagen.rs index 15b79e5d8..6439fbc5f 100644 --- a/rust/lancedb/src/test_utils/datagen.rs +++ b/rust/lancedb/src/test_utils/datagen.rs @@ -34,10 +34,7 @@ impl LanceDbDatagenExt for BatchGeneratorBuilder { schema, )); let db = connect("memory:///").execute().await.unwrap(); - db.create_table_streaming(table_name, stream) - .execute() - .await - .unwrap() + db.create_table(table_name, stream).execute().await.unwrap() } } @@ -48,8 +45,5 @@ pub async fn virtual_table(name: &str, values: &RecordBatch) -> Table { schema, )); let db = connect("memory:///").execute().await.unwrap(); - db.create_table_streaming(name, stream) - .execute() - .await - .unwrap() + db.create_table(name, stream).execute().await.unwrap() } diff --git a/rust/lancedb/src/test_utils/embeddings.rs b/rust/lancedb/src/test_utils/embeddings.rs new file mode 100644 index 000000000..48c9b5743 --- /dev/null +++ b/rust/lancedb/src/test_utils/embeddings.rs @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::{borrow::Cow, sync::Arc}; + +use arrow_array::{Array, FixedSizeListArray, Float32Array}; +use arrow_schema::{DataType, Field}; + +use crate::embeddings::EmbeddingFunction; +use crate::Result; + +#[derive(Debug, Clone)] +pub struct MockEmbed { + name: String, + dim: usize, +} + +impl MockEmbed { + pub fn new(name: impl Into, dim: usize) -> Self { + Self { + name: name.into(), + dim, + } + } +} + +impl EmbeddingFunction for MockEmbed { + fn name(&self) -> &str { + &self.name + } + + fn source_type(&self) -> Result> { + Ok(Cow::Borrowed(&DataType::Utf8)) + } + + fn dest_type(&self) -> Result> { + Ok(Cow::Owned(DataType::new_fixed_size_list( + DataType::Float32, + self.dim as _, + true, + ))) + } + + fn compute_source_embeddings(&self, source: Arc) -> Result> { + // We can't use the FixedSizeListBuilder here because it always adds a null bitmap + // and we want to explicitly work with non-nullable arrays. + let len = source.len(); + let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim])); + let field = Field::new("item", inner.data_type().clone(), false); + let arr = FixedSizeListArray::new(Arc::new(field), self.dim as _, inner, None); + + Ok(Arc::new(arr)) + } + + #[allow(unused_variables)] + fn compute_query_embeddings(&self, input: Arc) -> Result> { + todo!() + } +} diff --git a/rust/lancedb/tests/embedding_registry_test.rs b/rust/lancedb/tests/embedding_registry_test.rs index 4c636aad4..b002e47ff 100644 --- a/rust/lancedb/tests/embedding_registry_test.rs +++ b/rust/lancedb/tests/embedding_registry_test.rs @@ -15,7 +15,6 @@ use arrow_array::{ use arrow_schema::{DataType, Field, Schema}; use futures::StreamExt; use lancedb::{ - arrow::IntoArrow, connect, embeddings::{EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry}, query::ExecutableQuery, @@ -253,7 +252,7 @@ async fn test_no_func_in_registry_on_add() -> Result<()> { Ok(()) } -fn create_some_records() -> Result { +fn create_some_records() -> Result> { const TOTAL: usize = 2; let schema = Arc::new(Schema::new(vec![ diff --git a/rust/lancedb/tests/object_store_test.rs b/rust/lancedb/tests/object_store_test.rs index c874fefb9..0d06eb678 100644 --- a/rust/lancedb/tests/object_store_test.rs +++ b/rust/lancedb/tests/object_store_test.rs @@ -4,7 +4,7 @@ #![cfg(feature = "s3-test")] use std::sync::Arc; -use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray}; +use arrow_array::{Int32Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use aws_config::{BehaviorVersion, ConfigLoader, Region, SdkConfig}; @@ -111,7 +111,6 @@ async fn test_minio_lifecycle() -> Result<()> { .await?; let data = test_data(); - let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); let table = db.create_table("test_table", data).execute().await?; @@ -127,7 +126,6 @@ async fn test_minio_lifecycle() -> Result<()> { assert_eq!(row_count, 3); let data = test_data(); - let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); table.add(data).execute().await?; db.drop_table("test_table", &[]).await?; @@ -247,7 +245,6 @@ async fn test_encryption() -> Result<()> { // Create a table with encryption let data = test_data(); - let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); let mut builder = db.create_table("test_table", data); for (key, value) in CONFIG { @@ -274,7 +271,6 @@ async fn test_encryption() -> Result<()> { let table = db.open_table("test_table").execute().await?; let data = test_data(); - let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); table.add(data).execute().await?; validate_objects_encrypted(&bucket.0, "test_table", &key.0).await; @@ -300,7 +296,6 @@ async fn test_table_storage_options_override() -> Result<()> { // Create table overriding with key2 encryption let data = test_data(); - let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); let _table = db .create_table("test_override", data) .storage_option("aws_sse_kms_key_id", &key2.0) @@ -312,7 +307,6 @@ async fn test_table_storage_options_override() -> Result<()> { // Also test that a table created without override uses connection settings let data = test_data(); - let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); let _table2 = db.create_table("test_inherit", data).execute().await?; // Verify this table uses key1 from connection @@ -419,7 +413,6 @@ async fn test_concurrent_dynamodb_commit() { .unwrap(); let data = test_data(); - let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); let table = db.create_table("test_table", data).execute().await.unwrap(); @@ -430,7 +423,6 @@ async fn test_concurrent_dynamodb_commit() { let table = db.open_table("test_table").execute().await.unwrap(); let data = data.clone(); tasks.push(tokio::spawn(async move { - let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); table.add(data).execute().await.unwrap(); })); }