diff --git a/nodejs/src/connection.rs b/nodejs/src/connection.rs index 7f4a82d9..1fd3ca56 100644 --- a/nodejs/src/connection.rs +++ b/nodejs/src/connection.rs @@ -12,18 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use napi::bindgen_prelude::*; use napi_derive::*; use crate::table::Table; -use vectordb::connection::{Connection as LanceDBConnection, Database}; +use vectordb::connection::Connection as LanceDBConnection; use vectordb::ipc::ipc_file_to_batches; #[napi] pub struct Connection { - conn: Arc, + conn: LanceDBConnection, } #[napi] @@ -32,9 +30,9 @@ impl Connection { #[napi(factory)] pub async fn new(uri: String) -> napi::Result { Ok(Self { - conn: Arc::new(Database::connect(&uri).await.map_err(|e| { + conn: vectordb::connect(&uri).execute().await.map_err(|e| { napi::Error::from_reason(format!("Failed to connect to database: {}", e)) - })?), + })?, }) } @@ -59,7 +57,8 @@ impl Connection { .map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?; let tbl = self .conn - .create_table(&name, Box::new(batches), None) + .create_table(&name, Box::new(batches)) + .execute() .await .map_err(|e| napi::Error::from_reason(format!("{}", e)))?; Ok(Table::new(tbl)) @@ -70,6 +69,7 @@ impl Connection { let tbl = self .conn .open_table(&name) + .execute() .await .map_err(|e| napi::Error::from_reason(format!("{}", e)))?; Ok(Table::new(tbl)) diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index fdb16b26..a02522cf 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -15,6 +15,7 @@ use arrow_ipc::writer::FileWriter; use napi::bindgen_prelude::*; use napi_derive::napi; +use vectordb::table::AddDataOptions; use vectordb::{ipc::ipc_file_to_batches, table::TableRef}; use crate::index::IndexBuilder; @@ -48,12 +49,15 @@ impl Table { pub async fn add(&self, buf: Buffer) -> napi::Result<()> { let batches = ipc_file_to_batches(buf.to_vec()) .map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?; - self.table.add(Box::new(batches), None).await.map_err(|e| { - napi::Error::from_reason(format!( - "Failed to add batches to table {}: {}", - self.table, e - )) - }) + self.table + .add(Box::new(batches), AddDataOptions::default()) + .await + .map_err(|e| { + napi::Error::from_reason(format!( + "Failed to add batches to table {}: {}", + self.table, e + )) + }) } #[napi] diff --git a/rust/ffi/node/src/lib.rs b/rust/ffi/node/src/lib.rs index 41212030..fcc21302 100644 --- a/rust/ffi/node/src/lib.rs +++ b/rust/ffi/node/src/lib.rs @@ -22,9 +22,9 @@ use object_store::CredentialProvider; use once_cell::sync::OnceCell; use tokio::runtime::Runtime; -use vectordb::connection::Database; +use vectordb::connect; +use vectordb::connection::Connection; use vectordb::table::ReadParams; -use vectordb::{ConnectOptions, Connection}; use crate::error::ResultExt; use crate::query::JsQuery; @@ -39,7 +39,7 @@ mod query; mod table; struct JsDatabase { - database: Arc, + database: Connection, } impl Finalize for JsDatabase {} @@ -89,23 +89,23 @@ fn database_new(mut cx: FunctionContext) -> JsResult { let channel = cx.channel(); let (deferred, promise) = cx.promise(); - let mut conn_options = ConnectOptions::new(&path); + let mut conn_builder = connect(&path); if let Some(region) = region { - conn_options = conn_options.region(®ion); + conn_builder = conn_builder.region(®ion); } if let Some(aws_creds) = aws_creds { - conn_options = conn_options.aws_creds(AwsCredential { + conn_builder = conn_builder.aws_creds(AwsCredential { key_id: aws_creds.key_id, secret_key: aws_creds.secret_key, token: aws_creds.token, }); } rt.spawn(async move { - let database = Database::connect_with_options(&conn_options).await; + let database = conn_builder.execute().await; deferred.settle_with(&channel, move |mut cx| { let db = JsDatabase { - database: Arc::new(database.or_throw(&mut cx)?), + database: database.or_throw(&mut cx)?, }; Ok(cx.boxed(db)) }); @@ -217,7 +217,11 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult { let (deferred, promise) = cx.promise(); rt.spawn(async move { - let table_rst = database.open_table_with_params(&table_name, params).await; + let table_rst = database + .open_table(&table_name) + .lance_read_params(params) + .execute() + .await; deferred.settle_with(&channel, move |mut cx| { let js_table = JsTable::from(table_rst.or_throw(&mut cx)?); diff --git a/rust/ffi/node/src/table.rs b/rust/ffi/node/src/table.rs index bb6bbfea..ee59186a 100644 --- a/rust/ffi/node/src/table.rs +++ b/rust/ffi/node/src/table.rs @@ -18,7 +18,7 @@ use arrow_array::{RecordBatch, RecordBatchIterator}; use lance::dataset::optimize::CompactionOptions; use lance::dataset::{WriteMode, WriteParams}; use lance::io::ObjectStoreParams; -use vectordb::table::OptimizeAction; +use vectordb::table::{AddDataOptions, OptimizeAction, WriteOptions}; use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer}; use neon::prelude::*; @@ -80,7 +80,11 @@ impl JsTable { rt.spawn(async move { let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); let table_rst = database - .create_table(&table_name, Box::new(batch_reader), Some(params)) + .create_table(&table_name, Box::new(batch_reader)) + .write_options(WriteOptions { + lance_write_params: Some(params), + }) + .execute() .await; deferred.settle_with(&channel, move |mut cx| { @@ -121,7 +125,13 @@ impl JsTable { rt.spawn(async move { let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); - let add_result = table.add(Box::new(batch_reader), Some(params)).await; + let opts = AddDataOptions { + write_options: WriteOptions { + lance_write_params: Some(params), + }, + ..Default::default() + }; + let add_result = table.add(Box::new(batch_reader), opts).await; deferred.settle_with(&channel, move |mut cx| { add_result.or_throw(&mut cx)?; diff --git a/rust/vectordb/examples/simple.rs b/rust/vectordb/examples/simple.rs index 947c6952..c315ac00 100644 --- a/rust/vectordb/examples/simple.rs +++ b/rust/vectordb/examples/simple.rs @@ -19,7 +19,8 @@ use arrow_array::{FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterat use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; -use vectordb::Connection; +use vectordb::connection::Connection; +use vectordb::table::AddDataOptions; use vectordb::{connect, Result, Table, TableRef}; #[tokio::main] @@ -29,18 +30,18 @@ async fn main() -> Result<()> { } // --8<-- [start:connect] let uri = "data/sample-lancedb"; - let db = connect(uri).await?; + let db = connect(uri).execute().await?; // --8<-- [end:connect] // --8<-- [start:list_names] println!("{:?}", db.table_names().await?); // --8<-- [end:list_names] - let tbl = create_table(db.clone()).await?; + let tbl = create_table(&db).await?; create_index(tbl.as_ref()).await?; let batches = search(tbl.as_ref()).await?; println!("{:?}", batches); - create_empty_table(db.clone()).await.unwrap(); + create_empty_table(&db).await.unwrap(); // --8<-- [start:delete] tbl.delete("id > 24").await.unwrap(); @@ -55,17 +56,14 @@ async fn main() -> Result<()> { #[allow(dead_code)] async fn open_with_existing_tbl() -> Result<()> { let uri = "data/sample-lancedb"; - let db = connect(uri).await?; + let db = connect(uri).execute().await?; // --8<-- [start:open_with_existing_file] - let _ = db - .open_table_with_params("my_table", Default::default()) - .await - .unwrap(); + let _ = db.open_table("my_table").execute().await.unwrap(); // --8<-- [end:open_with_existing_file] Ok(()) } -async fn create_table(db: Arc) -> Result { +async fn create_table(db: &Connection) -> Result { // --8<-- [start:create_table] const TOTAL: usize = 1000; const DIM: usize = 128; @@ -102,7 +100,8 @@ async fn create_table(db: Arc) -> Result { schema.clone(), ); let tbl = db - .create_table("my_table", Box::new(batches), None) + .create_table("my_table", Box::new(batches)) + .execute() .await .unwrap(); // --8<-- [end:create_table] @@ -126,21 +125,21 @@ async fn create_table(db: Arc) -> Result { schema.clone(), ); // --8<-- [start:add] - tbl.add(Box::new(new_batches), None).await.unwrap(); + tbl.add(Box::new(new_batches), AddDataOptions::default()) + .await + .unwrap(); // --8<-- [end:add] Ok(tbl) } -async fn create_empty_table(db: Arc) -> Result { +async fn create_empty_table(db: &Connection) -> Result { // --8<-- [start:create_empty_table] let schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("item", DataType::Utf8, true), ])); - let batches = RecordBatchIterator::new(vec![], schema.clone()); - db.create_table("empty_table", Box::new(batches), None) - .await + db.create_empty_table("empty_table", schema).execute().await // --8<-- [end:create_empty_table] } diff --git a/rust/vectordb/src/connection.rs b/rust/vectordb/src/connection.rs index cb2800c5..e6608ae2 100644 --- a/rust/vectordb/src/connection.rs +++ b/rust/vectordb/src/connection.rs @@ -13,14 +13,14 @@ // limitations under the License. //! LanceDB Database -//! use std::fs::create_dir_all; use std::path::Path; use std::sync::Arc; -use arrow_array::RecordBatchReader; -use lance::dataset::WriteParams; +use arrow_array::{RecordBatchIterator, RecordBatchReader}; +use arrow_schema::SchemaRef; +use lance::dataset::{ReadParams, WriteMode}; use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore}; use object_store::{ aws::AwsCredential, local::LocalFileSystem, CredentialProvider, StaticCredentialProvider, @@ -29,73 +29,283 @@ use snafu::prelude::*; use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result}; use crate::io::object_store::MirroringObjectStoreWrapper; -use crate::table::{NativeTable, ReadParams, TableRef}; +use crate::table::{NativeTable, TableRef, WriteOptions}; pub const LANCE_FILE_EXTENSION: &str = "lance"; -/// A connection to LanceDB -#[async_trait::async_trait] -pub trait Connection: Send + Sync { - /// Get the names of all tables in the database. - async fn table_names(&self) -> Result>; +pub type TableBuilderCallback = Box OpenTableBuilder + Send>; - /// Create a new table in the database. +/// Describes what happens when creating a table and a table with +/// the same name already exists +pub enum CreateTableMode { + /// If the table already exists, an error is returned + Create, + /// If the table already exists, it is opened. Any provided data is + /// ignored. The function will be passed an OpenTableBuilder to customize + /// how the table is opened + ExistOk(TableBuilderCallback), + /// If the table already exists, it is overwritten + Overwrite, +} + +impl CreateTableMode { + pub fn exist_ok( + callback: impl FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send + 'static, + ) -> Self { + Self::ExistOk(Box::new(callback)) + } +} + +impl Default for CreateTableMode { + fn default() -> Self { + Self::Create + } +} + +/// Describes what happens when a vector either contains NaN or +/// does not have enough values +#[derive(Clone, Debug, Default)] +enum BadVectorHandling { + /// An error is returned + #[default] + Error, + #[allow(dead_code)] // https://github.com/lancedb/lancedb/issues/992 + /// The offending row is droppped + Drop, + #[allow(dead_code)] // https://github.com/lancedb/lancedb/issues/992 + /// The invalid/missing items are replaced by fill_value + Fill(f32), +} + +/// A builder for configuring a [`Connection::create_table`] operation +pub struct CreateTableBuilder { + parent: Arc, + name: String, + data: Option>, + schema: Option, + mode: CreateTableMode, + write_options: WriteOptions, +} + +// Builder methods that only apply when we have initial data +impl CreateTableBuilder { + fn new( + parent: Arc, + name: String, + data: Box, + ) -> Self { + Self { + parent, + name, + data: Some(data), + schema: None, + mode: CreateTableMode::default(), + write_options: WriteOptions::default(), + } + } + + /// Apply the given write options when writing the initial data + pub fn write_options(mut self, write_options: WriteOptions) -> Self { + self.write_options = write_options; + self + } + + /// Execute the create table operation + pub async fn execute(self) -> Result { + self.parent.clone().do_create_table(self).await + } +} + +// Builder methods that only apply when we do not have initial data +impl CreateTableBuilder { + fn new(parent: Arc, name: String, schema: SchemaRef) -> Self { + Self { + parent, + name, + data: None, + schema: Some(schema), + mode: CreateTableMode::default(), + write_options: WriteOptions::default(), + } + } + + /// Execute the create table operation + pub async fn execute(self) -> Result { + self.parent.clone().do_create_empty_table(self).await + } +} + +impl CreateTableBuilder { + /// Set the mode for creating the table + /// + /// This controls what happens if a table with the given name already exists + pub fn mode(mut self, mode: CreateTableMode) -> Self { + self.mode = mode; + self + } +} + +#[derive(Clone, Debug)] +pub struct OpenTableBuilder { + parent: Arc, + name: String, + index_cache_size: u32, + lance_read_params: Option, +} + +impl OpenTableBuilder { + fn new(parent: Arc, name: String) -> Self { + Self { + parent, + name, + index_cache_size: 256, + lance_read_params: None, + } + } + + /// Set the size of the index cache, specified as a number of entries + /// + /// The default value is 256 + /// + /// The exact meaning of an "entry" will depend on the type of index: + /// * IVF - there is one entry for each IVF partition + /// * BTREE - there is one entry for the entire index + /// + /// This cache applies to the entire opened table, across all indices. + /// Setting this value higher will increase performance on larger datasets + /// at the expense of more RAM + pub fn index_cache_size(mut self, index_cache_size: u32) -> Self { + self.index_cache_size = index_cache_size; + self + } + + /// Advanced parameters that can be used to customize table reads + /// + /// If set, these will take precedence over any overlapping `OpenTableOptions` options + pub fn lance_read_params(mut self, params: ReadParams) -> Self { + self.lance_read_params = Some(params); + self + } + + /// Open the table + pub async fn execute(self) -> Result { + self.parent.clone().do_open_table(self).await + } +} + +#[async_trait::async_trait] +trait ConnectionInternal: Send + Sync + std::fmt::Debug + 'static { + async fn table_names(&self) -> Result>; + async fn do_create_table(&self, options: CreateTableBuilder) -> Result; + async fn do_open_table(&self, options: OpenTableBuilder) -> Result; + async fn drop_table(&self, name: &str) -> Result<()>; + async fn drop_db(&self) -> Result<()>; + + async fn do_create_empty_table(&self, options: CreateTableBuilder) -> Result { + let batches = RecordBatchIterator::new(vec![], options.schema.unwrap()); + let opts = CreateTableBuilder::::new(options.parent, options.name, Box::new(batches)) + .mode(options.mode) + .write_options(options.write_options); + self.do_create_table(opts).await + } +} + +/// A connection to LanceDB +#[derive(Clone)] +pub struct Connection { + uri: String, + internal: Arc, +} + +impl Connection { + /// Get the URI of the connection + pub fn uri(&self) -> &str { + self.uri.as_str() + } + + /// Get the names of all tables in the database. + pub async fn table_names(&self) -> Result> { + self.internal.table_names().await + } + + /// Create a new table from data /// /// # Parameters /// - /// * `name` - The name of the table. - /// * `batches` - The initial data to write to the table. - /// * `params` - Optional [`WriteParams`] to create the table. - /// - /// # Returns - /// Created [`TableRef`], or [`Err(Error::TableAlreadyExists)`] if the table already exists. - async fn create_table( + /// * `name` - The name of the table + /// * `initial_data` - The initial data to write to the table + pub fn create_table( &self, - name: &str, - batches: Box, - params: Option, - ) -> Result; - - async fn open_table(&self, name: &str) -> Result { - self.open_table_with_params(name, ReadParams::default()) - .await + name: impl Into, + initial_data: Box, + ) -> CreateTableBuilder { + CreateTableBuilder::::new(self.internal.clone(), name.into(), initial_data) } - async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result; + /// Create an empty table with a given schema + /// + /// # Parameters + /// + /// * `name` - The name of the table + /// * `schema` - The schema of the table + pub fn create_empty_table( + &self, + name: impl Into, + schema: SchemaRef, + ) -> CreateTableBuilder { + CreateTableBuilder::::new(self.internal.clone(), name.into(), schema) + } + + /// Open an existing table in the database + /// + /// # Arguments + /// * `name` - The name of the table + /// + /// # Returns + /// Created [`TableRef`], or [`Error::TableNotFound`] if the table does not exist. + pub fn open_table(&self, name: impl Into) -> OpenTableBuilder { + OpenTableBuilder::new(self.internal.clone(), name.into()) + } /// Drop a table in the database. /// /// # Arguments - /// * `name` - The name of the table. - async fn drop_table(&self, name: &str) -> Result<()>; + /// * `name` - The name of the table to drop + pub async fn drop_table(&self, name: impl AsRef) -> Result<()> { + self.internal.drop_table(name.as_ref()).await + } + + /// Drop the database + /// + /// This is the same as dropping all of the tables + pub async fn drop_db(&self) -> Result<()> { + self.internal.drop_db().await + } } #[derive(Debug)] -pub struct ConnectOptions { +pub struct ConnectBuilder { /// Database URI /// - /// # Accpeted URI formats + /// ### Accpeted URI formats /// /// - `/path/to/database` - local database on file system. /// - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store - /// - `db://dbname` - Lance Cloud - pub uri: String, + /// - `db://dbname` - LanceDB Cloud + uri: String, - /// Lance Cloud API key - pub api_key: Option, - /// Lance Cloud region - pub region: Option, - /// Lance Cloud host override - pub host_override: Option, + /// LanceDB Cloud API key, required if using Lance Cloud + api_key: Option, + /// LanceDB Cloud region, required if using Lance Cloud + region: Option, + /// LanceDB Cloud host override, only required if using an on-premises Lance Cloud instance + host_override: Option, /// User provided AWS credentials - pub aws_creds: Option, - - /// The maximum number of indices to cache in memory. Defaults to 256. - pub index_cache_size: u32, + aws_creds: Option, } -impl ConnectOptions { +impl ConnectBuilder { /// Create a new [`ConnectOptions`] with the given database URI. pub fn new(uri: &str) -> Self { Self { @@ -104,7 +314,6 @@ impl ConnectOptions { region: None, host_override: None, aws_creds: None, - index_cache_size: 256, } } @@ -124,15 +333,18 @@ impl ConnectOptions { } /// [`AwsCredential`] to use when connecting to S3. - /// pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self { self.aws_creds = Some(aws_creds); self } - pub fn index_cache_size(mut self, index_cache_size: u32) -> Self { - self.index_cache_size = index_cache_size; - self + /// Establishes a connection to the database + pub async fn execute(self) -> Result { + let internal = Arc::new(Database::connect_with_options(&self).await?); + Ok(Connection { + internal, + uri: self.uri, + }) } } @@ -140,29 +352,14 @@ impl ConnectOptions { /// /// # Arguments /// -/// - `uri` - URI where the database is located, can be a local file or a supported remote cloud storage -/// -/// ## Accepted URI formats -/// -/// - `/path/to/database` - local database on file system. -/// - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store -/// - `db://dbname` - Lance Cloud -/// -pub async fn connect(uri: &str) -> Result> { - let options = ConnectOptions::new(uri); - connect_with_options(&options).await +/// * `uri` - URI where the database is located, can be a local directory, supported remote cloud storage, +/// or a LanceDB Cloud database. See [ConnectOptions::uri] for a list of accepted formats +pub fn connect(uri: &str) -> ConnectBuilder { + ConnectBuilder::new(uri) } -/// Connect with [`ConnectOptions`]. -/// -/// # Arguments -/// - `options` - [`ConnectOptions`] to connect to the database. -pub async fn connect_with_options(options: &ConnectOptions) -> Result> { - let db = Database::connect(&options.uri).await?; - Ok(Arc::new(db)) -} - -pub struct Database { +#[derive(Debug)] +struct Database { object_store: ObjectStore, query_string: Option, @@ -179,21 +376,7 @@ const MIRRORED_STORE: &str = "mirroredStore"; /// A connection to LanceDB impl Database { - /// Connects to LanceDB - /// - /// # Arguments - /// - /// * `uri` - URI where the database is located, can be a local file or a supported remote cloud storage - /// - /// # Returns - /// - /// * A [Database] object. - pub async fn connect(uri: &str) -> Result { - let options = ConnectOptions::new(uri); - Self::connect_with_options(&options).await - } - - pub async fn connect_with_options(options: &ConnectOptions) -> Result { + async fn connect_with_options(options: &ConnectBuilder) -> Result { let uri = &options.uri; let parse_res = url::Url::parse(uri); @@ -333,7 +516,7 @@ impl Database { } #[async_trait::async_trait] -impl Connection for Database { +impl ConnectionInternal for Database { async fn table_names(&self) -> Result> { let mut f = self .object_store @@ -354,40 +537,47 @@ impl Connection for Database { Ok(f) } - async fn create_table( - &self, - name: &str, - batches: Box, - params: Option, - ) -> Result { - let table_uri = self.table_uri(name)?; + async fn do_create_table(&self, options: CreateTableBuilder) -> Result { + let table_uri = self.table_uri(&options.name)?; - Ok(Arc::new( - NativeTable::create( - &table_uri, - name, - batches, - self.store_wrapper.clone(), - params, - ) - .await?, - )) + let mut write_params = options.write_options.lance_write_params.unwrap_or_default(); + if matches!(&options.mode, CreateTableMode::Overwrite) { + write_params.mode = WriteMode::Overwrite; + } + + match NativeTable::create( + &table_uri, + &options.name, + options.data.unwrap(), + self.store_wrapper.clone(), + Some(write_params), + ) + .await + { + Ok(table) => Ok(Arc::new(table)), + Err(Error::TableAlreadyExists { name }) => match options.mode { + CreateTableMode::Create => Err(Error::TableAlreadyExists { name }), + CreateTableMode::ExistOk(callback) => { + let builder = OpenTableBuilder::new(options.parent, options.name); + let builder = (callback)(builder); + builder.execute().await + } + CreateTableMode::Overwrite => unreachable!(), + }, + Err(err) => Err(err), + } } - /// Open a table in the database. - /// - /// # Arguments - /// * `name` - The name of the table. - /// * `params` - The parameters to open the table. - /// - /// # Returns - /// - /// * A [TableRef] object. - async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result { - let table_uri = self.table_uri(name)?; + async fn do_open_table(&self, options: OpenTableBuilder) -> Result { + let table_uri = self.table_uri(&options.name)?; Ok(Arc::new( - NativeTable::open_with_params(&table_uri, name, self.store_wrapper.clone(), params) - .await?, + NativeTable::open_with_params( + &table_uri, + &options.name, + self.store_wrapper.clone(), + options.lance_read_params, + ) + .await?, )) } @@ -397,12 +587,17 @@ impl Connection for Database { self.object_store.remove_dir_all(full_path).await?; Ok(()) } + + async fn drop_db(&self) -> Result<()> { + todo!() + } } #[cfg(test)] mod tests { use std::fs::create_dir_all; + use arrow_schema::{DataType, Field, Schema}; use tempfile::tempdir; use super::*; @@ -411,7 +606,7 @@ mod tests { async fn test_connect() { let tmp_dir = tempdir().unwrap(); let uri = tmp_dir.path().to_str().unwrap(); - let db = Database::connect(uri).await.unwrap(); + let db = connect(uri).execute().await.unwrap(); assert_eq!(db.uri, uri); } @@ -429,7 +624,8 @@ mod tests { let relative_root = std::path::PathBuf::from(relative_ancestors.join("/")); let relative_uri = relative_root.join(&uri); - let db = Database::connect(relative_uri.to_str().unwrap()) + let db = connect(relative_uri.to_str().unwrap()) + .execute() .await .unwrap(); @@ -444,7 +640,7 @@ mod tests { create_dir_all(tmp_dir.path().join("invalidlance")).unwrap(); let uri = tmp_dir.path().to_str().unwrap(); - let db = Database::connect(uri).await.unwrap(); + let db = connect(uri).execute().await.unwrap(); let tables = db.table_names().await.unwrap(); assert_eq!(tables.len(), 2); assert!(tables[0].eq(&String::from("table1"))); @@ -462,10 +658,44 @@ mod tests { create_dir_all(tmp_dir.path().join("table1.lance")).unwrap(); let uri = tmp_dir.path().to_str().unwrap(); - let db = Database::connect(uri).await.unwrap(); + let db = connect(uri).execute().await.unwrap(); db.drop_table("table1").await.unwrap(); let tables = db.table_names().await.unwrap(); assert_eq!(tables.len(), 0); } + + #[tokio::test] + async fn test_create_table_already_exists() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let db = connect(uri).execute().await.unwrap(); + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)])); + db.create_empty_table("test", schema.clone()) + .execute() + .await + .unwrap(); + // TODO: None of the open table options are "inspectable" right now but once one is we + // should assert we are passing these options in correctly + db.create_empty_table("test", schema) + .mode(CreateTableMode::exist_ok(|builder| { + builder.index_cache_size(16) + })) + .execute() + .await + .unwrap(); + let other_schema = Arc::new(Schema::new(vec![Field::new("y", DataType::Int32, false)])); + assert!(db + .create_empty_table("test", other_schema.clone()) + .execute() + .await + .is_err()); + let overwritten = db + .create_empty_table("test", other_schema.clone()) + .mode(CreateTableMode::Overwrite) + .execute() + .await + .unwrap(); + assert_eq!(other_schema, overwritten.schema()); + } } diff --git a/rust/vectordb/src/data/sanitize.rs b/rust/vectordb/src/data/sanitize.rs index c5efd2bc..fe139b99 100644 --- a/rust/vectordb/src/data/sanitize.rs +++ b/rust/vectordb/src/data/sanitize.rs @@ -174,7 +174,6 @@ fn coerce_schema_batch( } /// Coerce the reader (input data) to match the given [Schema]. -/// pub fn coerce_schema( reader: impl RecordBatchReader + Send + 'static, schema: Arc, diff --git a/rust/vectordb/src/io/object_store.rs b/rust/vectordb/src/io/object_store.rs index 22b7d518..e7dc3d78 100644 --- a/rust/vectordb/src/io/object_store.rs +++ b/rust/vectordb/src/io/object_store.rs @@ -342,7 +342,7 @@ mod test { use object_store::local::LocalFileSystem; use tempfile; - use crate::connection::{Connection, Database}; + use crate::{connect, table::WriteOptions}; #[tokio::test] async fn test_e2e() { @@ -354,7 +354,7 @@ mod test { secondary: Arc::new(secondary_store), }); - let db = Database::connect(dir1.to_str().unwrap()).await.unwrap(); + let db = connect(dir1.to_str().unwrap()).execute().await.unwrap(); let mut param = WriteParams::default(); let store_params = ObjectStoreParams { @@ -368,7 +368,11 @@ mod test { datagen = datagen.col(Box::new(RandomVector::default().named("vector".into()))); let res = db - .create_table("test", Box::new(datagen.batch(100)), Some(param.clone())) + .create_table("test", Box::new(datagen.batch(100))) + .write_options(WriteOptions { + lance_write_params: Some(param), + }) + .execute() .await; // leave this here for easy debugging diff --git a/rust/vectordb/src/lib.rs b/rust/vectordb/src/lib.rs index fc4ac149..c2ed48bc 100644 --- a/rust/vectordb/src/lib.rs +++ b/rust/vectordb/src/lib.rs @@ -43,10 +43,9 @@ //! #### Connect to a database. //! //! ```rust -//! use vectordb::connect; //! # use arrow_schema::{Field, Schema}; //! # tokio::runtime::Runtime::new().unwrap().block_on(async { -//! let db = connect("data/sample-lancedb").await.unwrap(); +//! let db = vectordb::connect("data/sample-lancedb").execute().await.unwrap(); //! # }); //! ``` //! @@ -56,14 +55,20 @@ //! - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store //! - `db://dbname` - Lance Cloud //! -//! You can also use [`ConnectOptions`] to configure the connectoin to the database. +//! You can also use [`ConnectOptions`] to configure the connection to the database. //! //! ```rust -//! use vectordb::{connect_with_options, ConnectOptions}; +//! use object_store::aws::AwsCredential; //! # tokio::runtime::Runtime::new().unwrap().block_on(async { -//! let options = ConnectOptions::new("data/sample-lancedb") -//! .index_cache_size(1024); -//! let db = connect_with_options(&options).await.unwrap(); +//! let db = vectordb::connect("data/sample-lancedb") +//! .aws_creds(AwsCredential { +//! key_id: "some_key".to_string(), +//! secret_key: "some_secret".to_string(), +//! token: None, +//! }) +//! .execute() +//! .await +//! .unwrap(); //! # }); //! ``` //! @@ -79,31 +84,44 @@ //! //! ```rust //! # use std::sync::Arc; -//! use arrow_schema::{DataType, Schema, Field}; //! use arrow_array::{RecordBatch, RecordBatchIterator}; +//! use arrow_schema::{DataType, Field, Schema}; //! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type}; -//! # use vectordb::connection::{Database, Connection}; -//! # use vectordb::connect; //! //! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! # let tmpdir = tempfile::tempdir().unwrap(); -//! # let db = connect(tmpdir.path().to_str().unwrap()).await.unwrap(); +//! # let db = vectordb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap(); //! let schema = Arc::new(Schema::new(vec![ -//! Field::new("id", DataType::Int32, false), -//! Field::new("vector", DataType::FixedSizeList( -//! Arc::new(Field::new("item", DataType::Float32, true)), 128), true), +//! Field::new("id", DataType::Int32, false), +//! Field::new( +//! "vector", +//! DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 128), +//! true, +//! ), //! ])); //! // Create a RecordBatch stream. -//! let batches = RecordBatchIterator::new(vec![ -//! RecordBatch::try_new(schema.clone(), +//! let batches = RecordBatchIterator::new( +//! vec![RecordBatch::try_new( +//! schema.clone(), //! vec![ -//! Arc::new(Int32Array::from_iter_values(0..1000)), -//! Arc::new(FixedSizeListArray::from_iter_primitive::( -//! (0..1000).map(|_| Some(vec![Some(1.0); 128])), 128)), -//! ]).unwrap() -//! ].into_iter().map(Ok), -//! schema.clone()); -//! db.create_table("my_table", Box::new(batches), None).await.unwrap(); +//! Arc::new(Int32Array::from_iter_values(0..256)), +//! Arc::new( +//! FixedSizeListArray::from_iter_primitive::( +//! (0..256).map(|_| Some(vec![Some(1.0); 128])), +//! 128, +//! ), +//! ), +//! ], +//! ) +//! .unwrap()] +//! .into_iter() +//! .map(Ok), +//! schema.clone(), +//! ); +//! db.create_table("my_table", Box::new(batches)) +//! .execute() +//! .await +//! .unwrap(); //! # }); //! ``` //! @@ -111,14 +129,13 @@ //! //! ```no_run //! # use std::sync::Arc; -//! # use vectordb::connect; //! # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch, //! # RecordBatchIterator, Int32Array}; //! # use arrow_schema::{Schema, Field, DataType}; //! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! # let tmpdir = tempfile::tempdir().unwrap(); -//! # let db = connect(tmpdir.path().to_str().unwrap()).await.unwrap(); -//! # let tbl = db.open_table("idx_test").await.unwrap(); +//! # let db = vectordb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap(); +//! # let tbl = db.open_table("idx_test").execute().await.unwrap(); //! tbl.create_index(&["vector"]) //! .ivf_pq() //! .num_partitions(256) @@ -136,10 +153,9 @@ //! # use arrow_schema::{DataType, Schema, Field}; //! # use arrow_array::{RecordBatch, RecordBatchIterator}; //! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type}; -//! # use vectordb::connection::{Database, Connection}; //! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! # let tmpdir = tempfile::tempdir().unwrap(); -//! # let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap(); +//! # let db = vectordb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap(); //! # let schema = Arc::new(Schema::new(vec![ //! # Field::new("id", DataType::Int32, false), //! # Field::new("vector", DataType::FixedSizeList( @@ -154,8 +170,8 @@ //! # ]).unwrap() //! # ].into_iter().map(Ok), //! # schema.clone()); -//! # db.create_table("my_table", Box::new(batches), None).await.unwrap(); -//! # let table = db.open_table("my_table").await.unwrap(); +//! # db.create_table("my_table", Box::new(batches)).execute().await.unwrap(); +//! # let table = db.open_table("my_table").execute().await.unwrap(); //! let results = table //! .search(&[1.0; 128]) //! .execute_stream() @@ -165,8 +181,6 @@ //! .await //! .unwrap(); //! # }); -//! -//! //! ``` pub mod connection; @@ -179,10 +193,8 @@ pub mod query; pub mod table; pub mod utils; -pub use connection::{Connection, Database}; pub use error::{Error, Result}; pub use table::{Table, TableRef}; /// Connect to a database -pub use connection::{connect, connect_with_options, ConnectOptions}; -pub use lance::dataset::WriteMode; +pub use connection::connect; diff --git a/rust/vectordb/src/query.rs b/rust/vectordb/src/query.rs index 53765ab5..57086359 100644 --- a/rust/vectordb/src/query.rs +++ b/rust/vectordb/src/query.rs @@ -60,7 +60,6 @@ impl Query { /// # Arguments /// /// * `dataset` - Lance dataset. - /// pub(crate) fn new(dataset: Arc) -> Self { Self { dataset, diff --git a/rust/vectordb/src/table.rs b/rust/vectordb/src/table.rs index 7c5ac300..c6866899 100644 --- a/rust/vectordb/src/table.rs +++ b/rust/vectordb/src/table.rs @@ -17,7 +17,7 @@ use std::path::Path; use std::sync::{Arc, Mutex}; -use arrow_array::RecordBatchReader; +use arrow_array::{RecordBatchIterator, RecordBatchReader}; use arrow_schema::{Schema, SchemaRef}; use async_trait::async_trait; use chrono::Duration; @@ -27,7 +27,7 @@ use lance::dataset::optimize::{ compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions, }; pub use lance::dataset::ReadParams; -use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteParams}; +use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteMode, WriteParams}; use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource}; use lance::io::WrappingObjectStore; use lance_index::{optimize::OptimizeOptions, DatasetIndexExt}; @@ -38,7 +38,6 @@ use crate::index::vector::{VectorIndex, VectorIndexStatistics}; use crate::index::IndexBuilder; use crate::query::Query; use crate::utils::{PatchReadParam, PatchWriteParam}; -use crate::WriteMode; use self::merge::{MergeInsert, MergeInsertBuilder}; @@ -85,6 +84,35 @@ pub struct OptimizeStats { pub prune: Option, } +/// Options to use when writing data +#[derive(Clone, Debug, Default)] +pub struct WriteOptions { + // Coming soon: https://github.com/lancedb/lancedb/issues/992 + // /// What behavior to take if the data contains invalid vectors + // pub on_bad_vectors: BadVectorHandling, + /// Advanced parameters that can be used to customize table creation + /// + /// If set, these will take precedence over any overlapping `OpenTableOptions` options + pub lance_write_params: Option, +} + +#[derive(Debug, Clone, Default)] +pub enum AddDataMode { + /// Rows will be appended to the table (the default) + #[default] + Append, + /// The existing table will be overwritten with the new data + Overwrite, +} + +#[derive(Debug, Default, Clone)] +pub struct AddDataOptions { + /// Whether to add new rows (the default) or replace the existing data + pub mode: AddDataMode, + /// Options to use when writing the data + pub write_options: WriteOptions, +} + /// A Table is a collection of strong typed Rows. /// /// The type of the each row is defined in Apache Arrow [Schema]. @@ -112,12 +140,12 @@ pub trait Table: std::fmt::Display + Send + Sync { /// /// # Arguments /// - /// * `batches` RecordBatch to be saved in the Table - /// * `params` Append / Overwrite existing records. Default: Append + /// * `batches` data to be added to the Table + /// * `options` options to control how data is added async fn add( &self, batches: Box, - params: Option, + options: AddDataOptions, ) -> Result<()>; /// Delete the rows from table that match the predicate. @@ -129,28 +157,43 @@ pub trait Table: std::fmt::Display + Send + Sync { /// /// ```no_run /// # use std::sync::Arc; - /// # use vectordb::connection::{Database, Connection}; /// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch, /// # RecordBatchIterator, Int32Array}; /// # use arrow_schema::{Schema, Field, DataType}; /// # tokio::runtime::Runtime::new().unwrap().block_on(async { /// let tmpdir = tempfile::tempdir().unwrap(); - /// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap(); + /// let db = vectordb::connect(tmpdir.path().to_str().unwrap()) + /// .execute() + /// .await + /// .unwrap(); /// # let schema = Arc::new(Schema::new(vec![ /// # Field::new("id", DataType::Int32, false), /// # Field::new("vector", DataType::FixedSizeList( /// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true), /// # ])); - /// let batches = RecordBatchIterator::new(vec![ - /// RecordBatch::try_new(schema.clone(), - /// vec![ - /// Arc::new(Int32Array::from_iter_values(0..10)), - /// Arc::new(FixedSizeListArray::from_iter_primitive::( - /// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)), - /// ]).unwrap() - /// ].into_iter().map(Ok), - /// schema.clone()); - /// let tbl = db.create_table("delete_test", Box::new(batches), None).await.unwrap(); + /// let batches = RecordBatchIterator::new( + /// vec![RecordBatch::try_new( + /// schema.clone(), + /// vec![ + /// Arc::new(Int32Array::from_iter_values(0..10)), + /// Arc::new( + /// FixedSizeListArray::from_iter_primitive::( + /// (0..10).map(|_| Some(vec![Some(1.0); 128])), + /// 128, + /// ), + /// ), + /// ], + /// ) + /// .unwrap()] + /// .into_iter() + /// .map(Ok), + /// schema.clone(), + /// ); + /// let tbl = db + /// .create_table("delete_test", Box::new(batches)) + /// .execute() + /// .await + /// .unwrap(); /// tbl.delete("id > 5").await.unwrap(); /// # }); /// ``` @@ -162,14 +205,16 @@ pub trait Table: std::fmt::Display + Send + Sync { /// /// ```no_run /// # use std::sync::Arc; - /// # use vectordb::connection::{Database, Connection}; /// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch, /// # RecordBatchIterator, Int32Array}; /// # use arrow_schema::{Schema, Field, DataType}; /// # tokio::runtime::Runtime::new().unwrap().block_on(async { /// let tmpdir = tempfile::tempdir().unwrap(); - /// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap(); - /// # let tbl = db.open_table("idx_test").await.unwrap(); + /// let db = vectordb::connect(tmpdir.path().to_str().unwrap()) + /// .execute() + /// .await + /// .unwrap(); + /// # let tbl = db.open_table("idx_test").execute().await.unwrap(); /// tbl.create_index(&["vector"]) /// .ivf_pq() /// .num_partitions(256) @@ -214,32 +259,44 @@ pub trait Table: std::fmt::Display + Send + Sync { /// /// ```no_run /// # use std::sync::Arc; - /// # use vectordb::connection::{Database, Connection}; /// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch, /// # RecordBatchIterator, Int32Array}; /// # use arrow_schema::{Schema, Field, DataType}; /// # tokio::runtime::Runtime::new().unwrap().block_on(async { /// let tmpdir = tempfile::tempdir().unwrap(); - /// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap(); - /// # let tbl = db.open_table("idx_test").await.unwrap(); + /// let db = vectordb::connect(tmpdir.path().to_str().unwrap()) + /// .execute() + /// .await + /// .unwrap(); + /// # let tbl = db.open_table("idx_test").execute().await.unwrap(); /// # let schema = Arc::new(Schema::new(vec![ /// # Field::new("id", DataType::Int32, false), /// # Field::new("vector", DataType::FixedSizeList( /// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true), /// # ])); - /// let new_data = RecordBatchIterator::new(vec![ - /// RecordBatch::try_new(schema.clone(), - /// vec![ - /// Arc::new(Int32Array::from_iter_values(0..10)), - /// Arc::new(FixedSizeListArray::from_iter_primitive::( - /// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)), - /// ]).unwrap() - /// ].into_iter().map(Ok), - /// schema.clone()); + /// let new_data = RecordBatchIterator::new( + /// vec![RecordBatch::try_new( + /// schema.clone(), + /// vec![ + /// Arc::new(Int32Array::from_iter_values(0..10)), + /// Arc::new( + /// FixedSizeListArray::from_iter_primitive::( + /// (0..10).map(|_| Some(vec![Some(1.0); 128])), + /// 128, + /// ), + /// ), + /// ], + /// ) + /// .unwrap()] + /// .into_iter() + /// .map(Ok), + /// schema.clone(), + /// ); /// // Perform an upsert operation /// let mut merge_insert = tbl.merge_insert(&["id"]); - /// merge_insert.when_matched_update_all(None) - /// .when_not_matched_insert_all(); + /// merge_insert + /// .when_matched_update_all(None) + /// .when_not_matched_insert_all(); /// merge_insert.execute(Box::new(new_data)).await.unwrap(); /// # }); /// ``` @@ -266,7 +323,9 @@ pub trait Table: std::fmt::Display + Send + Sync { /// # use futures::TryStreamExt; /// # tokio::runtime::Runtime::new().unwrap().block_on(async { /// # let tbl = vectordb::table::NativeTable::open("/tmp/tbl").await.unwrap(); - /// let stream = tbl.query().nearest_to(&[1.0, 2.0, 3.0]) + /// let stream = tbl + /// .query() + /// .nearest_to(&[1.0, 2.0, 3.0]) /// .refine_factor(5) /// .nprobes(10) /// .execute_stream() @@ -299,11 +358,7 @@ pub trait Table: std::fmt::Display + Send + Sync { /// # use futures::TryStreamExt; /// # tokio::runtime::Runtime::new().unwrap().block_on(async { /// # let tbl = vectordb::table::NativeTable::open("/tmp/tbl").await.unwrap(); - /// let stream = tbl - /// .query() - /// .execute_stream() - /// .await - /// .unwrap(); + /// let stream = tbl.query().execute_stream().await.unwrap(); /// let batches: Vec = stream.try_collect().await.unwrap(); /// # }); /// ``` @@ -351,7 +406,7 @@ impl NativeTable { /// * A [NativeTable] object. pub async fn open(uri: &str) -> Result { let name = Self::get_table_name(uri)?; - Self::open_with_params(uri, &name, None, ReadParams::default()).await + Self::open_with_params(uri, &name, None, None).await } /// Opens an existing Table @@ -369,8 +424,9 @@ impl NativeTable { uri: &str, name: &str, write_store_wrapper: Option>, - params: ReadParams, + params: Option, ) -> Result { + let params = params.unwrap_or_default(); // patch the params if we have a write store wrapper let params = match write_store_wrapper.clone() { Some(wrapper) => params.patch_with_store_wrapper(wrapper)?, @@ -403,7 +459,6 @@ impl NativeTable { } /// Checkout a specific version of this [NativeTable] - /// pub async fn checkout(uri: &str, version: u64) -> Result { let name = Self::get_table_name(uri)?; Self::checkout_with_params(uri, &name, version, None, ReadParams::default()).await @@ -489,13 +544,14 @@ impl NativeTable { write_store_wrapper: Option>, params: Option, ) -> Result { + let params = params.unwrap_or_default(); // patch the params if we have a write store wrapper let params = match write_store_wrapper.clone() { Some(wrapper) => params.patch_with_store_wrapper(wrapper)?, None => params, }; - let dataset = Dataset::write(batches, uri, params) + let dataset = Dataset::write(batches, uri, Some(params)) .await .map_err(|e| match e { lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists { @@ -513,6 +569,17 @@ impl NativeTable { }) } + pub async fn create_empty( + uri: &str, + name: &str, + schema: SchemaRef, + write_store_wrapper: Option>, + params: Option, + ) -> Result { + let batches = RecordBatchIterator::new(vec![], schema); + Self::create(uri, name, batches, write_store_wrapper, params).await + } + /// Version of this Table pub fn version(&self) -> u64 { self.dataset.lock().expect("lock poison").version().version @@ -740,20 +807,26 @@ impl Table for NativeTable { async fn add( &self, batches: Box, - params: Option, + params: AddDataOptions, ) -> Result<()> { - let params = Some(params.unwrap_or(WriteParams { - mode: WriteMode::Append, - ..WriteParams::default() - })); + let lance_params = params + .write_options + .lance_write_params + .unwrap_or(WriteParams { + mode: match params.mode { + AddDataMode::Append => WriteMode::Append, + AddDataMode::Overwrite => WriteMode::Overwrite, + }, + ..Default::default() + }); // patch the params if we have a write store wrapper - let params = match self.store_wrapper.clone() { - Some(wrapper) => params.patch_with_store_wrapper(wrapper)?, - None => params, + let lance_params = match self.store_wrapper.clone() { + Some(wrapper) => lance_params.patch_with_store_wrapper(wrapper)?, + None => lance_params, }; - self.reset_dataset(Dataset::write(batches, &self.uri, params).await?); + self.reset_dataset(Dataset::write(batches, &self.uri, Some(lance_params)).await?); Ok(()) } @@ -881,25 +954,6 @@ mod tests { assert_eq!(c.to_str().unwrap(), "s3://bucket/path/to/file/subfile"); } - #[tokio::test] - async fn test_create_already_exists() { - let tmp_dir = tempdir().unwrap(); - let uri = tmp_dir.path().to_str().unwrap(); - - let batches = make_test_batches(); - let _ = batches.schema().clone(); - NativeTable::create(uri, "test", batches, None, None) - .await - .unwrap(); - - let batches = make_test_batches(); - let result = NativeTable::create(uri, "test", batches, None, None).await; - assert!(matches!( - result.unwrap_err(), - Error::TableAlreadyExists { .. } - )); - } - #[tokio::test] async fn test_count_rows() { let tmp_dir = tempdir().unwrap(); @@ -940,7 +994,10 @@ mod tests { schema.clone(), ); - table.add(Box::new(new_batches), None).await.unwrap(); + table + .add(Box::new(new_batches), AddDataOptions::default()) + .await + .unwrap(); assert_eq!(table.count_rows(None).await.unwrap(), 20); assert_eq!(table.name, "test"); } @@ -1003,23 +1060,47 @@ mod tests { .unwrap(); assert_eq!(table.count_rows(None).await.unwrap(), 10); - let new_batches = RecordBatchIterator::new( - vec![RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from_iter_values(100..110))], - ) - .unwrap()] - .into_iter() - .map(Ok), + let batches = vec![RecordBatch::try_new( schema.clone(), - ); + vec![Arc::new(Int32Array::from_iter_values(100..110))], + ) + .unwrap()] + .into_iter() + .map(Ok); + + let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone()); + + // Can overwrite using AddDataOptions::mode + table + .add( + Box::new(new_batches), + AddDataOptions { + mode: AddDataMode::Overwrite, + ..Default::default() + }, + ) + .await + .unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), 10); + assert_eq!(table.name, "test"); + + // Can overwrite using underlying WriteParams (which + // take precedence over AddDataOptions::mode) let param: WriteParams = WriteParams { mode: WriteMode::Overwrite, ..Default::default() }; - table.add(Box::new(new_batches), Some(param)).await.unwrap(); + let opts = AddDataOptions { + write_options: WriteOptions { + lance_write_params: Some(param), + }, + mode: AddDataMode::Append, + }; + + let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone()); + table.add(Box::new(new_batches), opts).await.unwrap(); assert_eq!(table.count_rows(None).await.unwrap(), 10); assert_eq!(table.name, "test"); } @@ -1329,7 +1410,7 @@ mod tests { ..Default::default() }; assert!(!wrapper.called()); - let _ = NativeTable::open_with_params(uri, "test", None, param) + let _ = NativeTable::open_with_params(uri, "test", None, Some(param)) .await .unwrap(); assert!(wrapper.called()); diff --git a/rust/vectordb/src/utils.rs b/rust/vectordb/src/utils.rs index a3a03382..ec508de3 100644 --- a/rust/vectordb/src/utils.rs +++ b/rust/vectordb/src/utils.rs @@ -32,20 +32,17 @@ impl PatchStoreParam for Option { } pub trait PatchWriteParam { - fn patch_with_store_wrapper( - self, - wrapper: Arc, - ) -> Result>; + fn patch_with_store_wrapper(self, wrapper: Arc) + -> Result; } -impl PatchWriteParam for Option { +impl PatchWriteParam for WriteParams { fn patch_with_store_wrapper( - self, + mut self, wrapper: Arc, - ) -> Result> { - let mut params = self.unwrap_or_default(); - params.store_params = params.store_params.patch_with_store_wrapper(wrapper)?; - Ok(Some(params)) + ) -> Result { + self.store_params = self.store_params.patch_with_store_wrapper(wrapper)?; + Ok(self) } }