mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 10:30:40 +00:00
feat: extra headers parameter in client options (#2091)
Closes #1106 Unfortunately, these need to be set at the connection level. I investigated whether if we let users provide a callback they could use `AsyncLocalStorage` to access their context. However, it doesn't seem like NAPI supports this right now. I filed an issue: https://github.com/napi-rs/napi-rs/issues/2456
This commit is contained in:
@@ -8,6 +8,14 @@
|
||||
|
||||
## Properties
|
||||
|
||||
### extraHeaders?
|
||||
|
||||
```ts
|
||||
optional extraHeaders: Record<string, string>;
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### retryConfig?
|
||||
|
||||
```ts
|
||||
|
||||
@@ -104,4 +104,26 @@ describe("remote connection", () => {
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
it("should pass on requested extra headers", async () => {
|
||||
await withMockDatabase(
|
||||
(req, res) => {
|
||||
expect(req.headers["x-my-header"]).toEqual("my-value");
|
||||
|
||||
const body = JSON.stringify({ tables: [] });
|
||||
res.writeHead(200, { "Content-Type": "application/json" }).end(body);
|
||||
},
|
||||
async (db) => {
|
||||
const tableNames = await db.tableNames();
|
||||
expect(tableNames).toEqual([]);
|
||||
},
|
||||
{
|
||||
clientConfig: {
|
||||
extraHeaders: {
|
||||
"x-my-header": "my-value",
|
||||
},
|
||||
},
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use napi_derive::*;
|
||||
|
||||
/// Timeout configuration for remote HTTP client.
|
||||
@@ -67,6 +69,7 @@ pub struct ClientConfig {
|
||||
pub user_agent: Option<String>,
|
||||
pub retry_config: Option<RetryConfig>,
|
||||
pub timeout_config: Option<TimeoutConfig>,
|
||||
pub extra_headers: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl From<TimeoutConfig> for lancedb::remote::TimeoutConfig {
|
||||
@@ -104,6 +107,7 @@ impl From<ClientConfig> for lancedb::remote::ClientConfig {
|
||||
.unwrap_or(concat!("LanceDB-Node-Client/", env!("CARGO_PKG_VERSION")).to_string()),
|
||||
retry_config: config.retry_config.map(Into::into).unwrap_or_default(),
|
||||
timeout_config: config.timeout_config.map(Into::into).unwrap_or_default(),
|
||||
extra_headers: config.extra_headers.unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,6 +109,7 @@ class ClientConfig:
|
||||
user_agent: str = f"LanceDB-Python-Client/{__version__}"
|
||||
retry_config: RetryConfig = field(default_factory=RetryConfig)
|
||||
timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig)
|
||||
extra_headers: Optional[dict] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.retry_config, dict):
|
||||
|
||||
@@ -57,7 +57,7 @@ def mock_lancedb_connection(handler):
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def mock_lancedb_connection_async(handler):
|
||||
async def mock_lancedb_connection_async(handler, **client_config):
|
||||
with http.server.HTTPServer(
|
||||
("localhost", 8080), make_mock_http_handler(handler)
|
||||
) as server:
|
||||
@@ -73,6 +73,7 @@ async def mock_lancedb_connection_async(handler):
|
||||
"timeout_config": {
|
||||
"connect_timeout": 1,
|
||||
},
|
||||
**client_config,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -522,3 +523,19 @@ def test_create_client():
|
||||
|
||||
with pytest.warns(DeprecationWarning):
|
||||
lancedb.connect(**mandatory_args, request_thread_pool=10)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass_through_headers():
|
||||
def handler(request):
|
||||
assert request.headers["foo"] == "bar"
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"tables": []}')
|
||||
|
||||
async with mock_lancedb_connection_async(
|
||||
handler, extra_headers={"foo": "bar"}
|
||||
) as db:
|
||||
table_names = await db.table_names()
|
||||
assert table_names == []
|
||||
|
||||
@@ -223,6 +223,7 @@ pub struct PyClientConfig {
|
||||
user_agent: String,
|
||||
retry_config: Option<PyClientRetryConfig>,
|
||||
timeout_config: Option<PyClientTimeoutConfig>,
|
||||
extra_headers: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
@@ -274,6 +275,7 @@ impl From<PyClientConfig> for lancedb::remote::ClientConfig {
|
||||
user_agent: value.user_agent,
|
||||
retry_config: value.retry_config.map(Into::into).unwrap_or_default(),
|
||||
timeout_config: value.timeout_config.map(Into::into).unwrap_or_default(),
|
||||
extra_headers: value.extra_headers.unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::{future::Future, time::Duration};
|
||||
use std::{collections::HashMap, future::Future, str::FromStr, time::Duration};
|
||||
|
||||
use http::HeaderName;
|
||||
use log::debug;
|
||||
use reqwest::{
|
||||
header::{HeaderMap, HeaderValue},
|
||||
@@ -23,6 +24,7 @@ pub struct ClientConfig {
|
||||
/// name and version.
|
||||
pub user_agent: String,
|
||||
// TODO: how to configure request ids?
|
||||
pub extra_headers: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Default for ClientConfig {
|
||||
@@ -31,6 +33,7 @@ impl Default for ClientConfig {
|
||||
timeout_config: TimeoutConfig::default(),
|
||||
retry_config: RetryConfig::default(),
|
||||
user_agent: concat!("LanceDB-Rust-Client/", env!("CARGO_PKG_VERSION")).into(),
|
||||
extra_headers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -256,6 +259,7 @@ impl RestfulLanceDbClient<Sender> {
|
||||
host_override.is_some(),
|
||||
options,
|
||||
db_prefix,
|
||||
&client_config,
|
||||
)?)
|
||||
.user_agent(client_config.user_agent)
|
||||
.build()
|
||||
@@ -291,6 +295,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
has_host_override: bool,
|
||||
options: &RemoteOptions,
|
||||
db_prefix: Option<&str>,
|
||||
config: &ClientConfig,
|
||||
) -> Result<HeaderMap> {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
@@ -345,6 +350,18 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
);
|
||||
}
|
||||
|
||||
for (key, value) in &config.extra_headers {
|
||||
let key_parsed = HeaderName::from_str(key).map_err(|_| Error::InvalidInput {
|
||||
message: format!("non-ascii value for header '{}' provided", key),
|
||||
})?;
|
||||
headers.insert(
|
||||
key_parsed,
|
||||
HeaderValue::from_str(value).map_err(|_| Error::InvalidInput {
|
||||
message: format!("non-ascii value for header '{}' provided", key),
|
||||
})?,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user