diff --git a/Cargo.lock b/Cargo.lock index 9b994ede..c058b936 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1041,6 +1041,61 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.7.0", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backoff" version = "0.4.0" @@ -3930,6 +3985,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "httparse", + "httpdate", "itoa", "pin-project-lite", "pin-utils", @@ -4826,6 +4882,7 @@ dependencies = [ "arrow-ipc", "arrow-schema", "async-trait", + "axum", "bytes", "futures", "lance", @@ -4837,9 +4894,12 @@ dependencies = [ "object_store", "rand 0.9.2", "reqwest", + "serde", "serde_json", "snafu", "tokio", + "tower", + "tower-http 0.5.2", "url", ] @@ -5277,6 +5337,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "matrixmultiply" version = "0.3.10" @@ -7265,7 +7331,7 @@ dependencies = [ "tokio-rustls 0.26.4", "tokio-util", "tower", - "tower-http", + "tower-http 0.6.6", "tower-service", "url", "wasm-bindgen", @@ -7784,6 +7850,17 @@ dependencies = [ "serde_core", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_plain" version = "1.0.2" @@ -8819,6 +8896,24 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" +dependencies = [ + "bitflags 2.9.4", + "bytes", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", ] [[package]] @@ -8857,6 +8952,7 @@ version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 01bbac1e..637e5c6b 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -110,7 +110,7 @@ oss = ["lance/oss", "lance-io/oss", "lance-namespace-impls/dir-oss"] gcs = ["lance/gcp", "lance-io/gcp", "lance-namespace-impls/dir-gcp"] azure = ["lance/azure", "lance-io/azure", "lance-namespace-impls/dir-azure"] dynamodb = ["lance/dynamodb", "aws"] -remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest"] +remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"] fp16kernels = ["lance-linalg/fp16kernels"] s3-test = [] bedrock = ["dep:aws-sdk-bedrockruntime"] diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 963bce6d..b086b76f 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -804,6 +804,14 @@ impl Connection { self.internal.describe_namespace(request).await } + /// Get the equivalent namespace client in the database of this connection. + /// For LanceNamespaceDatabase, it is the underlying LanceNamespace. + /// For ListingDatabase, it is the equivalent DirectoryNamespace. + /// For RemoteDatabase, it is the equivalent RestNamespace. + pub async fn namespace_client(&self) -> Result> { + self.internal.namespace_client().await + } + /// List tables with pagination support pub async fn list_tables(&self, request: ListTablesRequest) -> Result { self.internal.list_tables(request).await diff --git a/rust/lancedb/src/database.rs b/rust/lancedb/src/database.rs index 5e594aaf..36947727 100644 --- a/rust/lancedb/src/database.rs +++ b/rust/lancedb/src/database.rs @@ -296,4 +296,10 @@ pub trait Database: /// Drop all tables in the database async fn drop_all_tables(&self, namespace: &[String]) -> Result<()>; fn as_any(&self) -> &dyn std::any::Any; + + /// Get the equivalent namespace client of this database + /// For LanceNamespaceDatabase, it is the underlying LanceNamespace. + /// For ListingDatabase, it is the equivalent DirectoryNamespace. + /// For RemoteDatabase, it is the equivalent RestNamespace. + async fn namespace_client(&self) -> Result>; } diff --git a/rust/lancedb/src/database/listing.rs b/rust/lancedb/src/database/listing.rs index 05b11e2f..626baa46 100644 --- a/rust/lancedb/src/database/listing.rs +++ b/rust/lancedb/src/database/listing.rs @@ -1043,6 +1043,24 @@ impl Database for ListingDatabase { fn as_any(&self) -> &dyn std::any::Any { self } + + async fn namespace_client(&self) -> Result> { + // Create a DirectoryNamespace pointing to the same root with the same storage options + let mut builder = lance_namespace_impls::DirectoryNamespaceBuilder::new(&self.uri); + + // Add storage options + if !self.storage_options.is_empty() { + builder = builder.storage_options(self.storage_options.clone()); + } + + // Use the same session + builder = builder.session(self.session.clone()); + + let namespace = builder.build().await.map_err(|e| Error::Runtime { + message: format!("Failed to create namespace client: {}", e), + })?; + Ok(Arc::new(namespace) as Arc) + } } #[cfg(test)] @@ -2027,4 +2045,63 @@ mod tests { let db_options = ListingDatabaseOptions::parse_from_map(&options).unwrap(); assert_eq!(db_options.new_table_config.enable_stable_row_ids, None); } + + #[tokio::test] + async fn test_namespace_client() { + let (_tempdir, db) = setup_database().await; + + // Create some tables first + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + + db.create_table(CreateTableRequest { + name: "table1".to_string(), + namespace: vec![], + data: CreateTableData::Empty(TableDefinition::new_from_schema(schema.clone())), + mode: CreateTableMode::Create, + write_options: Default::default(), + location: None, + namespace_client: None, + }) + .await + .unwrap(); + + db.create_table(CreateTableRequest { + name: "table2".to_string(), + namespace: vec![], + data: CreateTableData::Empty(TableDefinition::new_from_schema(schema)), + mode: CreateTableMode::Create, + write_options: Default::default(), + location: None, + namespace_client: None, + }) + .await + .unwrap(); + + // Get the namespace client + let namespace_client = db.namespace_client().await; + assert!(namespace_client.is_ok()); + let namespace_client = namespace_client.unwrap(); + + // Verify the namespace client can list the tables we created + // Use empty vec for root namespace + let list_result = namespace_client + .list_tables(lance_namespace::models::ListTablesRequest { + id: Some(vec![]), + ..Default::default() + }) + .await; + assert!( + list_result.is_ok(), + "list_tables failed: {:?}", + list_result.err() + ); + + let tables = list_result.unwrap().tables; + assert_eq!(tables.len(), 2); + assert!(tables.contains(&"table1".to_string())); + assert!(tables.contains(&"table2".to_string())); + } } diff --git a/rust/lancedb/src/database/namespace.rs b/rust/lancedb/src/database/namespace.rs index 176dc4bb..6f34529e 100644 --- a/rust/lancedb/src/database/namespace.rs +++ b/rust/lancedb/src/database/namespace.rs @@ -425,6 +425,10 @@ impl Database for LanceNamespaceDatabase { fn as_any(&self) -> &dyn std::any::Any { self } + + async fn namespace_client(&self) -> Result> { + Ok(self.namespace.clone()) + } } #[cfg(test)] diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 8dd941b4..46e51569 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -232,6 +232,38 @@ impl HttpSend for Sender { } } +/// Parsed components from a database URL (db://...) +pub struct ParsedDbUrl { + pub db_name: String, + pub db_prefix: Option, +} + +/// Parse a database URL and extract the database name and optional prefix. +/// +/// Expected format: `db://db_name` or `db://db_name/prefix` +pub fn parse_db_url(db_url: &str) -> Result { + let parsed_url = url::Url::parse(db_url).map_err(|err| Error::InvalidInput { + message: format!("db_url is not a valid URL. '{db_url}'. Error: {err}"), + })?; + debug_assert_eq!(parsed_url.scheme(), "db"); + if !parsed_url.has_host() { + return Err(Error::InvalidInput { + message: format!("Invalid database URL (missing host) '{}'", db_url), + }); + } + let db_name = parsed_url.host_str().unwrap().to_string(); + let db_prefix = { + let prefix = parsed_url.path().trim_start_matches('/'); + if prefix.is_empty() { + None + } else { + Some(prefix.to_string()) + } + }; + + Ok(ParsedDbUrl { db_name, db_prefix }) +} + impl RestfulLanceDbClient { fn get_timeout(passed: Option, env_var: &str) -> Result> { if let Some(passed) = passed { @@ -250,32 +282,12 @@ impl RestfulLanceDbClient { } pub fn try_new( - db_url: &str, - api_key: &str, + parsed_url: &ParsedDbUrl, region: &str, host_override: Option, + default_headers: HeaderMap, client_config: ClientConfig, - options: &RemoteOptions, ) -> Result { - let parsed_url = url::Url::parse(db_url).map_err(|err| Error::InvalidInput { - message: format!("db_url is not a valid URL. '{db_url}'. Error: {err}"), - })?; - debug_assert_eq!(parsed_url.scheme(), "db"); - if !parsed_url.has_host() { - return Err(Error::InvalidInput { - message: format!("Invalid database URL (missing host) '{}'", db_url), - }); - } - let db_name = parsed_url.host_str().unwrap(); - let db_prefix = { - let prefix = parsed_url.path().trim_start_matches('/'); - if prefix.is_empty() { - None - } else { - Some(prefix) - } - }; - // Get the timeouts let timeout = Self::get_timeout(client_config.timeout_config.timeout, "LANCE_CLIENT_TIMEOUT")?; @@ -348,15 +360,7 @@ impl RestfulLanceDbClient { } let client = client_builder - .default_headers(Self::default_headers( - api_key, - region, - db_name, - host_override.is_some(), - options, - db_prefix, - &client_config, - )?) + .default_headers(default_headers) .user_agent(client_config.user_agent) .build() .map_err(|err| Error::Other { @@ -366,7 +370,7 @@ impl RestfulLanceDbClient { let host = match host_override { Some(host_override) => host_override, - None => format!("https://{}.{}.api.lancedb.com", db_name, region), + None => format!("https://{}.{}.api.lancedb.com", parsed_url.db_name, region), }; debug!("Created client for host: {}", host); let retry_config = client_config.retry_config.clone().try_into()?; @@ -389,7 +393,7 @@ impl RestfulLanceDbClient { &self.host } - fn default_headers( + pub fn default_headers( api_key: &str, region: &str, db_name: &str, diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 1cdfe1d7..895c18a0 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -189,6 +189,10 @@ pub struct RemoteDatabase { client: RestfulLanceDbClient, table_cache: Cache>>, uri: String, + /// Headers to pass to the namespace client for authentication + namespace_headers: HashMap, + /// TLS configuration for mTLS support + tls_config: Option, } impl RemoteDatabase { @@ -200,13 +204,32 @@ impl RemoteDatabase { client_config: ClientConfig, options: RemoteOptions, ) -> Result { - let client = RestfulLanceDbClient::try_new( - uri, + let parsed = super::client::parse_db_url(uri)?; + let header_map = RestfulLanceDbClient::::default_headers( api_key, region, - host_override, - client_config, + &parsed.db_name, + host_override.is_some(), &options, + parsed.db_prefix.as_deref(), + &client_config, + )?; + + let namespace_headers: HashMap = header_map + .iter() + .filter_map(|(k, v)| { + v.to_str() + .ok() + .map(|val| (k.as_str().to_string(), val.to_string())) + }) + .collect(); + + let client = RestfulLanceDbClient::try_new( + &parsed, + region, + host_override, + header_map, + client_config.clone(), )?; let table_cache = Cache::builder() @@ -218,6 +241,8 @@ impl RemoteDatabase { client, table_cache, uri: uri.to_owned(), + namespace_headers, + tls_config: client_config.tls_config, }) } } @@ -240,6 +265,8 @@ mod test_utils { client, table_cache: Cache::new(0), uri: "http://localhost".to_string(), + namespace_headers: HashMap::new(), + tls_config: None, } } @@ -248,11 +275,13 @@ mod test_utils { F: Fn(reqwest::Request) -> http::Response + Send + Sync + 'static, T: Into, { - let client = client_with_handler_and_config(handler, config); + let client = client_with_handler_and_config(handler, config.clone()); Self { client, table_cache: Cache::new(0), uri: "http://localhost".to_string(), + namespace_headers: config.extra_headers.clone(), + tls_config: config.tls_config.clone(), } } } @@ -716,7 +745,8 @@ impl Database for RemoteDatabase { let namespace_id = build_namespace_identifier(namespace_parts, &self.client.id_delimiter); let req = self .client - .get(&format!("/v1/namespace/{}/describe", namespace_id)); + .post(&format!("/v1/namespace/{}/describe", namespace_id)) + .json(&DescribeNamespaceRequest::default()); let (request_id, resp) = self.client.send(req).await?; let resp = self.client.check_response(&request_id, resp).await?; @@ -727,6 +757,31 @@ impl Database for RemoteDatabase { fn as_any(&self) -> &dyn std::any::Any { self } + + async fn namespace_client(&self) -> Result> { + // Create a RestNamespace pointing to the same remote host with the same authentication headers + let mut builder = lance_namespace_impls::RestNamespaceBuilder::new(self.client.host()) + .delimiter(&self.client.id_delimiter) + // TODO: support header provider + .headers(self.namespace_headers.clone()); + + // Apply mTLS configuration if present + if let Some(tls_config) = &self.tls_config { + if let Some(cert_file) = &tls_config.cert_file { + builder = builder.cert_file(cert_file); + } + if let Some(key_file) = &tls_config.key_file { + builder = builder.key_file(key_file); + } + if let Some(ssl_ca_cert) = &tls_config.ssl_ca_cert { + builder = builder.ssl_ca_cert(ssl_ca_cert); + } + builder = builder.assert_hostname(tls_config.assert_hostname); + } + + let namespace = builder.build(); + Ok(Arc::new(namespace) as Arc) + } } /// RemoteOptions contains a subset of StorageOptions that are compatible with Remote LanceDB connections @@ -1518,4 +1573,265 @@ mod tests { panic!("Expected HTTP error"); } } + + #[tokio::test] + async fn test_namespace_client() { + let conn = Connection::new_with_handler(|_| { + http::Response::builder() + .status(200) + .body(r#"{"tables": []}"#) + .unwrap() + }); + + // Get the namespace client from the connection's internal database + let namespace_client = conn.namespace_client().await; + assert!(namespace_client.is_ok()); + } + + #[tokio::test] + async fn test_namespace_client_with_tls_config() { + use crate::remote::client::TlsConfig; + + let tls_config = TlsConfig { + cert_file: Some("/path/to/cert.pem".to_string()), + key_file: Some("/path/to/key.pem".to_string()), + ssl_ca_cert: Some("/path/to/ca.pem".to_string()), + assert_hostname: true, + }; + + let client_config = ClientConfig { + tls_config: Some(tls_config), + ..Default::default() + }; + + let conn = Connection::new_with_handler_and_config( + |_| { + http::Response::builder() + .status(200) + .body(r#"{"tables": []}"#) + .unwrap() + }, + client_config, + ); + + // Get the namespace client - it should be created with the TLS config + let namespace_client = conn.namespace_client().await; + assert!(namespace_client.is_ok()); + } + + #[tokio::test] + async fn test_namespace_client_with_headers() { + let mut extra_headers = HashMap::new(); + extra_headers.insert("X-Custom-Header".to_string(), "custom-value".to_string()); + + let client_config = ClientConfig { + extra_headers, + ..Default::default() + }; + + let conn = Connection::new_with_handler_and_config( + |_| { + http::Response::builder() + .status(200) + .body(r#"{"tables": []}"#) + .unwrap() + }, + client_config, + ); + + // Get the namespace client - it should be created with the extra headers + let namespace_client = conn.namespace_client().await; + assert!(namespace_client.is_ok()); + } + + /// Integration tests using RestAdapter to run RemoteDatabase against a real namespace server + mod rest_adapter_integration { + use super::*; + use lance_namespace::models::ListTablesRequest; + use lance_namespace_impls::{DirectoryNamespaceBuilder, RestAdapter, RestAdapterConfig}; + use std::sync::Arc; + use tempfile::TempDir; + + /// Test fixture that manages a REST server backed by DirectoryNamespace + struct RestServerFixture { + _temp_dir: TempDir, + server_handle: lance_namespace_impls::RestAdapterHandle, + server_url: String, + } + + impl RestServerFixture { + async fn new() -> Self { + let temp_dir = TempDir::new().unwrap(); + let temp_path = temp_dir.path().to_str().unwrap().to_string(); + + // Create DirectoryNamespace backend + let backend = DirectoryNamespaceBuilder::new(&temp_path) + .build() + .await + .unwrap(); + let backend = Arc::new(backend); + + // Start REST server with port 0 (OS assigns available port) + let config = RestAdapterConfig { + port: 0, + ..Default::default() + }; + + let server = RestAdapter::new(backend, config); + let server_handle = server.start().await.unwrap(); + + // Get the actual port assigned by OS + let actual_port = server_handle.port(); + let server_url = format!("http://127.0.0.1:{}", actual_port); + + Self { + _temp_dir: temp_dir, + server_handle, + server_url, + } + } + } + + impl Drop for RestServerFixture { + fn drop(&mut self) { + self.server_handle.shutdown(); + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_remote_database_with_rest_adapter() { + use lance_namespace::models::CreateNamespaceRequest; + + let fixture = RestServerFixture::new().await; + + // Connect to the REST server using lancedb Connection + // Use db://dummy as URI and set actual server URL via host_override + let conn = ConnectBuilder::new("db://dummy") + .api_key("test-api-key") + .region("us-east-1") + .host_override(&fixture.server_url) + .execute() + .await + .unwrap(); + + // Create a child namespace first + let namespace = vec!["test_ns".to_string()]; + conn.create_namespace(CreateNamespaceRequest { + id: Some(namespace.clone()), + mode: None, + properties: None, + }) + .await + .expect("Failed to create namespace"); + + // Create a table in the child namespace + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let reader = RecordBatchIterator::new([Ok(data.clone())], schema.clone()); + + let table = conn + .create_table("test_table", reader) + .namespace(namespace.clone()) + .execute() + .await; + assert!(table.is_ok(), "Failed to create table: {:?}", table.err()); + + // List tables in the child namespace + let list_response = conn + .list_tables(ListTablesRequest { + id: Some(namespace.clone()), + page_token: None, + limit: None, + }) + .await + .expect("Failed to list tables"); + assert_eq!(list_response.tables, vec!["test_table"]); + + // Get namespace client and verify it can also list tables + let namespace_client = conn.namespace_client().await.unwrap(); + let list_response = namespace_client + .list_tables(ListTablesRequest { + id: Some(namespace.clone()), + page_token: None, + limit: None, + }) + .await + .unwrap(); + assert_eq!(list_response.tables, vec!["test_table"]); + + // Open the table from the child namespace + let opened_table = conn + .open_table("test_table") + .namespace(namespace.clone()) + .execute() + .await; + assert!( + opened_table.is_ok(), + "Failed to open table: {:?}", + opened_table.err() + ); + assert_eq!(opened_table.unwrap().name(), "test_table"); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_remote_database_with_multiple_tables() { + use lance_namespace::models::CreateNamespaceRequest; + + let fixture = RestServerFixture::new().await; + + // Connect to the REST server + // Use db://dummy as URI and set actual server URL via host_override + let conn = ConnectBuilder::new("db://dummy") + .api_key("test-api-key") + .region("us-east-1") + .host_override(&fixture.server_url) + .execute() + .await + .unwrap(); + + // Create a child namespace first + let namespace = vec!["multi_table_ns".to_string()]; + conn.create_namespace(CreateNamespaceRequest { + id: Some(namespace.clone()), + mode: None, + properties: None, + }) + .await + .expect("Failed to create namespace"); + + // Create multiple tables in the child namespace + let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + for i in 1..=3 { + let data = + RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from(vec![i]))]) + .unwrap(); + let reader = RecordBatchIterator::new([Ok(data.clone())], schema.clone()); + + conn.create_table(format!("table{}", i), reader) + .namespace(namespace.clone()) + .execute() + .await + .unwrap_or_else(|e| panic!("Failed to create table{}: {:?}", i, e)); + } + + // List tables in the child namespace + let list_response = conn + .list_tables(ListTablesRequest { + id: Some(namespace.clone()), + page_token: None, + limit: None, + }) + .await + .unwrap(); + assert_eq!(list_response.tables.len(), 3); + assert!(list_response.tables.contains(&"table1".to_string())); + assert!(list_response.tables.contains(&"table2".to_string())); + assert!(list_response.tables.contains(&"table3".to_string())); + } + } }