feat: better errors for remote SDK (#1722)

* Adds nicer errors to remote SDK, that expose useful properties like
`request_id` and `status_code`.
* Makes sure the Python tracebacks print nicely by mapping the `source`
field from a Rust error to the `__cause__` field.
This commit is contained in:
Will Jones
2024-10-08 21:21:13 -07:00
committed by GitHub
parent 607476788e
commit 8509f73221
9 changed files with 622 additions and 193 deletions

View File

@@ -103,19 +103,29 @@ class RestfulLanceDBClient:
@staticmethod
def _check_status(resp: requests.Response):
# Leaving request id empty for now, as we'll be replacing this impl
# with the Rust one shortly.
if resp.status_code == 404:
raise LanceDBClientError(f"Not found: {resp.text}")
raise LanceDBClientError(
f"Not found: {resp.text}", request_id="", status_code=404
)
elif 400 <= resp.status_code < 500:
raise LanceDBClientError(
f"Bad Request: {resp.status_code}, error: {resp.text}"
f"Bad Request: {resp.status_code}, error: {resp.text}",
request_id="",
status_code=resp.status_code,
)
elif 500 <= resp.status_code < 600:
raise LanceDBClientError(
f"Internal Server Error: {resp.status_code}, error: {resp.text}"
f"Internal Server Error: {resp.status_code}, error: {resp.text}",
request_id="",
status_code=resp.status_code,
)
elif resp.status_code != 200:
raise LanceDBClientError(
f"Unknown Error: {resp.status_code}, error: {resp.text}"
f"Unknown Error: {resp.status_code}, error: {resp.text}",
request_id="",
status_code=resp.status_code,
)
@_check_not_closed

View File

@@ -12,5 +12,102 @@
# limitations under the License.
from typing import Optional
class LanceDBClientError(RuntimeError):
"""An error that occurred in the LanceDB client.
Attributes
----------
message: str
The error message.
request_id: str
The id of the request that failed. This can be provided in error reports
to help diagnose the issue.
status_code: int
The HTTP status code of the response. May be None if the request
failed before the response was received.
"""
def __init__(
self, message: str, request_id: str, status_code: Optional[int] = None
):
super().__init__(message)
self.request_id = request_id
self.status_code = status_code
class HttpError(LanceDBClientError):
"""An error that occurred during an HTTP request.
Attributes
----------
message: str
The error message.
request_id: str
The id of the request that failed. This can be provided in error reports
to help diagnose the issue.
status_code: int
The HTTP status code of the response. May be None if the request
failed before the response was received.
"""
pass
class RetryError(LanceDBClientError):
"""An error that occurs when the client has exceeded the maximum number of retries.
The retry strategy can be adjusted by setting the
[retry_config](lancedb.remote.ClientConfig.retry_config) in the client
configuration. This is passed in the `client_config` argument of
[connect](lancedb.connect) and [connect_async](lancedb.connect_async).
The __cause__ attribute of this exception will be the last exception that
caused the retry to fail. It will be an
[HttpError][lancedb.remote.errors.HttpError] instance.
Attributes
----------
message: str
The retry error message, which will describe which retry limit was hit.
request_id: str
The id of the request that failed. This can be provided in error reports
to help diagnose the issue.
request_failures: int
The number of request failures.
connect_failures: int
The number of connect failures.
read_failures: int
The number of read failures.
max_request_failures: int
The maximum number of request failures.
max_connect_failures: int
The maximum number of connect failures.
max_read_failures: int
The maximum number of read failures.
status_code: int
The HTTP status code of the last response. May be None if the request
failed before the response was received.
"""
def __init__(
self,
message: str,
request_id: str,
request_failures: int,
connect_failures: int,
read_failures: int,
max_request_failures: int,
max_connect_failures: int,
max_read_failures: int,
status_code: Optional[int],
):
super().__init__(message, request_id, status_code)
self.request_failures = request_failures
self.connect_failures = connect_failures
self.read_failures = read_failures
self.max_request_failures = max_request_failures
self.max_connect_failures = max_connect_failures
self.max_read_failures = max_read_failures

View File

@@ -354,7 +354,7 @@ async def test_create_mode_async(tmp_path):
)
await db.create_table("test", data=data)
with pytest.raises(RuntimeError):
with pytest.raises(ValueError, match="already exists"):
await db.create_table("test", data=data)
new_data = pd.DataFrame(
@@ -382,7 +382,7 @@ async def test_create_exist_ok_async(tmp_path):
)
tbl = await db.create_table("test", data=data)
with pytest.raises(RuntimeError):
with pytest.raises(ValueError, match="already exists"):
await db.create_table("test", data=data)
# open the table but don't add more rows

View File

@@ -1,12 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import contextlib
import http.server
import threading
from unittest.mock import MagicMock
import uuid
import lancedb
from lancedb.remote.errors import HttpError, RetryError
import pyarrow as pa
from lancedb.remote.client import VectorQuery, VectorQueryResult
import pytest
@@ -98,6 +100,33 @@ def make_mock_http_handler(handler):
return MockLanceDBHandler
@contextlib.asynccontextmanager
async def mock_lancedb_connection(handler):
with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler)
) as server:
handle = threading.Thread(target=server.serve_forever)
handle.start()
db = await lancedb.connect_async(
"db://dev",
api_key="fake",
host_override="http://localhost:8080",
client_config={
"retry_config": {"retries": 2},
"timeout_config": {
"connect_timeout": 1,
},
},
)
try:
yield db
finally:
server.shutdown()
handle.join()
@pytest.mark.asyncio
async def test_async_remote_db():
def handler(request):
@@ -114,28 +143,50 @@ async def test_async_remote_db():
request.end_headers()
request.wfile.write(b'{"tables": []}')
def run_server():
with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler)
) as server:
# we will only make one request
server.handle_request()
async with mock_lancedb_connection(handler) as db:
table_names = await db.table_names()
assert table_names == []
handle = threading.Thread(target=run_server)
handle.start()
db = await lancedb.connect_async(
"db://dev",
api_key="fake",
host_override="http://localhost:8080",
client_config={
"retry_config": {"retries": 2},
"timeout_config": {
"connect_timeout": 1,
},
},
)
table_names = await db.table_names()
assert table_names == []
@pytest.mark.asyncio
async def test_http_error():
request_id_holder = {"request_id": None}
handle.join()
def handler(request):
request_id_holder["request_id"] = request.headers["x-request-id"]
request.send_response(507)
request.end_headers()
request.wfile.write(b"Internal Server Error")
async with mock_lancedb_connection(handler) as db:
with pytest.raises(HttpError, match="Internal Server Error") as exc_info:
await db.table_names()
assert exc_info.value.request_id == request_id_holder["request_id"]
assert exc_info.value.status_code == 507
@pytest.mark.asyncio
async def test_retry_error():
request_id_holder = {"request_id": None}
def handler(request):
request_id_holder["request_id"] = request.headers["x-request-id"]
request.send_response(429)
request.end_headers()
request.wfile.write(b"Try again later")
async with mock_lancedb_connection(handler) as db:
with pytest.raises(RetryError, match="Hit retry limit") as exc_info:
await db.table_names()
assert exc_info.value.request_id == request_id_holder["request_id"]
assert exc_info.value.status_code == 429
cause = exc_info.value.__cause__
assert isinstance(cause, HttpError)
assert "Try again later" in str(cause)
assert cause.request_id == request_id_holder["request_id"]
assert cause.status_code == 429

View File

@@ -14,7 +14,9 @@
use pyo3::{
exceptions::{PyIOError, PyNotImplementedError, PyOSError, PyRuntimeError, PyValueError},
PyResult,
intern,
types::{PyAnyMethods, PyNone},
PyErr, PyResult, Python,
};
use lancedb::error::Error as LanceError;
@@ -38,12 +40,79 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
LanceError::InvalidInput { .. }
| LanceError::InvalidTableName { .. }
| LanceError::TableNotFound { .. }
| LanceError::Schema { .. } => self.value_error(),
| LanceError::Schema { .. }
| LanceError::TableAlreadyExists { .. } => self.value_error(),
LanceError::CreateDir { .. } => self.os_error(),
LanceError::ObjectStore { .. } => Err(PyIOError::new_err(err.to_string())),
LanceError::NotSupported { .. } => {
Err(PyNotImplementedError::new_err(err.to_string()))
}
LanceError::Http {
request_id,
source,
status_code,
} => Python::with_gil(|py| {
let message = err.to_string();
let http_err_cls = py
.import_bound(intern!(py, "lancedb.remote.errors"))?
.getattr(intern!(py, "HttpError"))?;
let err = http_err_cls.call1((
message,
request_id,
status_code.map(|s| s.as_u16()),
))?;
if let Some(cause) = source.source() {
// The HTTP error already includes the first cause. But
// we can add the rest of the chain if there is any more.
let cause_err = http_from_rust_error(
py,
cause,
request_id,
status_code.map(|s| s.as_u16()),
)?;
err.setattr(intern!(py, "__cause__"), cause_err)?;
}
Err(PyErr::from_value_bound(err))
}),
LanceError::Retry {
request_id,
request_failures,
max_request_failures,
connect_failures,
max_connect_failures,
read_failures,
max_read_failures,
source,
status_code,
} => Python::with_gil(|py| {
let cause_err = http_from_rust_error(
py,
source.as_ref(),
request_id,
status_code.map(|s| s.as_u16()),
)?;
let message = err.to_string();
let retry_error_cls = py
.import_bound(intern!(py, "lancedb.remote.errors"))?
.getattr("RetryError")?;
let err = retry_error_cls.call1((
message,
request_id,
*request_failures,
*connect_failures,
*read_failures,
*max_request_failures,
*max_connect_failures,
*max_read_failures,
status_code.map(|s| s.as_u16()),
))?;
err.setattr(intern!(py, "__cause__"), cause_err)?;
Err(PyErr::from_value_bound(err))
}),
_ => self.runtime_error(),
},
}
@@ -61,3 +130,24 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
self.map_err(|err| PyValueError::new_err(err.to_string()))
}
}
fn http_from_rust_error(
py: Python<'_>,
err: &dyn std::error::Error,
request_id: &str,
status_code: Option<u16>,
) -> PyResult<PyErr> {
let message = err.to_string();
let http_err_cls = py.import("lancedb.remote.errors")?.getattr("HttpError")?;
let py_err = http_err_cls.call1((message, request_id, status_code))?;
// Reset the traceback since it doesn't provide additional information.
let py_err = py_err.call_method1(intern!(py, "with_traceback"), (PyNone::get_bound(py),))?;
if let Some(cause) = err.source() {
let cause_err = http_from_rust_error(py, cause, request_id, status_code)?;
py_err.setattr(intern!(py, "__cause__"), cause_err)?;
}
Ok(PyErr::from_value(py_err))
}

View File

@@ -46,8 +46,37 @@ pub enum Error {
ObjectStore { source: object_store::Error },
#[snafu(display("lance error: {source}"))]
Lance { source: lance::Error },
#[snafu(display("Http error: {message}"))]
Http { message: String },
#[cfg(feature = "remote")]
#[snafu(display("Http error: (request_id={request_id}) {source}"))]
Http {
#[snafu(source(from(reqwest::Error, Box::new)))]
source: Box<dyn std::error::Error + Send + Sync>,
request_id: String,
/// Status code associated with the error, if available.
/// This is not always available, for example when the error is due to a
/// connection failure. It may also be missing if the request was
/// successful but there was an error decoding the response.
status_code: Option<reqwest::StatusCode>,
},
#[cfg(feature = "remote")]
#[snafu(display(
"Hit retry limit for request_id={request_id} (\
request_failures={request_failures}/{max_request_failures}, \
connect_failures={connect_failures}/{max_connect_failures}, \
read_failures={read_failures}/{max_read_failures})"
))]
Retry {
request_id: String,
request_failures: u8,
max_request_failures: u8,
connect_failures: u8,
max_connect_failures: u8,
read_failures: u8,
max_read_failures: u8,
#[snafu(source(from(reqwest::Error, Box::new)))]
source: Box<dyn std::error::Error + Send + Sync>,
status_code: Option<reqwest::StatusCode>,
},
#[snafu(display("Arrow error: {source}"))]
Arrow { source: ArrowError },
#[snafu(display("LanceDBError: not supported: {message}"))]
@@ -98,24 +127,6 @@ impl<T> From<PoisonError<T>> for Error {
}
}
#[cfg(feature = "remote")]
impl From<reqwest::Error> for Error {
fn from(e: reqwest::Error) -> Self {
Self::Http {
message: e.to_string(),
}
}
}
#[cfg(feature = "remote")]
impl From<url::ParseError> for Error {
fn from(e: url::ParseError) -> Self {
Self::Http {
message: e.to_string(),
}
}
}
#[cfg(feature = "polars")]
impl From<polars::prelude::PolarsError> for Error {
fn from(source: polars::prelude::PolarsError) -> Self {

View File

@@ -216,10 +216,12 @@ impl RestfulLanceDbClient<Sender> {
host_override: Option<String>,
client_config: ClientConfig,
) -> Result<Self> {
let parsed_url = url::Url::parse(db_url)?;
let parsed_url = url::Url::parse(db_url).map_err(|err| Error::InvalidInput {
message: format!("db_url is not a valid URL. '{db_url}'. Error: {err}"),
})?;
debug_assert_eq!(parsed_url.scheme(), "db");
if !parsed_url.has_host() {
return Err(Error::Http {
return Err(Error::InvalidInput {
message: format!("Invalid database URL (missing host) '{}'", db_url),
});
}
@@ -255,7 +257,11 @@ impl RestfulLanceDbClient<Sender> {
host_override.is_some(),
)?)
.user_agent(client_config.user_agent)
.build()?;
.build()
.map_err(|err| Error::Other {
message: "Failed to build HTTP client".into(),
source: Some(Box::new(err)),
})?;
let host = match host_override {
Some(host_override) => host_override,
None => format!("https://{}.{}.api.lancedb.com", db_name, region),
@@ -284,7 +290,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
let mut headers = HeaderMap::new();
headers.insert(
"x-api-key",
HeaderValue::from_str(api_key).map_err(|_| Error::Http {
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
message: "non-ascii api key provided".to_string(),
})?,
);
@@ -292,7 +298,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
let host = format!("{}.local.api.lancedb.com", db_name);
headers.insert(
"Host",
HeaderValue::from_str(&host).map_err(|_| Error::Http {
HeaderValue::from_str(&host).map_err(|_| Error::InvalidInput {
message: format!("non-ascii database name '{}' provided", db_name),
})?,
);
@@ -300,7 +306,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
if has_host_override {
headers.insert(
"x-lancedb-database",
HeaderValue::from_str(db_name).map_err(|_| Error::Http {
HeaderValue::from_str(db_name).map_err(|_| Error::InvalidInput {
message: format!("non-ascii database name '{}' provided", db_name),
})?,
);
@@ -319,22 +325,30 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
self.client.post(full_uri)
}
pub async fn send(&self, req: RequestBuilder, with_retry: bool) -> Result<Response> {
pub async fn send(&self, req: RequestBuilder, with_retry: bool) -> Result<(String, Response)> {
let (client, request) = req.build_split();
let mut request = request.unwrap();
// Set a request id.
// TODO: allow the user to supply this, through middleware?
if request.headers().get(REQUEST_ID_HEADER).is_none() {
let request_id = uuid::Uuid::new_v4();
let request_id = HeaderValue::from_str(&request_id.to_string()).unwrap();
request.headers_mut().insert(REQUEST_ID_HEADER, request_id);
}
let request_id = if let Some(request_id) = request.headers().get(REQUEST_ID_HEADER) {
request_id.to_str().unwrap().to_string()
} else {
let request_id = uuid::Uuid::new_v4().to_string();
let header = HeaderValue::from_str(&request_id).unwrap();
request.headers_mut().insert(REQUEST_ID_HEADER, header);
request_id
};
if with_retry {
self.send_with_retry_impl(client, request).await
self.send_with_retry_impl(client, request, request_id).await
} else {
Ok(self.sender.send(&client, request).await?)
let response = self
.sender
.send(&client, request)
.await
.err_to_http(request_id.clone())?;
Ok((request_id, response))
}
}
@@ -342,98 +356,178 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
&self,
client: reqwest::Client,
req: Request,
) -> Result<Response> {
let mut request_failures = 0;
let mut connect_failures = 0;
let mut read_failures = 0;
request_id: String,
) -> Result<(String, Response)> {
let mut retry_counter = RetryCounter::new(&self.retry_config, request_id);
loop {
// This only works if the request body is not a stream. If it is
// a stream, we can't use the retry path. We would need to implement
// an outer retry.
let request = req.try_clone().ok_or_else(|| Error::Http {
let request = req.try_clone().ok_or_else(|| Error::Runtime {
message: "Attempted to retry a request that cannot be cloned".to_string(),
})?;
let response = self.sender.send(&client, request).await;
let status_code = response.as_ref().map(|r| r.status());
match status_code {
Ok(status) if status.is_success() => return Ok(response?),
Ok(status) if self.retry_config.statuses.contains(&status) => {
request_failures += 1;
if request_failures >= self.retry_config.retries {
// TODO: better error
return Err(Error::Runtime {
message: format!(
"Request failed after {} retries with status code {}",
request_failures, status
),
});
}
let response = self
.sender
.send(&client, request)
.await
.map(|r| (r.status(), r));
match response {
Ok((status, response)) if status.is_success() => {
return Ok((retry_counter.request_id, response))
}
Ok((status, response)) if self.retry_config.statuses.contains(&status) => {
let source = self
.check_response(&retry_counter.request_id, response)
.await
.unwrap_err();
retry_counter.increment_request_failures(source)?;
}
Err(err) if err.is_connect() => {
connect_failures += 1;
if connect_failures >= self.retry_config.connect_retries {
return Err(Error::Runtime {
message: format!(
"Request failed after {} connect retries with error: {}",
connect_failures, err
),
});
}
retry_counter.increment_connect_failures(err)?;
}
Err(err) if err.is_timeout() || err.is_body() || err.is_decode() => {
read_failures += 1;
if read_failures >= self.retry_config.read_retries {
return Err(Error::Runtime {
message: format!(
"Request failed after {} read retries with error: {}",
read_failures, err
),
});
}
retry_counter.increment_read_failures(err)?;
}
Ok(_) | Err(_) => return Ok(response?),
Err(err) => {
let status_code = err.status();
return Err(Error::Http {
source: Box::new(err),
request_id: retry_counter.request_id,
status_code,
});
}
Ok((_, response)) => return Ok((retry_counter.request_id, response)),
}
let backoff = self.retry_config.backoff_factor * (2.0f32.powi(request_failures as i32));
let jitter = rand::random::<f32>() * self.retry_config.backoff_jitter;
let sleep_time = Duration::from_secs_f32(backoff + jitter);
debug!(
"Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}",
req.headers()
.get("x-request-id")
.and_then(|v| v.to_str().ok()),
connect_failures,
self.retry_config.connect_retries,
request_failures,
self.retry_config.retries,
read_failures,
self.retry_config.read_retries,
sleep_time
);
let sleep_time = retry_counter.next_sleep_time();
tokio::time::sleep(sleep_time).await;
}
}
async fn rsp_to_str(response: Response) -> String {
pub async fn check_response(&self, request_id: &str, response: Response) -> Result<Response> {
// Try to get the response text, but if that fails, just return the status code
let status = response.status();
response.text().await.unwrap_or_else(|_| status.to_string())
if status.is_success() {
Ok(response)
} else {
let response_text = response.text().await.ok();
let message = if let Some(response_text) = response_text {
format!("{}: {}", status, response_text)
} else {
status.to_string()
};
Err(Error::Http {
source: message.into(),
request_id: request_id.into(),
status_code: Some(status),
})
}
}
}
struct RetryCounter<'a> {
request_failures: u8,
connect_failures: u8,
read_failures: u8,
config: &'a ResolvedRetryConfig,
request_id: String,
}
impl<'a> RetryCounter<'a> {
fn new(config: &'a ResolvedRetryConfig, request_id: String) -> Self {
Self {
request_failures: 0,
connect_failures: 0,
read_failures: 0,
config,
request_id,
}
}
pub async fn check_response(&self, response: Response) -> Result<Response> {
let status_int: u16 = u16::from(response.status());
if (400..500).contains(&status_int) {
Err(Error::InvalidInput {
message: Self::rsp_to_str(response).await,
})
} else if status_int != 200 {
Err(Error::Runtime {
message: Self::rsp_to_str(response).await,
fn check_out_of_retries(
&self,
source: Box<dyn std::error::Error + Send + Sync>,
status_code: Option<reqwest::StatusCode>,
) -> Result<()> {
if self.request_failures >= self.config.retries
|| self.connect_failures >= self.config.connect_retries
|| self.read_failures >= self.config.read_retries
{
Err(Error::Retry {
request_id: self.request_id.clone(),
request_failures: self.request_failures,
max_request_failures: self.config.retries,
connect_failures: self.connect_failures,
max_connect_failures: self.config.connect_retries,
read_failures: self.read_failures,
max_read_failures: self.config.read_retries,
source,
status_code,
})
} else {
Ok(response)
Ok(())
}
}
fn increment_request_failures(&mut self, source: crate::Error) -> Result<()> {
self.request_failures += 1;
let status_code = if let crate::Error::Http { status_code, .. } = &source {
*status_code
} else {
None
};
self.check_out_of_retries(Box::new(source), status_code)
}
fn increment_connect_failures(&mut self, source: reqwest::Error) -> Result<()> {
self.connect_failures += 1;
let status_code = source.status();
self.check_out_of_retries(Box::new(source), status_code)
}
fn increment_read_failures(&mut self, source: reqwest::Error) -> Result<()> {
self.read_failures += 1;
let status_code = source.status();
self.check_out_of_retries(Box::new(source), status_code)
}
fn next_sleep_time(&self) -> Duration {
let backoff = self.config.backoff_factor * (2.0f32.powi(self.request_failures as i32));
let jitter = rand::random::<f32>() * self.config.backoff_jitter;
let sleep_time = Duration::from_secs_f32(backoff + jitter);
debug!(
"Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}",
self.request_id,
self.connect_failures,
self.config.connect_retries,
self.request_failures,
self.config.retries,
self.read_failures,
self.config.read_retries,
sleep_time
);
sleep_time
}
}
pub trait RequestResultExt {
type Output;
fn err_to_http(self, request_id: String) -> Result<Self::Output>;
}
impl<T> RequestResultExt for reqwest::Result<T> {
type Output = T;
fn err_to_http(self, request_id: String) -> Result<T> {
self.map_err(|err| {
let status_code = err.status();
Error::Http {
source: Box::new(err),
request_id,
status_code,
}
})
}
}
#[cfg(test)]

View File

@@ -29,7 +29,7 @@ use crate::embeddings::EmbeddingRegistry;
use crate::error::Result;
use crate::Table;
use super::client::{ClientConfig, HttpSend, RestfulLanceDbClient, Sender};
use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender};
use super::table::RemoteTable;
use super::util::batches_to_ipc_bytes;
use super::ARROW_STREAM_CONTENT_TYPE;
@@ -105,9 +105,13 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
if let Some(start_after) = options.start_after {
req = req.query(&[("page_token", start_after)]);
}
let rsp = self.client.send(req, true).await?;
let rsp = self.client.check_response(rsp).await?;
let tables = rsp.json::<ListTablesResponse>().await?.tables;
let (request_id, rsp) = self.client.send(req, true).await?;
let rsp = self.client.check_response(&request_id, rsp).await?;
let tables = rsp
.json::<ListTablesResponse>()
.await
.err_to_http(request_id)?
.tables;
for table in &tables {
self.table_cache.insert(table.clone(), ()).await;
}
@@ -130,13 +134,11 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
.client
.post(&format!("/v1/table/{}/create/", options.name))
.body(data_buffer)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
// This is currently expected by LanceDb cloud but will be removed soon.
.header("x-request-id", "na");
let rsp = self.client.send(req, false).await?;
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
let (request_id, rsp) = self.client.send(req, false).await?;
if rsp.status() == StatusCode::BAD_REQUEST {
let body = rsp.text().await?;
let body = rsp.text().await.err_to_http(request_id.clone())?;
if body.contains("already exists") {
return Err(crate::Error::TableAlreadyExists { name: options.name });
} else {
@@ -144,7 +146,7 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
}
}
self.client.check_response(rsp).await?;
self.client.check_response(&request_id, rsp).await?;
self.table_cache.insert(options.name.clone(), ()).await;
@@ -160,11 +162,11 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
let req = self
.client
.get(&format!("/v1/table/{}/describe/", options.name));
let resp = self.client.send(req, true).await?;
let (request_id, resp) = self.client.send(req, true).await?;
if resp.status() == StatusCode::NOT_FOUND {
return Err(crate::Error::TableNotFound { name: options.name });
}
self.client.check_response(resp).await?;
self.client.check_response(&request_id, resp).await?;
}
Ok(Table::new(Arc::new(RemoteTable::new(
@@ -178,8 +180,8 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
.client
.post(&format!("/v1/table/{}/rename/", current_name));
let req = req.json(&serde_json::json!({ "new_table_name": new_name }));
let resp = self.client.send(req, false).await?;
self.client.check_response(resp).await?;
let (request_id, resp) = self.client.send(req, false).await?;
self.client.check_response(&request_id, resp).await?;
self.table_cache.remove(current_name).await;
self.table_cache.insert(new_name.into(), ()).await;
Ok(())
@@ -187,8 +189,8 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
async fn drop_table(&self, name: &str) -> Result<()> {
let req = self.client.post(&format!("/v1/table/{}/drop/", name));
let resp = self.client.send(req, true).await?;
self.client.check_response(resp).await?;
let (request_id, resp) = self.client.send(req, true).await?;
self.client.check_response(&request_id, resp).await?;
self.table_cache.remove(name).await;
Ok(())
}
@@ -206,16 +208,57 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::{Arc, OnceLock};
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
use crate::{
remote::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE},
Connection,
Connection, Error,
};
#[tokio::test]
async fn test_retries() {
// We'll record the request_id here, to check it matches the one in the error.
let seen_request_id = Arc::new(OnceLock::new());
let seen_request_id_ref = seen_request_id.clone();
let conn = Connection::new_with_handler(move |request| {
// Request id should be the same on each retry.
let request_id = request.headers()["x-request-id"]
.to_str()
.unwrap()
.to_string();
let seen_id = seen_request_id_ref.get_or_init(|| request_id.clone());
assert_eq!(&request_id, seen_id);
http::Response::builder()
.status(500)
.body("internal server error")
.unwrap()
});
let result = conn.table_names().execute().await;
if let Err(Error::Retry {
request_id,
request_failures,
max_request_failures,
source,
..
}) = result
{
let expected_id = seen_request_id.get().unwrap();
assert_eq!(&request_id, expected_id);
assert_eq!(request_failures, max_request_failures);
assert!(
source.to_string().contains("internal server error"),
"source: {:?}",
source
);
} else {
panic!("unexpected result: {:?}", result);
};
}
#[tokio::test]
async fn test_table_names() {
let conn = Connection::new_with_handler(|request| {

View File

@@ -34,6 +34,7 @@ use crate::{
},
};
use super::client::RequestResultExt;
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
use super::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE};
@@ -53,15 +54,25 @@ impl<S: HttpSend> RemoteTable<S> {
let request = self
.client
.post(&format!("/v1/table/{}/describe/", self.name));
let response = self.client.send(request, true).await?;
let (request_id, response) = self.client.send(request, true).await?;
let response = self.check_table_response(response).await?;
let response = self.check_table_response(&request_id, response).await?;
let body = response.text().await?;
serde_json::from_str(&body).map_err(|e| Error::Http {
message: format!("Failed to parse table description: {}", e),
})
match response.text().await {
Ok(body) => serde_json::from_str(&body).map_err(|e| Error::Http {
source: format!("Failed to parse table description: {}", e).into(),
request_id,
status_code: None,
}),
Err(err) => {
let status_code = err.status();
Err(Error::Http {
source: Box::new(err),
request_id,
status_code,
})
}
}
}
fn reader_as_body(data: Box<dyn RecordBatchReader + Send>) -> Result<reqwest::Body> {
@@ -87,18 +98,23 @@ impl<S: HttpSend> RemoteTable<S> {
Ok(reqwest::Body::wrap_stream(body_stream))
}
async fn check_table_response(&self, response: reqwest::Response) -> Result<reqwest::Response> {
async fn check_table_response(
&self,
request_id: &str,
response: reqwest::Response,
) -> Result<reqwest::Response> {
if response.status() == StatusCode::NOT_FOUND {
return Err(Error::TableNotFound {
name: self.name.clone(),
});
}
self.client.check_response(response).await
self.client.check_response(request_id, response).await
}
async fn read_arrow_stream(
&self,
request_id: &str,
body: reqwest::Response,
) -> Result<SendableRecordBatchStream> {
// Assert that the content type is correct
@@ -106,24 +122,31 @@ impl<S: HttpSend> RemoteTable<S> {
.headers()
.get(CONTENT_TYPE)
.ok_or_else(|| Error::Http {
message: "Missing content type".into(),
source: "Missing content type".into(),
request_id: request_id.to_string(),
status_code: None,
})?
.to_str()
.map_err(|e| Error::Http {
message: format!("Failed to parse content type: {}", e),
source: format!("Failed to parse content type: {}", e).into(),
request_id: request_id.to_string(),
status_code: None,
})?;
if content_type != ARROW_STREAM_CONTENT_TYPE {
return Err(Error::Http {
message: format!(
source: format!(
"Expected content type {}, got {}",
ARROW_STREAM_CONTENT_TYPE, content_type
),
)
.into(),
request_id: request_id.to_string(),
status_code: None,
});
}
// There isn't a way to actually stream this data yet. I have an upstream issue:
// https://github.com/apache/arrow-rs/issues/6420
let body = body.bytes().await?;
let body = body.bytes().await.err_to_http(request_id.into())?;
let reader = StreamReader::try_new(body.reader(), None)?;
let schema = reader.schema();
let stream = futures::stream::iter(reader).map_err(DataFusionError::from);
@@ -259,14 +282,16 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
request = request.json(&serde_json::json!({}));
}
let response = self.client.send(request, true).await?;
let (request_id, response) = self.client.send(request, true).await?;
let response = self.check_table_response(response).await?;
let response = self.check_table_response(&request_id, response).await?;
let body = response.text().await?;
let body = response.text().await.err_to_http(request_id.clone())?;
serde_json::from_str(&body).map_err(|e| Error::Http {
message: format!("Failed to parse row count: {}", e),
source: format!("Failed to parse row count: {}", e).into(),
request_id,
status_code: None,
})
}
async fn add(
@@ -288,9 +313,9 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
}
}
let response = self.client.send(request, false).await?;
let (request_id, response) = self.client.send(request, false).await?;
self.check_table_response(response).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
}
@@ -339,9 +364,9 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let request = request.json(&body);
let response = self.client.send(request, true).await?;
let (request_id, response) = self.client.send(request, true).await?;
let stream = self.read_arrow_stream(response).await?;
let stream = self.read_arrow_stream(&request_id, response).await?;
Ok(Arc::new(OneShotExec::new(stream)))
}
@@ -361,9 +386,9 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let request = request.json(&body);
let response = self.client.send(request, true).await?;
let (request_id, response) = self.client.send(request, true).await?;
let stream = self.read_arrow_stream(response).await?;
let stream = self.read_arrow_stream(&request_id, response).await?;
Ok(DatasetRecordBatchStream::new(stream))
}
@@ -383,17 +408,20 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
"only_if": update.filter,
}));
let response = self.client.send(request, false).await?;
let (request_id, response) = self.client.send(request, false).await?;
let response = self.check_table_response(response).await?;
let response = self.check_table_response(&request_id, response).await?;
let body = response.text().await?;
let body = response.text().await.err_to_http(request_id.clone())?;
serde_json::from_str(&body).map_err(|e| Error::Http {
message: format!(
source: format!(
"Failed to parse updated rows result from response {}: {}",
body, e
),
)
.into(),
request_id,
status_code: None,
})
}
async fn delete(&self, predicate: &str) -> Result<()> {
@@ -402,8 +430,8 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
.client
.post(&format!("/v1/table/{}/delete/", self.name))
.json(&body);
let response = self.client.send(request, false).await?;
self.check_table_response(response).await?;
let (request_id, response) = self.client.send(request, false).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
}
@@ -474,9 +502,9 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let request = request.json(&body);
let response = self.client.send(request, false).await?;
let (request_id, response) = self.client.send(request, false).await?;
self.check_table_response(response).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
}
@@ -495,9 +523,9 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.body(body);
let response = self.client.send(request, false).await?;
let (request_id, response) = self.client.send(request, false).await?;
self.check_table_response(response).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
}
@@ -531,8 +559,8 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let request = self
.client
.post(&format!("/v1/table/{}/index/list/", self.name));
let response = self.client.send(request, true).await?;
let response = self.check_table_response(response).await?;
let (request_id, response) = self.client.send(request, true).await?;
let response = self.check_table_response(&request_id, response).await?;
#[derive(Deserialize)]
struct ListIndicesResponse {
@@ -545,12 +573,15 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
columns: Vec<String>,
}
let body = response.text().await?;
let body = response.text().await.err_to_http(request_id.clone())?;
let body: ListIndicesResponse = serde_json::from_str(&body).map_err(|err| Error::Http {
message: format!(
source: format!(
"Failed to parse list_indices response: {}, body: {}",
err, body
),
)
.into(),
request_id,
status_code: None,
})?;
// Make request to get stats for each index, so we get the index type.
@@ -581,18 +612,20 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
"/v1/table/{}/index/{}/stats/",
self.name, index_name
));
let response = self.client.send(request, true).await?;
let (request_id, response) = self.client.send(request, true).await?;
if response.status() == StatusCode::NOT_FOUND {
return Ok(None);
}
let response = self.check_table_response(response).await?;
let response = self.check_table_response(&request_id, response).await?;
let body = response.text().await?;
let body = response.text().await.err_to_http(request_id.clone())?;
let stats = serde_json::from_str(&body).map_err(|e| Error::Http {
message: format!("Failed to parse index statistics: {}", e),
source: format!("Failed to parse index statistics: {}", e).into(),
request_id,
status_code: None,
})?;
Ok(Some(stats))