diff --git a/rust/ffi/node/src/lib.rs b/rust/ffi/node/src/lib.rs index b1bc9ca2..3fcd5c5a 100644 --- a/rust/ffi/node/src/lib.rs +++ b/rust/ffi/node/src/lib.rs @@ -22,8 +22,9 @@ use object_store::CredentialProvider; use once_cell::sync::OnceCell; use tokio::runtime::Runtime; -use vectordb::database::Database; +use vectordb::connection::Database; use vectordb::table::ReadParams; +use vectordb::Connection; use crate::error::ResultExt; use crate::query::JsQuery; @@ -38,7 +39,7 @@ mod query; mod table; struct JsDatabase { - database: Arc, + database: Arc, } impl Finalize for JsDatabase {} diff --git a/rust/ffi/node/src/table.rs b/rust/ffi/node/src/table.rs index fd4b6ef3..47c49ce9 100644 --- a/rust/ffi/node/src/table.rs +++ b/rust/ffi/node/src/table.rs @@ -77,7 +77,7 @@ impl JsTable { rt.spawn(async move { let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); let table_rst = database - .create_table(&table_name, batch_reader, Some(params)) + .create_table(&table_name, Box::new(batch_reader), Some(params)) .await; deferred.settle_with(&channel, move |mut cx| { diff --git a/rust/vectordb/src/database.rs b/rust/vectordb/src/connection.rs similarity index 90% rename from rust/vectordb/src/database.rs rename to rust/vectordb/src/connection.rs index 7d1d2c35..1e8308bc 100644 --- a/rust/vectordb/src/database.rs +++ b/rust/vectordb/src/connection.rs @@ -31,6 +31,40 @@ use crate::table::{ReadParams, Table}; pub const LANCE_FILE_EXTENSION: &str = "lance"; +/// A connection to LanceDB +#[async_trait::async_trait] +pub trait Connection: Send + Sync { + /// Get the names of all tables in the database. + async fn table_names(&self) -> Result>; + + /// Create a new table in the database. + /// + /// # Parameters + /// + /// * `name` - The name of the table. + /// * `batches` - The initial data to write to the table. + /// * `params` - Optional [`WriteParams`] to create the table. + /// + /// # Returns + /// Created [`Table`], or [`Err(Error::TableAlreadyExists)`] if the table already exists. + async fn create_table( + &self, + name: &str, + batches: Box, + params: Option, + ) -> Result; + + async fn open_table(&self, name: &str) -> Result
; + + async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result
; + + /// Drop a table in the database. + /// + /// # Arguments + /// * `name` - The name of the table. + async fn drop_table(&self, name: &str) -> Result<()>; +} + pub struct Database { object_store: ObjectStore, query_string: Option, @@ -52,7 +86,7 @@ impl Database { /// /// # Arguments /// - /// * `path` - URI where the database is located, can be a local file or a supported remote cloud storage + /// * `uri` - URI where the database is located, can be a local file or a supported remote cloud storage /// /// # Returns /// @@ -158,12 +192,30 @@ impl Database { Ok(()) } - /// Get the names of all tables in the database. - /// - /// # Returns - /// - /// * A [`Vec`] with all table names. - pub async fn table_names(&self) -> Result> { + /// Get the URI of a table in the database. + fn table_uri(&self, name: &str) -> Result { + let path = Path::new(&self.uri); + let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION)); + + let mut uri = table_uri + .as_path() + .to_str() + .context(InvalidTableNameSnafu { name })? + .to_string(); + + // If there are query string set on the connection, propagate to lance + if let Some(query) = self.query_string.as_ref() { + uri.push('?'); + uri.push_str(query.as_str()); + } + + Ok(uri) + } +} + +#[async_trait::async_trait] +impl Connection for Database { + async fn table_names(&self) -> Result> { let mut f = self .object_store .read_dir(self.base_path.clone()) @@ -183,16 +235,10 @@ impl Database { Ok(f) } - /// Create a new table in the database. - /// - /// # Arguments - /// * `name` - The name of the table. - /// * `batches` - The initial data to write to the table. - /// * `params` - Optional [`WriteParams`] to create the table. - pub async fn create_table( + async fn create_table( &self, name: &str, - batches: impl RecordBatchReader + Send + 'static, + batches: Box, params: Option, ) -> Result
{ let table_uri = self.table_uri(name)?; @@ -215,7 +261,7 @@ impl Database { /// # Returns /// /// * A [Table] object. - pub async fn open_table(&self, name: &str) -> Result
{ + async fn open_table(&self, name: &str) -> Result
{ self.open_table_with_params(name, ReadParams::default()) .await } @@ -229,41 +275,17 @@ impl Database { /// # Returns /// /// * A [Table] object. - pub async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result
{ + async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result
{ let table_uri = self.table_uri(name)?; Table::open_with_params(&table_uri, name, self.store_wrapper.clone(), params).await } - /// Drop a table in the database. - /// - /// # Arguments - /// * `name` - The name of the table. - pub async fn drop_table(&self, name: &str) -> Result<()> { + async fn drop_table(&self, name: &str) -> Result<()> { let dir_name = format!("{}.{}", name, LANCE_EXTENSION); let full_path = self.base_path.child(dir_name.clone()); self.object_store.remove_dir_all(full_path).await?; Ok(()) } - - /// Get the URI of a table in the database. - fn table_uri(&self, name: &str) -> Result { - let path = Path::new(&self.uri); - let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION)); - - let mut uri = table_uri - .as_path() - .to_str() - .context(InvalidTableNameSnafu { name })? - .to_string(); - - // If there are query string set on the connection, propagate to lance - if let Some(query) = self.query_string.as_ref() { - uri.push('?'); - uri.push_str(query.as_str()); - } - - Ok(uri) - } } #[cfg(test)] @@ -272,7 +294,7 @@ mod tests { use tempfile::tempdir; - use crate::database::Database; + use super::*; #[tokio::test] async fn test_connect() { diff --git a/rust/vectordb/src/io/object_store.rs b/rust/vectordb/src/io/object_store.rs index e323f9b9..7e812821 100644 --- a/rust/vectordb/src/io/object_store.rs +++ b/rust/vectordb/src/io/object_store.rs @@ -335,7 +335,7 @@ impl WrappingObjectStore for MirroringObjectStoreWrapper { #[cfg(all(test, not(windows)))] mod test { use super::*; - use crate::Database; + use crate::connection::{Connection, Database}; use arrow_array::PrimitiveArray; use futures::TryStreamExt; use lance::{dataset::WriteParams, io::object_store::ObjectStoreParams}; @@ -365,7 +365,7 @@ mod test { datagen = datagen.col(Box::new(RandomVector::default().named("vector".into()))); let res = db - .create_table("test", datagen.batch(100), Some(param.clone())) + .create_table("test", Box::new(datagen.batch(100)), Some(param.clone())) .await; // leave this here for easy debugging diff --git a/rust/vectordb/src/lib.rs b/rust/vectordb/src/lib.rs index 6aa4dd91..7125db22 100644 --- a/rust/vectordb/src/lib.rs +++ b/rust/vectordb/src/lib.rs @@ -46,7 +46,7 @@ //! #### Connect to a database. //! //! ```rust -//! use vectordb::{Database, Table, WriteMode}; +//! use vectordb::{connection::{Database, Connection}, Table, WriteMode}; //! use arrow_schema::{Field, Schema}; //! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! let db = Database::connect("data/sample-lancedb").await.unwrap(); @@ -66,7 +66,7 @@ //! use arrow_schema::{DataType, Schema, Field}; //! use arrow_array::{RecordBatch, RecordBatchIterator}; //! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type}; -//! # use vectordb::Database; +//! # use vectordb::connection::{Database, Connection}; //! //! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! # let tmpdir = tempfile::tempdir().unwrap(); @@ -86,7 +86,7 @@ //! ]).unwrap() //! ].into_iter().map(Ok), //! schema.clone()); -//! db.create_table("my_table", batches, None).await.unwrap(); +//! db.create_table("my_table", Box::new(batches), None).await.unwrap(); //! # }); //! ``` //! @@ -98,7 +98,7 @@ //! # use arrow_schema::{DataType, Schema, Field}; //! # use arrow_array::{RecordBatch, RecordBatchIterator}; //! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type}; -//! # use vectordb::Database; +//! # use vectordb::connection::{Database, Connection}; //! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! # let tmpdir = tempfile::tempdir().unwrap(); //! # let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap(); @@ -116,7 +116,7 @@ //! # ]).unwrap() //! # ].into_iter().map(Ok), //! # schema.clone()); -//! # db.create_table("my_table", batches, None).await.unwrap(); +//! # db.create_table("my_table", Box::new(batches), None).await.unwrap(); //! let table = db.open_table("my_table").await.unwrap(); //! let results = table //! .search(Some(vec![1.0; 128])) @@ -131,8 +131,8 @@ //! //! ``` +pub mod connection; pub mod data; -pub mod database; pub mod error; pub mod index; pub mod io; @@ -140,7 +140,7 @@ pub mod query; pub mod table; pub mod utils; -pub use database::Database; +pub use connection::Connection; pub use table::Table; pub use lance::dataset::WriteMode;