diff --git a/rust/lancedb/src/arrow.rs b/rust/lancedb/src/arrow.rs index 15d78ac2..55e34a60 100644 --- a/rust/lancedb/src/arrow.rs +++ b/rust/lancedb/src/arrow.rs @@ -101,3 +101,21 @@ impl>> RecordBatchStream self.schema.clone() } } + +/// A trait for converting incoming data to Arrow +/// +/// Integrations should implement this trait to allow data to be +/// imported directly from the integration. For example, implementing +/// this trait for `Vec>` would allow the `Vec` to be directly +/// used in methods like [`crate::connection::Connection::create_table`] +/// or [`crate::table::Table::add`] +pub trait IntoArrow { + /// Convert the data into an Arrow array + fn into_arrow(self) -> Result>; +} + +impl IntoArrow for Box { + fn into_arrow(self) -> Result> { + Ok(self) + } +} diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 3c540b4f..54ae8d27 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -27,6 +27,7 @@ use object_store::{ }; use snafu::prelude::*; +use crate::arrow::IntoArrow; use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result}; use crate::io::object_store::MirroringObjectStoreWrapper; use crate::table::{NativeTable, WriteOptions}; @@ -116,23 +117,27 @@ impl TableNamesBuilder { } } +pub struct NoData {} + +impl IntoArrow for NoData { + fn into_arrow(self) -> Result> { + unreachable!("NoData should never be converted to Arrow") + } +} + /// A builder for configuring a [`Connection::create_table`] operation -pub struct CreateTableBuilder { +pub struct CreateTableBuilder { parent: Arc, pub(crate) name: String, - pub(crate) data: Option>, + pub(crate) data: Option, pub(crate) schema: Option, pub(crate) mode: CreateTableMode, pub(crate) write_options: WriteOptions, } // Builder methods that only apply when we have initial data -impl CreateTableBuilder { - fn new( - parent: Arc, - name: String, - data: Box, - ) -> Self { +impl CreateTableBuilder { + fn new(parent: Arc, name: String, data: T) -> Self { Self { parent, name, @@ -151,12 +156,32 @@ impl CreateTableBuilder { /// Execute the create table operation pub async fn execute(self) -> Result { - self.parent.clone().do_create_table(self).await + let parent = self.parent.clone(); + let (data, builder) = self.extract_data()?; + parent.do_create_table(builder, data).await + } + + fn extract_data( + mut self, + ) -> Result<( + Box, + CreateTableBuilder, + )> { + let data = self.data.take().unwrap().into_arrow()?; + let builder = CreateTableBuilder:: { + parent: self.parent, + name: self.name, + data: None, + schema: self.schema, + mode: self.mode, + write_options: self.write_options, + }; + Ok((data, builder)) } } // Builder methods that only apply when we do not have initial data -impl CreateTableBuilder { +impl CreateTableBuilder { fn new(parent: Arc, name: String, schema: SchemaRef) -> Self { Self { parent, @@ -174,7 +199,7 @@ impl CreateTableBuilder { } } -impl CreateTableBuilder { +impl CreateTableBuilder { /// Set the mode for creating the table /// /// This controls what happens if a table with the given name already exists @@ -237,17 +262,24 @@ pub(crate) trait ConnectionInternal: Send + Sync + std::fmt::Debug + std::fmt::Display + 'static { async fn table_names(&self, options: TableNamesBuilder) -> Result>; - async fn do_create_table(&self, options: CreateTableBuilder) -> Result
; + async fn do_create_table( + &self, + options: CreateTableBuilder, + data: Box, + ) -> Result
; async fn do_open_table(&self, options: OpenTableBuilder) -> Result
; async fn drop_table(&self, name: &str) -> Result<()>; async fn drop_db(&self) -> Result<()>; - async fn do_create_empty_table(&self, options: CreateTableBuilder) -> Result
{ - let batches = RecordBatchIterator::new(vec![], options.schema.unwrap()); - let opts = CreateTableBuilder::::new(options.parent, options.name, Box::new(batches)) - .mode(options.mode) - .write_options(options.write_options); - self.do_create_table(opts).await + async fn do_create_empty_table( + &self, + options: CreateTableBuilder, + ) -> Result
{ + let batches = Box::new(RecordBatchIterator::new( + vec![], + options.schema.as_ref().unwrap().clone(), + )); + self.do_create_table(options, batches).await } } @@ -285,12 +317,12 @@ impl Connection { /// /// * `name` - The name of the table /// * `initial_data` - The initial data to write to the table - pub fn create_table( + pub fn create_table( &self, name: impl Into, - initial_data: Box, - ) -> CreateTableBuilder { - CreateTableBuilder::::new(self.internal.clone(), name.into(), initial_data) + initial_data: T, + ) -> CreateTableBuilder { + CreateTableBuilder::::new(self.internal.clone(), name.into(), initial_data) } /// Create an empty table with a given schema @@ -303,8 +335,8 @@ impl Connection { &self, name: impl Into, schema: SchemaRef, - ) -> CreateTableBuilder { - CreateTableBuilder::::new(self.internal.clone(), name.into(), schema) + ) -> CreateTableBuilder { + CreateTableBuilder::::new(self.internal.clone(), name.into(), schema) } /// Open an existing table in the database @@ -694,7 +726,11 @@ impl ConnectionInternal for Database { Ok(f) } - async fn do_create_table(&self, options: CreateTableBuilder) -> Result
{ + async fn do_create_table( + &self, + options: CreateTableBuilder, + data: Box, + ) -> Result
{ let table_uri = self.table_uri(&options.name)?; let mut write_params = options.write_options.lance_write_params.unwrap_or_default(); @@ -705,7 +741,7 @@ impl ConnectionInternal for Database { match NativeTable::create( &table_uri, &options.name, - options.data.unwrap(), + data, self.store_wrapper.clone(), Some(write_params), self.read_consistency_interval, diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index bf3f8668..a60dfb79 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -77,7 +77,7 @@ impl Select { /// may be installed. These models may accept something other than f32. For example, /// sentence transformers typically expect the query to be a string. This means that /// any kind of conversion library should expect to convert more than just f32. -pub trait ToQueryVector { +pub trait IntoQueryVector { /// Convert the user's query vector input to a query vector /// /// This trait exists to allow users to provide many different types as @@ -112,7 +112,7 @@ pub trait ToQueryVector { } // TODO: perhaps support some casts like f32->f64 and maybe even f64->f32? -impl ToQueryVector for Arc { +impl IntoQueryVector for Arc { fn to_query_vector( self, data_type: &DataType, @@ -147,7 +147,7 @@ impl ToQueryVector for Arc { } } -impl ToQueryVector for &dyn Array { +impl IntoQueryVector for &dyn Array { fn to_query_vector( self, data_type: &DataType, @@ -167,7 +167,7 @@ impl ToQueryVector for &dyn Array { } } -impl ToQueryVector for &[f16] { +impl IntoQueryVector for &[f16] { fn to_query_vector( self, data_type: &DataType, @@ -197,7 +197,7 @@ impl ToQueryVector for &[f16] { } } -impl ToQueryVector for &[f32] { +impl IntoQueryVector for &[f32] { fn to_query_vector( self, data_type: &DataType, @@ -227,7 +227,7 @@ impl ToQueryVector for &[f32] { } } -impl ToQueryVector for &[f64] { +impl IntoQueryVector for &[f64] { fn to_query_vector( self, data_type: &DataType, @@ -257,7 +257,7 @@ impl ToQueryVector for &[f64] { } } -impl ToQueryVector for &[f16; N] { +impl IntoQueryVector for &[f16; N] { fn to_query_vector( self, data_type: &DataType, @@ -268,7 +268,7 @@ impl ToQueryVector for &[f16; N] { } } -impl ToQueryVector for &[f32; N] { +impl IntoQueryVector for &[f32; N] { fn to_query_vector( self, data_type: &DataType, @@ -279,7 +279,7 @@ impl ToQueryVector for &[f32; N] { } } -impl ToQueryVector for &[f64; N] { +impl IntoQueryVector for &[f64; N] { fn to_query_vector( self, data_type: &DataType, @@ -290,7 +290,7 @@ impl ToQueryVector for &[f64; N] { } } -impl ToQueryVector for Vec { +impl IntoQueryVector for Vec { fn to_query_vector( self, data_type: &DataType, @@ -301,7 +301,7 @@ impl ToQueryVector for Vec { } } -impl ToQueryVector for Vec { +impl IntoQueryVector for Vec { fn to_query_vector( self, data_type: &DataType, @@ -312,7 +312,7 @@ impl ToQueryVector for Vec { } } -impl ToQueryVector for Vec { +impl IntoQueryVector for Vec { fn to_query_vector( self, data_type: &DataType, @@ -530,7 +530,7 @@ impl Query { /// # Arguments /// /// * `vector` - The vector that will be used for search. - pub fn nearest_to(self, vector: impl ToQueryVector) -> Result { + pub fn nearest_to(self, vector: impl IntoQueryVector) -> Result { let mut vector_query = self.into_vector(); let query_vector = vector.to_query_vector(&DataType::Float32, "default")?; vector_query.query_vector = Some(query_vector); diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index fe430f3d..e45aca6a 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -14,13 +14,14 @@ use std::sync::Arc; +use arrow_array::RecordBatchReader; use async_trait::async_trait; use reqwest::header::CONTENT_TYPE; use serde::Deserialize; use tokio::task::spawn_blocking; use crate::connection::{ - ConnectionInternal, CreateTableBuilder, OpenTableBuilder, TableNamesBuilder, + ConnectionInternal, CreateTableBuilder, NoData, OpenTableBuilder, TableNamesBuilder, }; use crate::error::Result; use crate::Table; @@ -74,8 +75,11 @@ impl ConnectionInternal for RemoteDatabase { Ok(rsp.json::().await?.tables) } - async fn do_create_table(&self, options: CreateTableBuilder) -> Result
{ - let data = options.data.unwrap(); + async fn do_create_table( + &self, + options: CreateTableBuilder, + data: Box, + ) -> Result
{ // TODO: https://github.com/lancedb/lancedb/issues/1026 // We should accept data from an async source. In the meantime, spawn this as blocking // to make sure we don't block the tokio runtime if the source is slow. diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 433ef08e..cbfa4fac 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -4,6 +4,7 @@ use async_trait::async_trait; use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewColumnTransform}; use crate::{ + connection::NoData, error::Result, index::{IndexBuilder, IndexConfig}, query::{Query, QueryExecutionOptions, VectorQuery}, @@ -63,7 +64,11 @@ impl TableInternal for RemoteTable { async fn count_rows(&self, _filter: Option) -> Result { todo!() } - async fn add(&self, _add: AddDataBuilder) -> Result<()> { + async fn add( + &self, + _add: AddDataBuilder, + _data: Box, + ) -> Result<()> { todo!() } async fn plain_query( diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index a9eb1146..cd09f5ca 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -42,6 +42,8 @@ use lance_index::{optimize::OptimizeOptions, DatasetIndexExt}; use log::info; use snafu::whatever; +use crate::arrow::IntoArrow; +use crate::connection::NoData; use crate::error::{Error, Result}; use crate::index::vector::{IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics}; use crate::index::IndexConfig; @@ -50,7 +52,7 @@ use crate::index::{ Index, IndexBuilder, }; use crate::query::{ - Query, QueryExecutionOptions, Select, ToQueryVector, VectorQuery, DEFAULT_TOP_K, + IntoQueryVector, Query, QueryExecutionOptions, Select, VectorQuery, DEFAULT_TOP_K, }; use crate::utils::{default_vector_column, PatchReadParam, PatchWriteParam}; @@ -124,14 +126,14 @@ pub enum AddDataMode { /// A builder for configuring a [`crate::connection::Connection::create_table`] or [`Table::add`] /// operation -pub struct AddDataBuilder { +pub struct AddDataBuilder { parent: Arc, - pub(crate) data: Box, + pub(crate) data: T, pub(crate) mode: AddDataMode, pub(crate) write_options: WriteOptions, } -impl std::fmt::Debug for AddDataBuilder { +impl std::fmt::Debug for AddDataBuilder { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("AddDataBuilder") .field("parent", &self.parent) @@ -141,7 +143,7 @@ impl std::fmt::Debug for AddDataBuilder { } } -impl AddDataBuilder { +impl AddDataBuilder { pub fn mode(mut self, mode: AddDataMode) -> Self { self.mode = mode; self @@ -153,7 +155,15 @@ impl AddDataBuilder { } pub async fn execute(self) -> Result<()> { - self.parent.clone().add(self).await + let parent = self.parent.clone(); + let data = self.data.into_arrow()?; + let without_data = AddDataBuilder:: { + data: NoData {}, + mode: self.mode, + parent: self.parent, + write_options: self.write_options, + }; + parent.add(without_data, data).await } } @@ -233,7 +243,6 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn async fn schema(&self) -> Result; /// Count the number of rows in this table. async fn count_rows(&self, filter: Option) -> Result; - async fn add(&self, add: AddDataBuilder) -> Result<()>; async fn plain_query( &self, query: &Query, @@ -244,6 +253,11 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn query: &VectorQuery, options: QueryExecutionOptions, ) -> Result; + async fn add( + &self, + add: AddDataBuilder, + data: Box, + ) -> Result<()>; async fn delete(&self, predicate: &str) -> Result<()>; async fn update(&self, update: UpdateBuilder) -> Result<()>; async fn create_index(&self, index: IndexBuilder) -> Result<()>; @@ -319,7 +333,7 @@ impl Table { /// /// * `batches` data to be added to the Table /// * `options` options to control how data is added - pub fn add(&self, batches: Box) -> AddDataBuilder { + pub fn add(&self, batches: T) -> AddDataBuilder { AddDataBuilder { parent: self.inner.clone(), data: batches, @@ -637,7 +651,7 @@ impl Table { /// This is a convenience method for preparing a vector query and /// is the same thing as calling `nearest_to` on the builder returned /// by `query`. See [`Query::nearest_to`] for more details. - pub fn vector_search(&self, query: impl ToQueryVector) -> Result { + pub fn vector_search(&self, query: impl IntoQueryVector) -> Result { self.query().nearest_to(query) } @@ -1288,7 +1302,11 @@ impl TableInternal for NativeTable { } } - async fn add(&self, add: AddDataBuilder) -> Result<()> { + async fn add( + &self, + add: AddDataBuilder, + data: Box, + ) -> Result<()> { let lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams { mode: match add.mode { AddDataMode::Append => WriteMode::Append, @@ -1305,7 +1323,7 @@ impl TableInternal for NativeTable { self.dataset.ensure_mutable().await?; - let dataset = Dataset::write(add.data, &self.uri, Some(lance_params)).await?; + let dataset = Dataset::write(data, &self.uri, Some(lance_params)).await?; self.dataset.set_latest(dataset).await; Ok(()) }