diff --git a/rust/lancedb/src/catalog.rs b/rust/lancedb/src/catalog.rs index a9b9f0bd..a42269cf 100644 --- a/rust/lancedb/src/catalog.rs +++ b/rust/lancedb/src/catalog.rs @@ -12,6 +12,10 @@ use crate::database::Database; use crate::error::Result; use async_trait::async_trait; +pub trait CatalogOptions { + fn serialize_into_map(&self, map: &mut HashMap); +} + /// Request parameters for listing databases #[derive(Clone, Debug, Default)] pub struct DatabaseNamesRequest { diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 7685da3d..a413ce0d 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -12,6 +12,8 @@ use lance::dataset::ReadParams; use object_store::aws::AwsCredential; use crate::arrow::{IntoArrow, IntoArrowStream, SendableRecordBatchStream}; +use crate::catalog::listing::ListingCatalog; +use crate::catalog::CatalogOptions; use crate::database::listing::{ ListingDatabase, OPT_NEW_TABLE_STORAGE_VERSION, OPT_NEW_TABLE_V2_MANIFEST_PATHS, }; @@ -830,6 +832,52 @@ pub fn connect(uri: &str) -> ConnectBuilder { ConnectBuilder::new(uri) } +/// A builder for configuring a connection to a LanceDB catalog +#[derive(Debug)] +pub struct CatalogConnectBuilder { + request: ConnectRequest, +} + +impl CatalogConnectBuilder { + /// Create a new [`CatalogConnectBuilder`] with the given catalog URI. + pub fn new(uri: &str) -> Self { + Self { + request: ConnectRequest { + uri: uri.to_string(), + api_key: None, + region: None, + host_override: None, + #[cfg(feature = "remote")] + client_config: Default::default(), + read_consistency_interval: None, + storage_options: HashMap::new(), + }, + } + } + + pub fn catalog_options(mut self, catalog_options: &dyn CatalogOptions) -> Self { + catalog_options.serialize_into_map(&mut self.request.storage_options); + self + } + + /// Establishes a connection to the catalog + pub async fn execute(self) -> Result> { + let catalog = ListingCatalog::connect(&self.request).await?; + Ok(Arc::new(catalog)) + } +} + +/// Connect to a LanceDB catalog. +/// +/// A catalog is a container for databases, which in turn are containers for tables. +/// +/// # Arguments +/// +/// * `uri` - URI where the catalog is located, can be a local directory or supported remote cloud storage. +pub fn connect_catalog(uri: &str) -> CatalogConnectBuilder { + CatalogConnectBuilder::new(uri) +} + #[cfg(all(test, feature = "remote"))] mod test_utils { use super::*; @@ -854,6 +902,10 @@ mod test_utils { mod tests { use std::fs::create_dir_all; + use crate::catalog::{Catalog, DatabaseNamesRequest, OpenDatabaseRequest}; + use crate::database::listing::{ListingDatabaseOptions, NewTableConfig}; + use crate::query::QueryBase; + use crate::query::{ExecutableQuery, QueryExecutionOptions}; use arrow::compute::concat_batches; use arrow_array::RecordBatchReader; use arrow_schema::{DataType, Field, Schema}; @@ -864,9 +916,6 @@ mod tests { use tempfile::tempdir; use crate::arrow::SimpleRecordBatchStream; - use crate::database::listing::{ListingDatabaseOptions, NewTableConfig}; - use crate::query::QueryBase; - use crate::query::{ExecutableQuery, QueryExecutionOptions}; use super::*; @@ -1157,4 +1206,91 @@ mod tests { .unwrap(); assert_eq!(other_schema, overwritten.schema().await.unwrap()); } + + #[tokio::test] + async fn test_connect_catalog() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let catalog = connect_catalog(uri).execute().await.unwrap(); + + // Verify that we can get the uri from the catalog + let catalog_uri = catalog.uri(); + assert_eq!(catalog_uri, uri); + + // Check that the catalog is initially empty + let dbs = catalog + .database_names(DatabaseNamesRequest::default()) + .await + .unwrap(); + assert_eq!(dbs.len(), 0); + } + + #[tokio::test] + #[cfg(not(windows))] + async fn test_catalog_create_database() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let catalog = connect_catalog(uri).execute().await.unwrap(); + + let db_name = "test_db"; + catalog + .create_database(crate::catalog::CreateDatabaseRequest { + name: db_name.to_string(), + mode: Default::default(), + options: Default::default(), + }) + .await + .unwrap(); + + let dbs = catalog + .database_names(DatabaseNamesRequest::default()) + .await + .unwrap(); + assert_eq!(dbs.len(), 1); + assert_eq!(dbs[0], db_name); + + let db = catalog + .open_database(OpenDatabaseRequest { + name: db_name.to_string(), + database_options: HashMap::new(), + }) + .await + .unwrap(); + + let tables = db.table_names(Default::default()).await.unwrap(); + assert_eq!(tables.len(), 0); + } + + #[tokio::test] + #[cfg(not(windows))] + async fn test_catalog_drop_database() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let catalog = connect_catalog(uri).execute().await.unwrap(); + + // Create and then drop a database + let db_name = "test_db_to_drop"; + catalog + .create_database(crate::catalog::CreateDatabaseRequest { + name: db_name.to_string(), + mode: Default::default(), + options: Default::default(), + }) + .await + .unwrap(); + + let dbs = catalog + .database_names(DatabaseNamesRequest::default()) + .await + .unwrap(); + assert_eq!(dbs.len(), 1); + + catalog.drop_database(db_name).await.unwrap(); + + let dbs_after = catalog + .database_names(DatabaseNamesRequest::default()) + .await + .unwrap(); + assert_eq!(dbs_after.len(), 0); + } }