diff --git a/rust/lancedb/src/database/listing.rs b/rust/lancedb/src/database/listing.rs index a5ac62e6..1e016a1a 100644 --- a/rust/lancedb/src/database/listing.rs +++ b/rust/lancedb/src/database/listing.rs @@ -8,7 +8,7 @@ use std::path::Path; use std::{collections::HashMap, sync::Arc}; use lance::dataset::{ReadParams, WriteMode}; -use lance::io::{ObjectStore, ObjectStoreParams, ObjectStoreRegistry, WrappingObjectStore}; +use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore}; use lance_datafusion::utils::StreamingWriteSource; use lance_encoding::version::LanceFileVersion; use lance_table::io::commit::commit_handler_from_url; @@ -217,6 +217,9 @@ pub struct ListingDatabase { // Options for tables created by this connection new_table_config: NewTableConfig, + + // Session for object stores and caching + session: Arc, } impl std::fmt::Display for ListingDatabase { @@ -313,13 +316,17 @@ impl ListingDatabase { let plain_uri = url.to_string(); - let registry = Arc::new(ObjectStoreRegistry::default()); + let session = Arc::new(lance::session::Session::default()); let os_params = ObjectStoreParams { storage_options: Some(options.storage_options.clone()), ..Default::default() }; - let (object_store, base_path) = - ObjectStore::from_uri_and_params(registry, &plain_uri, &os_params).await?; + let (object_store, base_path) = ObjectStore::from_uri_and_params( + session.store_registry(), + &plain_uri, + &os_params, + ) + .await?; if object_store.is_local() { Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?; } @@ -342,6 +349,7 @@ impl ListingDatabase { read_consistency_interval: request.read_consistency_interval, storage_options: options.storage_options, new_table_config: options.new_table_config, + session, }) } Err(_) => { @@ -360,7 +368,13 @@ impl ListingDatabase { read_consistency_interval: Option, new_table_config: NewTableConfig, ) -> Result { - let (object_store, base_path) = ObjectStore::from_uri(path).await?; + let session = Arc::new(lance::session::Session::default()); + let (object_store, base_path) = ObjectStore::from_uri_and_params( + session.store_registry(), + path, + &ObjectStoreParams::default(), + ) + .await?; if object_store.is_local() { Self::try_create_dir(path).context(CreateDirSnafu { path })?; } @@ -374,6 +388,7 @@ impl ListingDatabase { read_consistency_interval, storage_options: HashMap::new(), new_table_config, + session, }) } @@ -441,6 +456,128 @@ impl ListingDatabase { } Ok(()) } + + /// Inherit storage options from the connection into the target map + fn inherit_storage_options(&self, target: &mut HashMap) { + for (key, value) in self.storage_options.iter() { + if !target.contains_key(key) { + target.insert(key.clone(), value.clone()); + } + } + } + + /// Extract storage option overrides from the request + fn extract_storage_overrides( + &self, + request: &CreateTableRequest, + ) -> Result<(Option, Option)> { + let storage_options = request + .write_options + .lance_write_params + .as_ref() + .and_then(|p| p.store_params.as_ref()) + .and_then(|sp| sp.storage_options.as_ref()); + + let storage_version_override = storage_options + .and_then(|opts| opts.get(OPT_NEW_TABLE_STORAGE_VERSION)) + .map(|s| s.parse::()) + .transpose()?; + + let v2_manifest_override = storage_options + .and_then(|opts| opts.get(OPT_NEW_TABLE_V2_MANIFEST_PATHS)) + .map(|s| s.parse::()) + .transpose() + .map_err(|_| Error::InvalidInput { + message: "enable_v2_manifest_paths must be a boolean".to_string(), + })?; + + Ok((storage_version_override, v2_manifest_override)) + } + + /// Prepare write parameters for table creation + fn prepare_write_params( + &self, + request: &CreateTableRequest, + storage_version_override: Option, + v2_manifest_override: Option, + ) -> lance::dataset::WriteParams { + let mut write_params = request + .write_options + .lance_write_params + .clone() + .unwrap_or_default(); + + // Only modify the storage options if we actually have something to + // inherit. There is a difference between storage_options=None and + // storage_options=Some({}). Using storage_options=None will cause the + // connection's session store registry to be used. Supplying Some({}) + // will cause a new connection to be created, and that connection will + // be dropped from the cache when python GCs the table object, which + // confounds reuse across tables. + if !self.storage_options.is_empty() { + let storage_options = write_params + .store_params + .get_or_insert_with(Default::default) + .storage_options + .get_or_insert_with(Default::default); + self.inherit_storage_options(storage_options); + } + + write_params.data_storage_version = self + .new_table_config + .data_storage_version + .or(storage_version_override); + + if let Some(enable_v2_manifest_paths) = self + .new_table_config + .enable_v2_manifest_paths + .or(v2_manifest_override) + { + write_params.enable_v2_manifest_paths = enable_v2_manifest_paths; + } + + if matches!(&request.mode, CreateTableMode::Overwrite) { + write_params.mode = WriteMode::Overwrite; + } + + write_params.session = Some(self.session.clone()); + + write_params + } + + /// Handle the case where table already exists based on the create mode + async fn handle_table_exists( + &self, + table_name: &str, + mode: CreateTableMode, + data_schema: &arrow_schema::Schema, + ) -> Result> { + match mode { + CreateTableMode::Create => Err(Error::TableAlreadyExists { + name: table_name.to_string(), + }), + CreateTableMode::ExistOk(callback) => { + let req = OpenTableRequest { + name: table_name.to_string(), + index_cache_size: None, + lance_read_params: None, + }; + let req = (callback)(req); + let table = self.open_table(req).await?; + + let table_schema = table.schema().await?; + + if table_schema.as_ref() != data_schema { + return Err(Error::Schema { + message: "Provided schema does not match existing table schema".to_string(), + }); + } + + Ok(table) + } + CreateTableMode::Overwrite => unreachable!(), + } + } } #[async_trait::async_trait] @@ -475,50 +612,14 @@ impl Database for ListingDatabase { Ok(f) } - async fn create_table(&self, mut request: CreateTableRequest) -> Result> { + async fn create_table(&self, request: CreateTableRequest) -> Result> { let table_uri = self.table_uri(&request.name)?; - // Inherit storage options from the connection - let storage_options = request - .write_options - .lance_write_params - .get_or_insert_with(Default::default) - .store_params - .get_or_insert_with(Default::default) - .storage_options - .get_or_insert_with(Default::default); - for (key, value) in self.storage_options.iter() { - if !storage_options.contains_key(key) { - storage_options.insert(key.clone(), value.clone()); - } - } - let storage_options = storage_options.clone(); + let (storage_version_override, v2_manifest_override) = + self.extract_storage_overrides(&request)?; - let mut write_params = request.write_options.lance_write_params.unwrap_or_default(); - - if let Some(storage_version) = &self.new_table_config.data_storage_version { - write_params.data_storage_version = Some(*storage_version); - } else { - // Allow the user to override the storage version via storage options (backwards compatibility) - if let Some(data_storage_version) = storage_options.get(OPT_NEW_TABLE_STORAGE_VERSION) { - write_params.data_storage_version = Some(data_storage_version.parse()?); - } - } - if let Some(enable_v2_manifest_paths) = self.new_table_config.enable_v2_manifest_paths { - write_params.enable_v2_manifest_paths = enable_v2_manifest_paths; - } else { - // Allow the user to override the storage version via storage options (backwards compatibility) - if let Some(enable_v2_manifest_paths) = storage_options - .get(OPT_NEW_TABLE_V2_MANIFEST_PATHS) - .map(|s| s.parse::().unwrap()) - { - write_params.enable_v2_manifest_paths = enable_v2_manifest_paths; - } - } - - if matches!(&request.mode, CreateTableMode::Overwrite) { - write_params.mode = WriteMode::Overwrite; - } + let write_params = + self.prepare_write_params(&request, storage_version_override, v2_manifest_override); let data_schema = request.data.arrow_schema(); @@ -533,30 +634,10 @@ impl Database for ListingDatabase { .await { Ok(table) => Ok(Arc::new(table)), - Err(Error::TableAlreadyExists { name }) => match request.mode { - CreateTableMode::Create => Err(Error::TableAlreadyExists { name }), - CreateTableMode::ExistOk(callback) => { - let req = OpenTableRequest { - name: request.name.clone(), - index_cache_size: None, - lance_read_params: None, - }; - let req = (callback)(req); - let table = self.open_table(req).await?; - - let table_schema = table.schema().await?; - - if table_schema != data_schema { - return Err(Error::Schema { - message: "Provided schema does not match existing table schema" - .to_string(), - }); - } - - Ok(table) - } - CreateTableMode::Overwrite => unreachable!(), - }, + Err(Error::TableAlreadyExists { .. }) => { + self.handle_table_exists(&request.name, request.mode, &data_schema) + .await + } Err(err) => Err(err), } } @@ -564,18 +645,22 @@ impl Database for ListingDatabase { async fn open_table(&self, mut request: OpenTableRequest) -> Result> { let table_uri = self.table_uri(&request.name)?; - // Inherit storage options from the connection - let storage_options = request - .lance_read_params - .get_or_insert_with(Default::default) - .store_options - .get_or_insert_with(Default::default) - .storage_options - .get_or_insert_with(Default::default); - for (key, value) in self.storage_options.iter() { - if !storage_options.contains_key(key) { - storage_options.insert(key.clone(), value.clone()); - } + // Only modify the storage options if we actually have something to + // inherit. There is a difference between storage_options=None and + // storage_options=Some({}). Using storage_options=None will cause the + // connection's session store registry to be used. Supplying Some({}) + // will cause a new connection to be created, and that connection will + // be dropped from the cache when python GCs the table object, which + // confounds reuse across tables. + if !self.storage_options.is_empty() { + let storage_options = request + .lance_read_params + .get_or_insert_with(Default::default) + .store_options + .get_or_insert_with(Default::default) + .storage_options + .get_or_insert_with(Default::default); + self.inherit_storage_options(storage_options); } // Some ReadParams are exposed in the OpenTableBuilder, but we also @@ -584,13 +669,14 @@ impl Database for ListingDatabase { // If we have a user provided ReadParams use that // If we don't then start with the default ReadParams and customize it with // the options from the OpenTableBuilder - let read_params = request.lance_read_params.unwrap_or_else(|| { + let mut read_params = request.lance_read_params.unwrap_or_else(|| { let mut default_params = ReadParams::default(); if let Some(index_cache_size) = request.index_cache_size { default_params.index_cache_size = index_cache_size as usize; } default_params }); + read_params.session(self.session.clone()); let native_table = Arc::new( NativeTable::open_with_params( diff --git a/rust/lancedb/tests/object_store_test.rs b/rust/lancedb/tests/object_store_test.rs index c5beea46..b2deb4b5 100644 --- a/rust/lancedb/tests/object_store_test.rs +++ b/rust/lancedb/tests/object_store_test.rs @@ -281,6 +281,46 @@ async fn test_encryption() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_table_storage_options_override() -> Result<()> { + // Test that table-level storage options override connection-level options + let bucket = S3Bucket::new("test-override").await; + let key1 = KMSKey::new().await; + let key2 = KMSKey::new().await; + + let uri = format!("s3://{}", bucket.0); + + // Create connection with key1 encryption + let db = lancedb::connect(&uri) + .storage_options(CONFIG.iter().cloned()) + .storage_option("aws_server_side_encryption", "aws:kms") + .storage_option("aws_sse_kms_key_id", &key1.0) + .execute() + .await?; + + // Create table overriding with key2 encryption + let data = test_data(); + let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); + let _table = db + .create_table("test_override", data) + .storage_option("aws_sse_kms_key_id", &key2.0) + .execute() + .await?; + + // Verify objects are encrypted with key2, not key1 + validate_objects_encrypted(&bucket.0, "test_override", &key2.0).await; + + // Also test that a table created without override uses connection settings + let data = test_data(); + let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema()); + let _table2 = db.create_table("test_inherit", data).execute().await?; + + // Verify this table uses key1 from connection + validate_objects_encrypted(&bucket.0, "test_inherit", &key1.0).await; + + Ok(()) +} + struct DynamoDBCommitTable(String); impl DynamoDBCommitTable {