diff --git a/Cargo.toml b/Cargo.toml index 4a4d667f..1c529852 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ rust-version = "1.80.0" # TO lance = { "version" = "=0.20.0", "features" = [ "dynamodb", ], git = "https://github.com/lancedb/lance.git", tag = "v0.20.0-beta.3" } +lance-io = { version = "=0.20.0", git = "https://github.com/lancedb/lance.git", tag = "v0.20.0-beta.3" } lance-index = { version = "=0.20.0", git = "https://github.com/lancedb/lance.git", tag = "v0.20.0-beta.3" } lance-linalg = { version = "=0.20.0", git = "https://github.com/lancedb/lance.git", tag = "v0.20.0-beta.3" } lance-table = { version = "=0.20.0", git = "https://github.com/lancedb/lance.git", tag = "v0.20.0-beta.3" } diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 2ac09fea..c5308672 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -27,6 +27,7 @@ half = { workspace = true } lazy_static.workspace = true lance = { workspace = true } lance-datafusion.workspace = true +lance-io = { workspace = true } lance-index = { workspace = true } lance-table = { workspace = true } lance-linalg = { workspace = true } diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index f7e17b39..90173fd3 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -38,6 +38,8 @@ use crate::table::{NativeTable, TableDefinition, WriteOptions}; use crate::utils::validate_table_name; use crate::Table; pub use lance_encoding::version::LanceFileVersion; +#[cfg(feature = "remote")] +use lance_io::object_store::StorageOptions; use lance_table::io::commit::commit_handler_from_url; pub const LANCE_FILE_EXTENSION: &str = "lance"; @@ -718,12 +720,14 @@ impl ConnectBuilder { message: "An api_key is required when connecting to LanceDb Cloud".to_string(), })?; + let storage_options = StorageOptions(self.storage_options.clone()); let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new( &self.uri, &api_key, ®ion, self.host_override, self.client_config, + storage_options.into(), )?); Ok(Connection { internal, @@ -856,7 +860,7 @@ impl Database { let table_base_uri = if let Some(store) = engine { static WARN_ONCE: std::sync::Once = std::sync::Once::new(); WARN_ONCE.call_once(|| { - log::warn!("Specifing engine is not a publicly supported feature in lancedb yet. THE API WILL CHANGE"); + log::warn!("Specifying engine is not a publicly supported feature in lancedb yet. THE API WILL CHANGE"); }); let old_scheme = url.scheme().to_string(); let new_scheme = format!("{}+{}", old_scheme, store); diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 48c8aa1c..c560332b 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -21,6 +21,7 @@ use reqwest::{ }; use crate::error::{Error, Result}; +use crate::remote::db::RemoteOptions; const REQUEST_ID_HEADER: &str = "x-request-id"; @@ -215,6 +216,7 @@ impl RestfulLanceDbClient { region: &str, host_override: Option, client_config: ClientConfig, + options: &RemoteOptions, ) -> Result { let parsed_url = url::Url::parse(db_url).map_err(|err| Error::InvalidInput { message: format!("db_url is not a valid URL. '{db_url}'. Error: {err}"), @@ -255,6 +257,7 @@ impl RestfulLanceDbClient { region, db_name, host_override.is_some(), + options, )?) .user_agent(client_config.user_agent) .build() @@ -262,6 +265,7 @@ impl RestfulLanceDbClient { message: "Failed to build HTTP client".into(), source: Some(Box::new(err)), })?; + let host = match host_override { Some(host_override) => host_override, None => format!("https://{}.{}.api.lancedb.com", db_name, region), @@ -287,6 +291,7 @@ impl RestfulLanceDbClient { region: &str, db_name: &str, has_host_override: bool, + options: &RemoteOptions, ) -> Result { let mut headers = HeaderMap::new(); headers.insert( @@ -313,6 +318,23 @@ impl RestfulLanceDbClient { ); } + if let Some(v) = options.0.get("account_name") { + headers.insert( + "x-azure-storage-account-name", + HeaderValue::from_str(v).map_err(|_| Error::InvalidInput { + message: format!("non-ascii storage account name '{}' provided", db_name), + })?, + ); + } + if let Some(v) = options.0.get("azure_storage_account_name") { + headers.insert( + "x-azure-storage-account-name", + HeaderValue::from_str(v).map_err(|_| Error::InvalidInput { + message: format!("non-ascii storage account name '{}' provided", db_name), + })?, + ); + } + Ok(headers) } diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index fc10fbdb..5d0ebedf 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::sync::Arc; use arrow_array::RecordBatchReader; use async_trait::async_trait; use http::StatusCode; +use lance_io::object_store::StorageOptions; use moka::future::Cache; use reqwest::header::CONTENT_TYPE; use serde::Deserialize; @@ -53,9 +55,16 @@ impl RemoteDatabase { region: &str, host_override: Option, client_config: ClientConfig, + options: RemoteOptions, ) -> Result { - let client = - RestfulLanceDbClient::try_new(uri, api_key, region, host_override, client_config)?; + let client = RestfulLanceDbClient::try_new( + uri, + api_key, + region, + host_override, + client_config, + &options, + )?; let table_cache = Cache::builder() .time_to_live(std::time::Duration::from_secs(300)) @@ -243,6 +252,29 @@ impl ConnectionInternal for RemoteDatabase { } } +/// RemoteOptions contains a subset of StorageOptions that are compatible with Remote LanceDB connections +#[derive(Clone, Debug, Default)] +pub struct RemoteOptions(pub HashMap); + +impl RemoteOptions { + pub fn new(options: HashMap) -> Self { + Self(options) + } +} + +impl From for RemoteOptions { + fn from(options: StorageOptions) -> Self { + let supported_opts = vec!["account_name", "azure_storage_account_name"]; + let mut filtered = HashMap::new(); + for opt in supported_opts { + if let Some(v) = options.0.get(opt) { + filtered.insert(opt.to_string(), v.to_string()); + } + } + RemoteOptions::new(filtered) + } +} + #[cfg(test)] mod tests { use std::sync::{Arc, OnceLock}; @@ -250,6 +282,7 @@ mod tests { use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; + use crate::connection::ConnectBuilder; use crate::{ connection::CreateTableMode, remote::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE}, @@ -541,4 +574,16 @@ mod tests { }); conn.rename_table("table1", "table2").await.unwrap(); } + + #[tokio::test] + async fn test_connect_remote_options() { + let db_uri = "db://my-container/my-prefix"; + let _ = ConnectBuilder::new(db_uri) + .region("us-east-1") + .api_key("my-api-key") + .storage_options(vec![("azure_storage_account_name", "my-storage-account")]) + .execute() + .await + .unwrap(); + } }