From 16851389ea4e52fa48ef8198b3c76295c7b06d4c Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 4 Feb 2025 17:26:45 -0800 Subject: [PATCH] feat: extra headers parameter in client options (#2091) Closes #1106 Unfortunately, these need to be set at the connection level. I investigated whether if we let users provide a callback they could use `AsyncLocalStorage` to access their context. However, it doesn't seem like NAPI supports this right now. I filed an issue: https://github.com/napi-rs/napi-rs/issues/2456 --- docs/src/js/interfaces/ClientConfig.md | 8 ++++++++ nodejs/__test__/remote.test.ts | 22 ++++++++++++++++++++++ nodejs/src/remote.rs | 4 ++++ python/python/lancedb/remote/__init__.py | 1 + python/python/tests/test_remote_db.py | 19 ++++++++++++++++++- python/src/connection.rs | 2 ++ rust/lancedb/src/remote/client.rs | 19 ++++++++++++++++++- 7 files changed, 73 insertions(+), 2 deletions(-) diff --git a/docs/src/js/interfaces/ClientConfig.md b/docs/src/js/interfaces/ClientConfig.md index 23f84350..b3f2c0a6 100644 --- a/docs/src/js/interfaces/ClientConfig.md +++ b/docs/src/js/interfaces/ClientConfig.md @@ -8,6 +8,14 @@ ## Properties +### extraHeaders? + +```ts +optional extraHeaders: Record; +``` + +*** + ### retryConfig? ```ts diff --git a/nodejs/__test__/remote.test.ts b/nodejs/__test__/remote.test.ts index 2271bd51..eb62d9f5 100644 --- a/nodejs/__test__/remote.test.ts +++ b/nodejs/__test__/remote.test.ts @@ -104,4 +104,26 @@ describe("remote connection", () => { }, ); }); + + it("should pass on requested extra headers", async () => { + await withMockDatabase( + (req, res) => { + expect(req.headers["x-my-header"]).toEqual("my-value"); + + const body = JSON.stringify({ tables: [] }); + res.writeHead(200, { "Content-Type": "application/json" }).end(body); + }, + async (db) => { + const tableNames = await db.tableNames(); + expect(tableNames).toEqual([]); + }, + { + clientConfig: { + extraHeaders: { + "x-my-header": "my-value", + }, + }, + }, + ); + }); }); diff --git a/nodejs/src/remote.rs b/nodejs/src/remote.rs index 38b4f43e..2ec29897 100644 --- a/nodejs/src/remote.rs +++ b/nodejs/src/remote.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors +use std::collections::HashMap; + use napi_derive::*; /// Timeout configuration for remote HTTP client. @@ -67,6 +69,7 @@ pub struct ClientConfig { pub user_agent: Option, pub retry_config: Option, pub timeout_config: Option, + pub extra_headers: Option>, } impl From for lancedb::remote::TimeoutConfig { @@ -104,6 +107,7 @@ impl From for lancedb::remote::ClientConfig { .unwrap_or(concat!("LanceDB-Node-Client/", env!("CARGO_PKG_VERSION")).to_string()), retry_config: config.retry_config.map(Into::into).unwrap_or_default(), timeout_config: config.timeout_config.map(Into::into).unwrap_or_default(), + extra_headers: config.extra_headers.unwrap_or_default(), } } } diff --git a/python/python/lancedb/remote/__init__.py b/python/python/lancedb/remote/__init__.py index 8b34aba8..40502d64 100644 --- a/python/python/lancedb/remote/__init__.py +++ b/python/python/lancedb/remote/__init__.py @@ -109,6 +109,7 @@ class ClientConfig: user_agent: str = f"LanceDB-Python-Client/{__version__}" retry_config: RetryConfig = field(default_factory=RetryConfig) timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig) + extra_headers: Optional[dict] = None def __post_init__(self): if isinstance(self.retry_config, dict): diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index 64252ed3..874787da 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -57,7 +57,7 @@ def mock_lancedb_connection(handler): @contextlib.asynccontextmanager -async def mock_lancedb_connection_async(handler): +async def mock_lancedb_connection_async(handler, **client_config): with http.server.HTTPServer( ("localhost", 8080), make_mock_http_handler(handler) ) as server: @@ -73,6 +73,7 @@ async def mock_lancedb_connection_async(handler): "timeout_config": { "connect_timeout": 1, }, + **client_config, }, ) @@ -522,3 +523,19 @@ def test_create_client(): with pytest.warns(DeprecationWarning): lancedb.connect(**mandatory_args, request_thread_pool=10) + + +@pytest.mark.asyncio +async def test_pass_through_headers(): + def handler(request): + assert request.headers["foo"] == "bar" + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b'{"tables": []}') + + async with mock_lancedb_connection_async( + handler, extra_headers={"foo": "bar"} + ) as db: + table_names = await db.table_names() + assert table_names == [] diff --git a/python/src/connection.rs b/python/src/connection.rs index d71da5e1..0edeebc6 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -223,6 +223,7 @@ pub struct PyClientConfig { user_agent: String, retry_config: Option, timeout_config: Option, + extra_headers: Option>, } #[derive(FromPyObject)] @@ -274,6 +275,7 @@ impl From for lancedb::remote::ClientConfig { user_agent: value.user_agent, retry_config: value.retry_config.map(Into::into).unwrap_or_default(), timeout_config: value.timeout_config.map(Into::into).unwrap_or_default(), + extra_headers: value.extra_headers.unwrap_or_default(), } } } diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index ed3a6f27..89e06f7a 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -1,8 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors -use std::{future::Future, time::Duration}; +use std::{collections::HashMap, future::Future, str::FromStr, time::Duration}; +use http::HeaderName; use log::debug; use reqwest::{ header::{HeaderMap, HeaderValue}, @@ -23,6 +24,7 @@ pub struct ClientConfig { /// name and version. pub user_agent: String, // TODO: how to configure request ids? + pub extra_headers: HashMap, } impl Default for ClientConfig { @@ -31,6 +33,7 @@ impl Default for ClientConfig { timeout_config: TimeoutConfig::default(), retry_config: RetryConfig::default(), user_agent: concat!("LanceDB-Rust-Client/", env!("CARGO_PKG_VERSION")).into(), + extra_headers: HashMap::new(), } } } @@ -256,6 +259,7 @@ impl RestfulLanceDbClient { host_override.is_some(), options, db_prefix, + &client_config, )?) .user_agent(client_config.user_agent) .build() @@ -291,6 +295,7 @@ impl RestfulLanceDbClient { has_host_override: bool, options: &RemoteOptions, db_prefix: Option<&str>, + config: &ClientConfig, ) -> Result { let mut headers = HeaderMap::new(); headers.insert( @@ -345,6 +350,18 @@ impl RestfulLanceDbClient { ); } + for (key, value) in &config.extra_headers { + let key_parsed = HeaderName::from_str(key).map_err(|_| Error::InvalidInput { + message: format!("non-ascii value for header '{}' provided", key), + })?; + headers.insert( + key_parsed, + HeaderValue::from_str(value).map_err(|_| Error::InvalidInput { + message: format!("non-ascii value for header '{}' provided", key), + })?, + ); + } + Ok(headers) }