diff --git a/rust/lancedb/src/arrow.rs b/rust/lancedb/src/arrow.rs index 32adef9a..133f5d1f 100644 --- a/rust/lancedb/src/arrow.rs +++ b/rust/lancedb/src/arrow.rs @@ -4,12 +4,14 @@ use std::{pin::Pin, sync::Arc}; pub use arrow_schema; -use futures::{Stream, StreamExt}; +use datafusion_common::DataFusionError; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use futures::{Stream, StreamExt, TryStreamExt}; #[cfg(feature = "polars")] use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame}; -use crate::error::Result; +use crate::{error::Result, Error}; /// An iterator of batches that also has a schema pub trait RecordBatchReader: Iterator> { @@ -65,6 +67,20 @@ impl From for SendableRecordBatchS } } +pub trait SendableRecordBatchStreamExt { + fn into_df_stream(self) -> datafusion_physical_plan::SendableRecordBatchStream; +} + +impl SendableRecordBatchStreamExt for SendableRecordBatchStream { + fn into_df_stream(self) -> datafusion_physical_plan::SendableRecordBatchStream { + let schema = self.schema(); + Box::pin(RecordBatchStreamAdapter::new( + schema, + self.map_err(|ldb_err| DataFusionError::External(ldb_err.into())), + )) + } +} + /// A simple RecordBatchStream formed from the two parts (stream + schema) #[pin_project::pin_project] pub struct SimpleRecordBatchStream>> { @@ -101,7 +117,7 @@ impl>> RecordBatchStream /// used in methods like [`crate::connection::Connection::create_table`] /// or [`crate::table::Table::add`] pub trait IntoArrow { - /// Convert the data into an Arrow array + /// Convert the data into an iterator of Arrow batches fn into_arrow(self) -> Result>; } @@ -113,11 +129,38 @@ impl IntoArrow for T { } } +/// A trait for converting incoming data to Arrow asynchronously +/// +/// Serves the same purpose as [`IntoArrow`], but for asynchronous data. +/// +/// Note: Arrow has no async equivalent to RecordBatchReader and so +pub trait IntoArrowStream { + /// Convert the data into a stream of Arrow batches + fn into_arrow(self) -> Result; +} + impl>> SimpleRecordBatchStream { pub fn new(stream: S, schema: Arc) -> Self { Self { schema, stream } } } + +impl IntoArrowStream for SendableRecordBatchStream { + fn into_arrow(self) -> Result { + Ok(self) + } +} + +impl IntoArrowStream for datafusion_physical_plan::SendableRecordBatchStream { + fn into_arrow(self) -> Result { + let schema = self.schema(); + let stream = self.map_err(|df_err| Error::Runtime { + message: df_err.to_string(), + }); + Ok(Box::pin(SimpleRecordBatchStream::new(stream, schema))) + } +} + #[cfg(feature = "polars")] /// An iterator of record batches formed from a Polars DataFrame. pub struct PolarsDataFrameRecordBatchReader { diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 84e01208..7685da3d 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -11,7 +11,7 @@ use arrow_schema::{Field, SchemaRef}; use lance::dataset::ReadParams; use object_store::aws::AwsCredential; -use crate::arrow::IntoArrow; +use crate::arrow::{IntoArrow, IntoArrowStream, SendableRecordBatchStream}; use crate::database::listing::{ ListingDatabase, OPT_NEW_TABLE_STORAGE_VERSION, OPT_NEW_TABLE_V2_MANIFEST_PATHS, }; @@ -75,6 +75,14 @@ impl IntoArrow for NoData { } } +// Stores the value given from the initial CreateTableBuilder::new call +// and defers errors until `execute` is called +enum CreateTableBuilderInitialData { + None, + Iterator(Result>), + Stream(Result), +} + /// A builder for configuring a [`Connection::create_table`] operation pub struct CreateTableBuilder { parent: Arc, @@ -83,7 +91,7 @@ pub struct CreateTableBuilder { request: CreateTableRequest, // This is a bit clumsy but we defer errors until `execute` is called // to maintain backwards compatibility - data: Option>>, + data: CreateTableBuilderInitialData, } // Builder methods that only apply when we have initial data @@ -103,7 +111,26 @@ impl CreateTableBuilder { ), embeddings: Vec::new(), embedding_registry, - data: Some(data.into_arrow()), + data: CreateTableBuilderInitialData::Iterator(data.into_arrow()), + } + } + + fn new_streaming( + parent: Arc, + name: String, + data: T, + embedding_registry: Arc, + ) -> Self { + let dummy_schema = Arc::new(arrow_schema::Schema::new(Vec::::default())); + Self { + parent, + request: CreateTableRequest::new( + name, + CreateTableData::Empty(TableDefinition::new_from_schema(dummy_schema)), + ), + embeddings: Vec::new(), + embedding_registry, + data: CreateTableBuilderInitialData::Stream(data.into_arrow()), } } @@ -125,17 +152,37 @@ impl CreateTableBuilder { } fn into_request(self) -> Result { - let data = if self.embeddings.is_empty() { - self.data.unwrap()? + if self.embeddings.is_empty() { + match self.data { + CreateTableBuilderInitialData::Iterator(maybe_iter) => { + let data = maybe_iter?; + Ok(CreateTableRequest { + data: CreateTableData::Data(data), + ..self.request + }) + } + CreateTableBuilderInitialData::None => { + unreachable!("No data provided for CreateTableBuilder") + } + CreateTableBuilderInitialData::Stream(maybe_stream) => { + let data = maybe_stream?; + Ok(CreateTableRequest { + data: CreateTableData::StreamingData(data), + ..self.request + }) + } + } } else { - let data = self.data.unwrap()?; - Box::new(WithEmbeddings::new(data, self.embeddings)) - }; - let req = self.request; - Ok(CreateTableRequest { - data: CreateTableData::Data(data), - ..req - }) + let CreateTableBuilderInitialData::Iterator(maybe_iter) = self.data else { + return Err(Error::NotSupported { message: "Creating a table with embeddings is currently not support when the input is streaming".to_string() }); + }; + let data = maybe_iter?; + let data = Box::new(WithEmbeddings::new(data, self.embeddings)); + Ok(CreateTableRequest { + data: CreateTableData::Data(data), + ..self.request + }) + } } } @@ -151,7 +198,7 @@ impl CreateTableBuilder { Self { parent, request: CreateTableRequest::new(name, CreateTableData::Empty(table_definition)), - data: None, + data: CreateTableBuilderInitialData::None, embeddings: Vec::default(), embedding_registry, } @@ -432,7 +479,7 @@ impl Connection { TableNamesBuilder::new(self.internal.clone()) } - /// Create a new table from data + /// Create a new table from an iterator of data /// /// # Parameters /// @@ -451,6 +498,25 @@ impl Connection { ) } + /// Create a new table from a stream of data + /// + /// # Parameters + /// + /// * `name` - The name of the table + /// * `initial_data` - The initial data to write to the table + pub fn create_table_streaming( + &self, + name: impl Into, + initial_data: T, + ) -> CreateTableBuilder { + CreateTableBuilder::::new_streaming( + self.internal.clone(), + name.into(), + initial_data, + self.embedding_registry.clone(), + ) + } + /// Create an empty table with a given schema /// /// # Parameters @@ -788,12 +854,16 @@ mod test_utils { mod tests { use std::fs::create_dir_all; + use arrow::compute::concat_batches; use arrow_array::RecordBatchReader; use arrow_schema::{DataType, Field, Schema}; - use futures::TryStreamExt; + use datafusion_physical_plan::stream::RecordBatchStreamAdapter; + use futures::{stream, TryStreamExt}; + use lance::error::{ArrowResult, DataFusionResult}; use lance_testing::datagen::{BatchGenerator, IncrementingInt32}; use tempfile::tempdir; + use crate::arrow::SimpleRecordBatchStream; use crate::database::listing::{ListingDatabaseOptions, NewTableConfig}; use crate::query::QueryBase; use crate::query::{ExecutableQuery, QueryExecutionOptions}; @@ -976,6 +1046,63 @@ mod tests { assert_eq!(batches.len(), 1); } + #[tokio::test] + async fn test_create_table_streaming() { + let tmp_dir = tempdir().unwrap(); + + let uri = tmp_dir.path().to_str().unwrap(); + let db = connect(uri).execute().await.unwrap(); + + let batches = make_data().collect::>>().unwrap(); + + let schema = batches.first().unwrap().schema(); + let one_batch = concat_batches(&schema, batches.iter()).unwrap(); + + let ldb_stream = stream::iter(batches.clone().into_iter().map(Result::Ok)); + let ldb_stream: SendableRecordBatchStream = + Box::pin(SimpleRecordBatchStream::new(ldb_stream, schema.clone())); + + let tbl1 = db + .create_table_streaming("one", ldb_stream) + .execute() + .await + .unwrap(); + + let df_stream = stream::iter(batches.into_iter().map(DataFusionResult::Ok)); + let df_stream: datafusion_physical_plan::SendableRecordBatchStream = + Box::pin(RecordBatchStreamAdapter::new(schema.clone(), df_stream)); + + let tbl2 = db + .create_table_streaming("two", df_stream) + .execute() + .await + .unwrap(); + + let tbl1_data = tbl1 + .query() + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + let tbl1_data = concat_batches(&schema, tbl1_data.iter()).unwrap(); + assert_eq!(tbl1_data, one_batch); + + let tbl2_data = tbl2 + .query() + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + let tbl2_data = concat_batches(&schema, tbl2_data.iter()).unwrap(); + assert_eq!(tbl2_data, one_batch); + } + #[tokio::test] async fn drop_table() { let tmp_dir = tempdir().unwrap(); diff --git a/rust/lancedb/src/database.rs b/rust/lancedb/src/database.rs index c5e40544..63406059 100644 --- a/rust/lancedb/src/database.rs +++ b/rust/lancedb/src/database.rs @@ -18,8 +18,13 @@ use std::collections::HashMap; use std::sync::Arc; use arrow_array::RecordBatchReader; +use async_trait::async_trait; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use futures::stream; use lance::dataset::ReadParams; +use lance_datafusion::utils::StreamingWriteSource; +use crate::arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt}; use crate::error::Result; use crate::table::{BaseTable, TableDefinition, WriteOptions}; @@ -81,12 +86,41 @@ impl Default for CreateTableMode { /// The data to start a table or a schema to create an empty table pub enum CreateTableData { - /// Creates a table using data, no schema required as it will be obtained from the data + /// Creates a table using an iterator of data, the schema will be obtained from the data Data(Box), + /// Creates a table using a stream of data, the schema will be obtained from the data + StreamingData(SendableRecordBatchStream), /// Creates an empty table, the definition / schema must be provided separately Empty(TableDefinition), } +impl CreateTableData { + pub fn schema(&self) -> Arc { + match self { + Self::Data(reader) => reader.schema(), + Self::StreamingData(stream) => stream.schema(), + Self::Empty(definition) => definition.schema.clone(), + } + } +} + +#[async_trait] +impl StreamingWriteSource for CreateTableData { + fn arrow_schema(&self) -> Arc { + self.schema() + } + fn into_stream(self) -> datafusion_physical_plan::SendableRecordBatchStream { + match self { + Self::Data(reader) => reader.into_stream(), + Self::StreamingData(stream) => stream.into_df_stream(), + Self::Empty(table_definition) => { + let schema = table_definition.schema.clone(); + Box::pin(RecordBatchStreamAdapter::new(schema, stream::empty())) + } + } + } +} + /// A request to create a table pub struct CreateTableRequest { /// The name of the new table diff --git a/rust/lancedb/src/database/listing.rs b/rust/lancedb/src/database/listing.rs index ff38cbe0..fa711ca9 100644 --- a/rust/lancedb/src/database/listing.rs +++ b/rust/lancedb/src/database/listing.rs @@ -7,9 +7,9 @@ use std::fs::create_dir_all; use std::path::Path; use std::{collections::HashMap, sync::Arc}; -use arrow_array::RecordBatchIterator; use lance::dataset::{ReadParams, WriteMode}; use lance::io::{ObjectStore, ObjectStoreParams, ObjectStoreRegistry, WrappingObjectStore}; +use lance_datafusion::utils::StreamingWriteSource; use lance_encoding::version::LanceFileVersion; use lance_table::io::commit::commit_handler_from_url; use object_store::local::LocalFileSystem; @@ -22,8 +22,8 @@ use crate::table::NativeTable; use crate::utils::validate_table_name; use super::{ - BaseTable, CreateTableData, CreateTableMode, CreateTableRequest, Database, DatabaseOptions, - OpenTableRequest, TableNamesRequest, + BaseTable, CreateTableMode, CreateTableRequest, Database, DatabaseOptions, OpenTableRequest, + TableNamesRequest, }; /// File extension to indicate a lance table @@ -401,19 +401,12 @@ impl Database for ListingDatabase { write_params.mode = WriteMode::Overwrite; } - let data = match request.data { - CreateTableData::Data(data) => data, - CreateTableData::Empty(table_definition) => { - let schema = table_definition.schema.clone(); - Box::new(RecordBatchIterator::new(vec![], schema)) - } - }; - let data_schema = data.schema(); + let data_schema = request.data.arrow_schema(); match NativeTable::create( &table_uri, &request.name, - data, + request.data, self.store_wrapper.clone(), Some(write_params), self.read_consistency_interval, diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 7d4d8da0..951c91be 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -164,6 +164,11 @@ impl Database for RemoteDatabase { async fn create_table(&self, request: CreateTableRequest) -> Result> { let data = match request.data { CreateTableData::Data(data) => data, + CreateTableData::StreamingData(_) => { + return Err(Error::NotSupported { + message: "Creating a remote table from a streaming source".to_string(), + }) + } CreateTableData::Empty(table_definition) => { let schema = table_definition.schema.clone(); Box::new(RecordBatchIterator::new(vec![], schema)) diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 0747df90..0f766d36 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -28,13 +28,13 @@ pub use lance::dataset::NewColumnTransform; pub use lance::dataset::ReadParams; pub use lance::dataset::Version; use lance::dataset::{ - Dataset, InsertBuilder, UpdateBuilder as LanceUpdateBuilder, WhenMatched, WriteMode, - WriteParams, + InsertBuilder, UpdateBuilder as LanceUpdateBuilder, WhenMatched, WriteMode, WriteParams, }; use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource}; use lance::index::vector::utils::infer_vector_dim; use lance::io::WrappingObjectStore; use lance_datafusion::exec::execute_plan; +use lance_datafusion::utils::StreamingWriteSource; use lance_index::vector::hnsw::builder::HnswBuildParams; use lance_index::vector::ivf::IvfBuildParams; use lance_index::vector::pq::PQBuildParams; @@ -1264,7 +1264,7 @@ impl NativeTable { pub async fn create( uri: &str, name: &str, - batches: impl RecordBatchReader + Send + 'static, + batches: impl StreamingWriteSource, write_store_wrapper: Option>, params: Option, read_consistency_interval: Option, @@ -1279,7 +1279,9 @@ impl NativeTable { None => params, }; - let dataset = Dataset::write(batches, uri, Some(params)) + let insert_builder = InsertBuilder::new(uri).with_params(¶ms); + let dataset = insert_builder + .execute_stream(batches) .await .map_err(|e| match e { lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists { @@ -1287,6 +1289,7 @@ impl NativeTable { }, source => Error::Lance { source }, })?; + Ok(Self { name: name.to_string(), uri: uri.to_string(), @@ -2391,8 +2394,9 @@ mod tests { use arrow_data::ArrayDataBuilder; use arrow_schema::{DataType, Field, Schema, TimeUnit}; use futures::TryStreamExt; - use lance::dataset::{Dataset, WriteMode}; + use lance::dataset::WriteMode; use lance::io::{ObjectStoreParams, WrappingObjectStore}; + use lance::Dataset; use rand::Rng; use tempfile::tempdir; @@ -2442,6 +2446,7 @@ mod tests { let uri = tmp_dir.path().to_str().unwrap(); let batches = make_test_batches(); + let batches = Box::new(batches) as Box; let table = NativeTable::create(uri, "test", batches, None, None, None) .await .unwrap();