feat: extra headers parameter in client options (#2091)

Closes #1106

Unfortunately, these need to be set at the connection level. I
investigated whether if we let users provide a callback they could use
`AsyncLocalStorage` to access their context. However, it doesn't seem
like NAPI supports this right now. I filed an issue:
https://github.com/napi-rs/napi-rs/issues/2456
This commit is contained in:
Will Jones
2025-02-04 17:26:45 -08:00
committed by GitHub
parent c269524b2f
commit 16851389ea
7 changed files with 73 additions and 2 deletions

View File

@@ -8,6 +8,14 @@
## Properties
### extraHeaders?
```ts
optional extraHeaders: Record<string, string>;
```
***
### retryConfig?
```ts

View File

@@ -104,4 +104,26 @@ describe("remote connection", () => {
},
);
});
it("should pass on requested extra headers", async () => {
await withMockDatabase(
(req, res) => {
expect(req.headers["x-my-header"]).toEqual("my-value");
const body = JSON.stringify({ tables: [] });
res.writeHead(200, { "Content-Type": "application/json" }).end(body);
},
async (db) => {
const tableNames = await db.tableNames();
expect(tableNames).toEqual([]);
},
{
clientConfig: {
extraHeaders: {
"x-my-header": "my-value",
},
},
},
);
});
});

View File

@@ -1,6 +1,8 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::collections::HashMap;
use napi_derive::*;
/// Timeout configuration for remote HTTP client.
@@ -67,6 +69,7 @@ pub struct ClientConfig {
pub user_agent: Option<String>,
pub retry_config: Option<RetryConfig>,
pub timeout_config: Option<TimeoutConfig>,
pub extra_headers: Option<HashMap<String, String>>,
}
impl From<TimeoutConfig> for lancedb::remote::TimeoutConfig {
@@ -104,6 +107,7 @@ impl From<ClientConfig> for lancedb::remote::ClientConfig {
.unwrap_or(concat!("LanceDB-Node-Client/", env!("CARGO_PKG_VERSION")).to_string()),
retry_config: config.retry_config.map(Into::into).unwrap_or_default(),
timeout_config: config.timeout_config.map(Into::into).unwrap_or_default(),
extra_headers: config.extra_headers.unwrap_or_default(),
}
}
}

View File

@@ -109,6 +109,7 @@ class ClientConfig:
user_agent: str = f"LanceDB-Python-Client/{__version__}"
retry_config: RetryConfig = field(default_factory=RetryConfig)
timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig)
extra_headers: Optional[dict] = None
def __post_init__(self):
if isinstance(self.retry_config, dict):

View File

@@ -57,7 +57,7 @@ def mock_lancedb_connection(handler):
@contextlib.asynccontextmanager
async def mock_lancedb_connection_async(handler):
async def mock_lancedb_connection_async(handler, **client_config):
with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler)
) as server:
@@ -73,6 +73,7 @@ async def mock_lancedb_connection_async(handler):
"timeout_config": {
"connect_timeout": 1,
},
**client_config,
},
)
@@ -522,3 +523,19 @@ def test_create_client():
with pytest.warns(DeprecationWarning):
lancedb.connect(**mandatory_args, request_thread_pool=10)
@pytest.mark.asyncio
async def test_pass_through_headers():
def handler(request):
assert request.headers["foo"] == "bar"
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b'{"tables": []}')
async with mock_lancedb_connection_async(
handler, extra_headers={"foo": "bar"}
) as db:
table_names = await db.table_names()
assert table_names == []

View File

@@ -223,6 +223,7 @@ pub struct PyClientConfig {
user_agent: String,
retry_config: Option<PyClientRetryConfig>,
timeout_config: Option<PyClientTimeoutConfig>,
extra_headers: Option<HashMap<String, String>>,
}
#[derive(FromPyObject)]
@@ -274,6 +275,7 @@ impl From<PyClientConfig> for lancedb::remote::ClientConfig {
user_agent: value.user_agent,
retry_config: value.retry_config.map(Into::into).unwrap_or_default(),
timeout_config: value.timeout_config.map(Into::into).unwrap_or_default(),
extra_headers: value.extra_headers.unwrap_or_default(),
}
}
}

View File

@@ -1,8 +1,9 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{future::Future, time::Duration};
use std::{collections::HashMap, future::Future, str::FromStr, time::Duration};
use http::HeaderName;
use log::debug;
use reqwest::{
header::{HeaderMap, HeaderValue},
@@ -23,6 +24,7 @@ pub struct ClientConfig {
/// name and version.
pub user_agent: String,
// TODO: how to configure request ids?
pub extra_headers: HashMap<String, String>,
}
impl Default for ClientConfig {
@@ -31,6 +33,7 @@ impl Default for ClientConfig {
timeout_config: TimeoutConfig::default(),
retry_config: RetryConfig::default(),
user_agent: concat!("LanceDB-Rust-Client/", env!("CARGO_PKG_VERSION")).into(),
extra_headers: HashMap::new(),
}
}
}
@@ -256,6 +259,7 @@ impl RestfulLanceDbClient<Sender> {
host_override.is_some(),
options,
db_prefix,
&client_config,
)?)
.user_agent(client_config.user_agent)
.build()
@@ -291,6 +295,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
has_host_override: bool,
options: &RemoteOptions,
db_prefix: Option<&str>,
config: &ClientConfig,
) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
headers.insert(
@@ -345,6 +350,18 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
);
}
for (key, value) in &config.extra_headers {
let key_parsed = HeaderName::from_str(key).map_err(|_| Error::InvalidInput {
message: format!("non-ascii value for header '{}' provided", key),
})?;
headers.insert(
key_parsed,
HeaderValue::from_str(value).map_err(|_| Error::InvalidInput {
message: format!("non-ascii value for header '{}' provided", key),
})?,
);
}
Ok(headers)
}