mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 19:02:58 +00:00
feat(rust): client configuration for remote client (#1696)
This PR ports over advanced client configuration present in the Python `RestfulLanceDBClient` to the Rust one. The goal is to have feature parity so we can replace the implementation. * [x] Request timeout * [x] Retries with backoff * [x] Request id generation * [x] User agent (with default tied to library version ✨ ) * [x] Table existence cache * [ ] Deferred: ~Request id customization (should this just pick up OTEL trace ids?)~ Fixes #1684
This commit is contained in:
@@ -45,10 +45,12 @@ half = { "version" = "=2.4.1", default-features = false, features = [
|
||||
] }
|
||||
futures = "0"
|
||||
log = "0.4"
|
||||
moka = { version = "0.11", features = ["future"] }
|
||||
object_store = "0.10.2"
|
||||
pin-project = "1.0.7"
|
||||
snafu = "0.7.4"
|
||||
url = "2"
|
||||
num-traits = "0.2"
|
||||
rand = "0.8"
|
||||
regex = "1.10"
|
||||
lazy_static = "1"
|
||||
|
||||
@@ -32,6 +32,7 @@ lance-table = { workspace = true }
|
||||
lance-linalg = { workspace = true }
|
||||
lance-testing = { workspace = true }
|
||||
lance-encoding = { workspace = true }
|
||||
moka = { workspace = true}
|
||||
pin-project = { workspace = true }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
log.workspace = true
|
||||
@@ -47,7 +48,9 @@ async-openai = { version = "0.20.0", optional = true }
|
||||
serde_with = { version = "3.8.1" }
|
||||
# For remote feature
|
||||
reqwest = { version = "0.12.0", features = ["gzip", "json", "stream"], optional = true }
|
||||
http = { version = "1", optional = true } # Matching what is in reqwest
|
||||
rand = { version = "0.8.3", features = ["small_rng"], optional = true}
|
||||
http = { version = "1", optional = true } # Matching what is in reqwest
|
||||
uuid = { version = "1.7.0", features = ["v4"], optional = true }
|
||||
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
|
||||
polars = { version = ">=0.37,<0.40.0", optional = true }
|
||||
hf-hub = { version = "0.3.2", optional = true }
|
||||
@@ -71,7 +74,7 @@ http-body = "1" # Matching reqwest
|
||||
|
||||
[features]
|
||||
default = []
|
||||
remote = ["dep:reqwest", "dep:http"]
|
||||
remote = ["dep:reqwest", "dep:http", "dep:rand", "dep:uuid"]
|
||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||
s3-test = []
|
||||
openai = ["dep:async-openai", "dep:reqwest"]
|
||||
|
||||
@@ -32,6 +32,8 @@ use crate::embeddings::{
|
||||
};
|
||||
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
||||
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||
#[cfg(feature = "remote")]
|
||||
use crate::remote::client::ClientConfig;
|
||||
use crate::table::{NativeTable, TableDefinition, WriteOptions};
|
||||
use crate::utils::validate_table_name;
|
||||
use crate::Table;
|
||||
@@ -567,6 +569,8 @@ pub struct ConnectBuilder {
|
||||
region: Option<String>,
|
||||
/// LanceDB Cloud host override, only required if using an on-premises Lance Cloud instance
|
||||
host_override: Option<String>,
|
||||
#[cfg(feature = "remote")]
|
||||
client_config: ClientConfig,
|
||||
|
||||
storage_options: HashMap<String, String>,
|
||||
|
||||
@@ -592,6 +596,8 @@ impl ConnectBuilder {
|
||||
api_key: None,
|
||||
region: None,
|
||||
host_override: None,
|
||||
#[cfg(feature = "remote")]
|
||||
client_config: Default::default(),
|
||||
read_consistency_interval: None,
|
||||
storage_options: HashMap::new(),
|
||||
embedding_registry: None,
|
||||
@@ -613,6 +619,30 @@ impl ConnectBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the LanceDB Cloud client configuration.
|
||||
///
|
||||
/// ```
|
||||
/// # use lancedb::connect;
|
||||
/// # use lancedb::remote::*;
|
||||
/// connect("db://my_database")
|
||||
/// .client_config(ClientConfig {
|
||||
/// timeout_config: TimeoutConfig {
|
||||
/// connect_timeout: Some(std::time::Duration::from_secs(5)),
|
||||
/// ..Default::default()
|
||||
/// },
|
||||
/// retry_config: RetryConfig {
|
||||
/// retries: Some(5),
|
||||
/// ..Default::default()
|
||||
/// },
|
||||
/// ..Default::default()
|
||||
/// });
|
||||
/// ```
|
||||
#[cfg(feature = "remote")]
|
||||
pub fn client_config(mut self, config: ClientConfig) -> Self {
|
||||
self.client_config = config;
|
||||
self
|
||||
}
|
||||
|
||||
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
|
||||
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
|
||||
self.embedding_registry = Some(registry);
|
||||
@@ -685,12 +715,14 @@ impl ConnectBuilder {
|
||||
let api_key = self.api_key.ok_or_else(|| Error::InvalidInput {
|
||||
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
|
||||
})?;
|
||||
// TODO: remove this warning when the remote client is ready
|
||||
warn!("The rust implementation of the remote client is not yet ready for use.");
|
||||
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
|
||||
&self.uri,
|
||||
&api_key,
|
||||
®ion,
|
||||
self.host_override,
|
||||
self.client_config,
|
||||
)?);
|
||||
Ok(Connection {
|
||||
internal,
|
||||
|
||||
@@ -213,7 +213,7 @@ pub mod ipc;
|
||||
mod polars_arrow_convertors;
|
||||
pub mod query;
|
||||
#[cfg(feature = "remote")]
|
||||
pub(crate) mod remote;
|
||||
pub mod remote;
|
||||
pub mod table;
|
||||
pub mod utils;
|
||||
|
||||
|
||||
@@ -17,10 +17,12 @@
|
||||
//! building client/server applications with LanceDB or as a client for some
|
||||
//! other custom LanceDB service.
|
||||
|
||||
pub mod client;
|
||||
pub mod db;
|
||||
pub mod table;
|
||||
pub mod util;
|
||||
pub(crate) mod client;
|
||||
pub(crate) mod db;
|
||||
pub(crate) mod table;
|
||||
pub(crate) mod util;
|
||||
|
||||
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
|
||||
const JSON_CONTENT_TYPE: &str = "application/json";
|
||||
|
||||
pub use client::{ClientConfig, RetryConfig, TimeoutConfig};
|
||||
|
||||
@@ -14,13 +14,152 @@
|
||||
|
||||
use std::{future::Future, time::Duration};
|
||||
|
||||
use log::debug;
|
||||
use reqwest::{
|
||||
header::{HeaderMap, HeaderValue},
|
||||
RequestBuilder, Response,
|
||||
Request, RequestBuilder, Response,
|
||||
};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
const REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
|
||||
/// Configuration for the LanceDB Cloud HTTP client.
|
||||
#[derive(Debug)]
|
||||
pub struct ClientConfig {
|
||||
pub timeout_config: TimeoutConfig,
|
||||
pub retry_config: RetryConfig,
|
||||
/// User agent to use for requests. The default provides the libary
|
||||
/// name and version.
|
||||
pub user_agent: String,
|
||||
// TODO: how to configure request ids?
|
||||
}
|
||||
|
||||
impl Default for ClientConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
timeout_config: TimeoutConfig::default(),
|
||||
retry_config: RetryConfig::default(),
|
||||
user_agent: concat!("LanceDB-Rust-Client/", env!("CARGO_PKG_VERSION")).into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// How to handle timeouts for HTTP requests.
|
||||
#[derive(Default, Debug)]
|
||||
pub struct TimeoutConfig {
|
||||
/// The timeout for creating a connection to the server.
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_CONNECT_TIMEOUT` environment variable
|
||||
/// to set this value. Use an integer value in seconds.
|
||||
///
|
||||
/// The default is 120 seconds (2 minutes).
|
||||
pub connect_timeout: Option<Duration>,
|
||||
/// The timeout for reading a response from the server.
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_READ_TIMEOUT` environment variable
|
||||
/// to set this value. Use an integer value in seconds.
|
||||
///
|
||||
/// The default is 300 seconds (5 minutes).
|
||||
pub read_timeout: Option<Duration>,
|
||||
/// The timeout for keeping idle connections alive.
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_CONNECTION_TIMEOUT` environment variable
|
||||
/// to set this value. Use an integer value in seconds.
|
||||
///
|
||||
/// The default is 300 seconds (5 minutes).
|
||||
pub pool_idle_timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
/// How to handle retries for HTTP requests.
|
||||
#[derive(Default, Debug)]
|
||||
pub struct RetryConfig {
|
||||
/// The number of times to retry a request if it fails.
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_MAX_RETRIES` environment variable
|
||||
/// to set this value. Use an integer value.
|
||||
///
|
||||
/// The default is 3 retries.
|
||||
pub retries: Option<u8>,
|
||||
/// The number of times to retry a request if it fails to connect.
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_CONNECT_RETRIES` environment variable
|
||||
/// to set this value. Use an integer value.
|
||||
///
|
||||
/// The default is 3 retries.
|
||||
pub connect_retries: Option<u8>,
|
||||
/// The number of times to retry a request if it fails to read.
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_READ_RETRIES` environment variable
|
||||
/// to set this value. Use an integer value.
|
||||
///
|
||||
/// The default is 3 retries.
|
||||
pub read_retries: Option<u8>,
|
||||
/// The exponential backoff factor to use when retrying requests.
|
||||
///
|
||||
/// Between each retry, the client will wait for the amount of seconds:
|
||||
///
|
||||
/// ```text
|
||||
/// {backoff factor} * (2 ** ({number of previous retries}))
|
||||
/// ```
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_RETRY_BACKOFF_FACTOR` environment variable
|
||||
/// to set this value. Use a float value.
|
||||
///
|
||||
/// The default is 0.25. So the first retry will wait 0.25 seconds, the second
|
||||
/// retry will wait 0.5 seconds, the third retry will wait 1 second, etc.
|
||||
pub backoff_factor: Option<f32>,
|
||||
/// The backoff jitter factor to use when retrying requests.
|
||||
///
|
||||
/// The backoff jitter is a random value between 0 and the jitter factor in
|
||||
/// seconds.
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_RETRY_BACKOFF_JITTER` environment variable
|
||||
/// to set this value. Use a float value.
|
||||
///
|
||||
/// The default is 0.25. So between 0 and 0.25 seconds will be added to the
|
||||
/// sleep time between retries.
|
||||
pub backoff_jitter: Option<f32>,
|
||||
/// The set of status codes to retry on.
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_RETRY_STATUSES` environment variable
|
||||
/// to set this value. Use a comma-separated list of integer values.
|
||||
///
|
||||
/// The default is 429, 500, 502, 503.
|
||||
pub statuses: Option<Vec<u16>>,
|
||||
// TODO: should we allow customizing methods?
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ResolvedRetryConfig {
|
||||
retries: u8,
|
||||
connect_retries: u8,
|
||||
read_retries: u8,
|
||||
backoff_factor: f32,
|
||||
backoff_jitter: f32,
|
||||
statuses: Vec<reqwest::StatusCode>,
|
||||
}
|
||||
|
||||
impl TryFrom<RetryConfig> for ResolvedRetryConfig {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(retry_config: RetryConfig) -> Result<Self> {
|
||||
Ok(Self {
|
||||
retries: retry_config.retries.unwrap_or(3),
|
||||
connect_retries: retry_config.connect_retries.unwrap_or(3),
|
||||
read_retries: retry_config.read_retries.unwrap_or(3),
|
||||
backoff_factor: retry_config.backoff_factor.unwrap_or(0.25),
|
||||
backoff_jitter: retry_config.backoff_jitter.unwrap_or(0.25),
|
||||
statuses: retry_config
|
||||
.statuses
|
||||
.unwrap_or_else(|| vec![429, 500, 502, 503])
|
||||
.into_iter()
|
||||
.map(|status| reqwest::StatusCode::from_u16(status).unwrap())
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// We use the `HttpSend` trait to abstract over the `reqwest::Client` so that
|
||||
// we can mock responses in tests. Based on the patterns from this blog post:
|
||||
// https://write.as/balrogboogie/testing-reqwest-based-clients
|
||||
@@ -28,28 +167,54 @@ use crate::error::{Error, Result};
|
||||
pub struct RestfulLanceDbClient<S: HttpSend = Sender> {
|
||||
client: reqwest::Client,
|
||||
host: String,
|
||||
retry_config: ResolvedRetryConfig,
|
||||
sender: S,
|
||||
}
|
||||
|
||||
pub trait HttpSend: Clone + Send + Sync + std::fmt::Debug + 'static {
|
||||
fn send(&self, req: RequestBuilder) -> impl Future<Output = Result<Response>> + Send;
|
||||
fn send(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
request: reqwest::Request,
|
||||
) -> impl Future<Output = reqwest::Result<Response>> + Send;
|
||||
}
|
||||
|
||||
// Default implementation of HttpSend which sends the request normally with reqwest
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Sender;
|
||||
impl HttpSend for Sender {
|
||||
async fn send(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
|
||||
Ok(request.send().await?)
|
||||
async fn send(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
request: reqwest::Request,
|
||||
) -> reqwest::Result<reqwest::Response> {
|
||||
client.execute(request).await
|
||||
}
|
||||
}
|
||||
|
||||
impl RestfulLanceDbClient<Sender> {
|
||||
fn get_timeout(passed: Option<Duration>, env_var: &str, default: Duration) -> Result<Duration> {
|
||||
if let Some(passed) = passed {
|
||||
Ok(passed)
|
||||
} else if let Ok(timeout) = std::env::var(env_var) {
|
||||
let timeout = timeout.parse::<u64>().map_err(|_| Error::InvalidInput {
|
||||
message: format!(
|
||||
"Invalid value for {} environment variable: '{}'",
|
||||
env_var, timeout
|
||||
),
|
||||
})?;
|
||||
Ok(Duration::from_secs(timeout))
|
||||
} else {
|
||||
Ok(default)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_new(
|
||||
db_url: &str,
|
||||
api_key: &str,
|
||||
region: &str,
|
||||
host_override: Option<String>,
|
||||
client_config: ClientConfig,
|
||||
) -> Result<Self> {
|
||||
let parsed_url = url::Url::parse(db_url)?;
|
||||
debug_assert_eq!(parsed_url.scheme(), "db");
|
||||
@@ -59,22 +224,47 @@ impl RestfulLanceDbClient<Sender> {
|
||||
});
|
||||
}
|
||||
let db_name = parsed_url.host_str().unwrap();
|
||||
|
||||
// Get the timeouts
|
||||
let connect_timeout = Self::get_timeout(
|
||||
client_config.timeout_config.connect_timeout,
|
||||
"LANCE_CLIENT_CONNECT_TIMEOUT",
|
||||
Duration::from_secs(120),
|
||||
)?;
|
||||
let read_timeout = Self::get_timeout(
|
||||
client_config.timeout_config.read_timeout,
|
||||
"LANCE_CLIENT_READ_TIMEOUT",
|
||||
Duration::from_secs(300),
|
||||
)?;
|
||||
let pool_idle_timeout = Self::get_timeout(
|
||||
client_config.timeout_config.pool_idle_timeout,
|
||||
// Though it's confusing with the connect_timeout name, this is the
|
||||
// legacy name for this in the Python sync client. So we keep as-is.
|
||||
"LANCE_CLIENT_CONNECTION_TIMEOUT",
|
||||
Duration::from_secs(300),
|
||||
)?;
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.connect_timeout(connect_timeout)
|
||||
.read_timeout(read_timeout)
|
||||
.pool_idle_timeout(pool_idle_timeout)
|
||||
.default_headers(Self::default_headers(
|
||||
api_key,
|
||||
region,
|
||||
db_name,
|
||||
host_override.is_some(),
|
||||
)?)
|
||||
.user_agent(client_config.user_agent)
|
||||
.build()?;
|
||||
let host = match host_override {
|
||||
Some(host_override) => host_override,
|
||||
None => format!("https://{}.{}.api.lancedb.com", db_name, region),
|
||||
};
|
||||
let retry_config = client_config.retry_config.try_into()?;
|
||||
Ok(Self {
|
||||
client,
|
||||
host,
|
||||
retry_config,
|
||||
sender: Sender,
|
||||
})
|
||||
}
|
||||
@@ -129,8 +319,100 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
self.client.post(full_uri)
|
||||
}
|
||||
|
||||
pub async fn send(&self, req: RequestBuilder) -> Result<Response> {
|
||||
self.sender.send(req).await
|
||||
pub async fn send(&self, req: RequestBuilder, with_retry: bool) -> Result<Response> {
|
||||
let (client, request) = req.build_split();
|
||||
let mut request = request.unwrap();
|
||||
|
||||
// Set a request id.
|
||||
// TODO: allow the user to supply this, through middleware?
|
||||
if request.headers().get(REQUEST_ID_HEADER).is_none() {
|
||||
let request_id = uuid::Uuid::new_v4();
|
||||
let request_id = HeaderValue::from_str(&request_id.to_string()).unwrap();
|
||||
request.headers_mut().insert(REQUEST_ID_HEADER, request_id);
|
||||
}
|
||||
|
||||
if with_retry {
|
||||
self.send_with_retry_impl(client, request).await
|
||||
} else {
|
||||
Ok(self.sender.send(&client, request).await?)
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_with_retry_impl(
|
||||
&self,
|
||||
client: reqwest::Client,
|
||||
req: Request,
|
||||
) -> Result<Response> {
|
||||
let mut request_failures = 0;
|
||||
let mut connect_failures = 0;
|
||||
let mut read_failures = 0;
|
||||
|
||||
loop {
|
||||
// This only works if the request body is not a stream. If it is
|
||||
// a stream, we can't use the retry path. We would need to implement
|
||||
// an outer retry.
|
||||
let request = req.try_clone().ok_or_else(|| Error::Http {
|
||||
message: "Attempted to retry a request that cannot be cloned".to_string(),
|
||||
})?;
|
||||
let response = self.sender.send(&client, request).await;
|
||||
let status_code = response.as_ref().map(|r| r.status());
|
||||
match status_code {
|
||||
Ok(status) if status.is_success() => return Ok(response?),
|
||||
Ok(status) if self.retry_config.statuses.contains(&status) => {
|
||||
request_failures += 1;
|
||||
if request_failures >= self.retry_config.retries {
|
||||
// TODO: better error
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"Request failed after {} retries with status code {}",
|
||||
request_failures, status
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(err) if err.is_connect() => {
|
||||
connect_failures += 1;
|
||||
if connect_failures >= self.retry_config.connect_retries {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"Request failed after {} connect retries with error: {}",
|
||||
connect_failures, err
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(err) if err.is_timeout() || err.is_body() || err.is_decode() => {
|
||||
read_failures += 1;
|
||||
if read_failures >= self.retry_config.read_retries {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"Request failed after {} read retries with error: {}",
|
||||
read_failures, err
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(_) | Err(_) => return Ok(response?),
|
||||
}
|
||||
|
||||
let backoff = self.retry_config.backoff_factor * (2.0f32.powi(request_failures as i32));
|
||||
let jitter = rand::random::<f32>() * self.retry_config.backoff_jitter;
|
||||
let sleep_time = Duration::from_secs_f32(backoff + jitter);
|
||||
debug!(
|
||||
"Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}",
|
||||
req.headers()
|
||||
.get("x-request-id")
|
||||
.and_then(|v| v.to_str().ok()),
|
||||
connect_failures,
|
||||
self.retry_config.connect_retries,
|
||||
request_failures,
|
||||
self.retry_config.retries,
|
||||
read_failures,
|
||||
self.retry_config.read_retries,
|
||||
sleep_time
|
||||
);
|
||||
tokio::time::sleep(sleep_time).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn rsp_to_str(response: Response) -> String {
|
||||
@@ -172,8 +454,11 @@ pub mod test_utils {
|
||||
}
|
||||
|
||||
impl HttpSend for MockSender {
|
||||
async fn send(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
|
||||
let request = request.build().unwrap();
|
||||
async fn send(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
request: reqwest::Request,
|
||||
) -> reqwest::Result<reqwest::Response> {
|
||||
let response = (self.f)(request);
|
||||
Ok(response)
|
||||
}
|
||||
@@ -193,6 +478,7 @@ pub mod test_utils {
|
||||
RestfulLanceDbClient {
|
||||
client: reqwest::Client::new(),
|
||||
host: "http://localhost".to_string(),
|
||||
retry_config: RetryConfig::default().try_into().unwrap(),
|
||||
sender: MockSender {
|
||||
f: Arc::new(wrapper),
|
||||
},
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::sync::Arc;
|
||||
use arrow_array::RecordBatchReader;
|
||||
use async_trait::async_trait;
|
||||
use http::StatusCode;
|
||||
use moka::future::Cache;
|
||||
use reqwest::header::CONTENT_TYPE;
|
||||
use serde::Deserialize;
|
||||
use tokio::task::spawn_blocking;
|
||||
@@ -28,7 +29,7 @@ use crate::embeddings::EmbeddingRegistry;
|
||||
use crate::error::Result;
|
||||
use crate::Table;
|
||||
|
||||
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
|
||||
use super::client::{ClientConfig, HttpSend, RestfulLanceDbClient, Sender};
|
||||
use super::table::RemoteTable;
|
||||
use super::util::batches_to_ipc_bytes;
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
@@ -41,6 +42,7 @@ struct ListTablesResponse {
|
||||
#[derive(Debug)]
|
||||
pub struct RemoteDatabase<S: HttpSend = Sender> {
|
||||
client: RestfulLanceDbClient<S>,
|
||||
table_cache: Cache<String, ()>,
|
||||
}
|
||||
|
||||
impl RemoteDatabase {
|
||||
@@ -49,9 +51,20 @@ impl RemoteDatabase {
|
||||
api_key: &str,
|
||||
region: &str,
|
||||
host_override: Option<String>,
|
||||
client_config: ClientConfig,
|
||||
) -> Result<Self> {
|
||||
let client = RestfulLanceDbClient::try_new(uri, api_key, region, host_override)?;
|
||||
Ok(Self { client })
|
||||
let client =
|
||||
RestfulLanceDbClient::try_new(uri, api_key, region, host_override, client_config)?;
|
||||
|
||||
let table_cache = Cache::builder()
|
||||
.time_to_live(std::time::Duration::from_secs(300))
|
||||
.max_capacity(10_000)
|
||||
.build();
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
table_cache,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,7 +81,10 @@ mod test_utils {
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
let client = client_with_handler(handler);
|
||||
Self { client }
|
||||
Self {
|
||||
client,
|
||||
table_cache: Cache::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -89,9 +105,13 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
if let Some(start_after) = options.start_after {
|
||||
req = req.query(&[("page_token", start_after)]);
|
||||
}
|
||||
let rsp = self.client.send(req).await?;
|
||||
let rsp = self.client.send(req, true).await?;
|
||||
let rsp = self.client.check_response(rsp).await?;
|
||||
Ok(rsp.json::<ListTablesResponse>().await?.tables)
|
||||
let tables = rsp.json::<ListTablesResponse>().await?.tables;
|
||||
for table in &tables {
|
||||
self.table_cache.insert(table.clone(), ()).await;
|
||||
}
|
||||
Ok(tables)
|
||||
}
|
||||
|
||||
async fn do_create_table(
|
||||
@@ -113,7 +133,7 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
// This is currently expected by LanceDb cloud but will be removed soon.
|
||||
.header("x-request-id", "na");
|
||||
let rsp = self.client.send(req).await?;
|
||||
let rsp = self.client.send(req, false).await?;
|
||||
|
||||
if rsp.status() == StatusCode::BAD_REQUEST {
|
||||
let body = rsp.text().await?;
|
||||
@@ -126,6 +146,8 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
|
||||
self.client.check_response(rsp).await?;
|
||||
|
||||
self.table_cache.insert(options.name.clone(), ()).await;
|
||||
|
||||
Ok(Table::new(Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
options.name,
|
||||
@@ -134,15 +156,17 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
|
||||
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<Table> {
|
||||
// We describe the table to confirm it exists before moving on.
|
||||
// TODO: a TTL cache of table existence
|
||||
let req = self
|
||||
.client
|
||||
.get(&format!("/v1/table/{}/describe/", options.name));
|
||||
let resp = self.client.send(req).await?;
|
||||
if resp.status() == StatusCode::NOT_FOUND {
|
||||
return Err(crate::Error::TableNotFound { name: options.name });
|
||||
if self.table_cache.get(&options.name).is_none() {
|
||||
let req = self
|
||||
.client
|
||||
.get(&format!("/v1/table/{}/describe/", options.name));
|
||||
let resp = self.client.send(req, true).await?;
|
||||
if resp.status() == StatusCode::NOT_FOUND {
|
||||
return Err(crate::Error::TableNotFound { name: options.name });
|
||||
}
|
||||
self.client.check_response(resp).await?;
|
||||
}
|
||||
self.client.check_response(resp).await?;
|
||||
|
||||
Ok(Table::new(Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
options.name,
|
||||
@@ -154,15 +178,18 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/rename/", current_name));
|
||||
let req = req.json(&serde_json::json!({ "new_table_name": new_name }));
|
||||
let resp = self.client.send(req).await?;
|
||||
let resp = self.client.send(req, false).await?;
|
||||
self.client.check_response(resp).await?;
|
||||
self.table_cache.remove(current_name).await;
|
||||
self.table_cache.insert(new_name.into(), ()).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn drop_table(&self, name: &str) -> Result<()> {
|
||||
let req = self.client.post(&format!("/v1/table/{}/drop/", name));
|
||||
let resp = self.client.send(req).await?;
|
||||
let resp = self.client.send(req, true).await?;
|
||||
self.client.check_response(resp).await?;
|
||||
self.table_cache.remove(name).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
|
||||
async fn describe(&self) -> Result<TableDescription> {
|
||||
let request = self.client.post(&format!("/table/{}/describe/", self.name));
|
||||
let response = self.client.send(request).await?;
|
||||
let response = self.client.send(request, true).await?;
|
||||
|
||||
let response = self.check_table_response(response).await?;
|
||||
|
||||
@@ -257,7 +257,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
request = request.json(&serde_json::json!({}));
|
||||
}
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
let response = self.client.send(request, true).await?;
|
||||
|
||||
let response = self.check_table_response(response).await?;
|
||||
|
||||
@@ -286,7 +286,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
}
|
||||
}
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
let response = self.client.send(request, false).await?;
|
||||
|
||||
self.check_table_response(response).await?;
|
||||
|
||||
@@ -337,7 +337,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
let response = self.client.send(request, true).await?;
|
||||
|
||||
let stream = self.read_arrow_stream(response).await?;
|
||||
|
||||
@@ -359,7 +359,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
let response = self.client.send(request, true).await?;
|
||||
|
||||
let stream = self.read_arrow_stream(response).await?;
|
||||
|
||||
@@ -379,7 +379,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
"only_if": update.filter,
|
||||
}));
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
let response = self.client.send(request, false).await?;
|
||||
|
||||
let response = self.check_table_response(response).await?;
|
||||
|
||||
@@ -398,7 +398,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
.client
|
||||
.post(&format!("/table/{}/delete/", self.name))
|
||||
.json(&body);
|
||||
let response = self.client.send(request).await?;
|
||||
let response = self.client.send(request, false).await?;
|
||||
self.check_table_response(response).await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -468,7 +468,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
let response = self.client.send(request, false).await?;
|
||||
|
||||
self.check_table_response(response).await?;
|
||||
|
||||
@@ -489,7 +489,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.body(body);
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
let response = self.client.send(request, false).await?;
|
||||
|
||||
self.check_table_response(response).await?;
|
||||
|
||||
@@ -528,7 +528,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/table/{}/index/{}/stats/", self.name, index_name));
|
||||
let response = self.client.send(request).await?;
|
||||
let response = self.client.send(request, true).await?;
|
||||
|
||||
if response.status() == StatusCode::NOT_FOUND {
|
||||
return Ok(None);
|
||||
|
||||
Reference in New Issue
Block a user