From 8509f732215a01456ba704fe8722dd88f8550135 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 8 Oct 2024 21:21:13 -0700 Subject: [PATCH] feat: better errors for remote SDK (#1722) * Adds nicer errors to remote SDK, that expose useful properties like `request_id` and `status_code`. * Makes sure the Python tracebacks print nicely by mapping the `source` field from a Rust error to the `__cause__` field. --- python/python/lancedb/remote/client.py | 18 +- python/python/lancedb/remote/errors.py | 97 ++++++++++ python/python/tests/test_db.py | 4 +- python/python/tests/test_remote_db.py | 95 ++++++--- python/src/error.rs | 94 ++++++++- rust/lancedb/src/error.rs | 51 +++-- rust/lancedb/src/remote/client.rs | 256 +++++++++++++++++-------- rust/lancedb/src/remote/db.rs | 79 ++++++-- rust/lancedb/src/remote/table.rs | 121 +++++++----- 9 files changed, 622 insertions(+), 193 deletions(-) diff --git a/python/python/lancedb/remote/client.py b/python/python/lancedb/remote/client.py index 5ad9a2d0..d546e92f 100644 --- a/python/python/lancedb/remote/client.py +++ b/python/python/lancedb/remote/client.py @@ -103,19 +103,29 @@ class RestfulLanceDBClient: @staticmethod def _check_status(resp: requests.Response): + # Leaving request id empty for now, as we'll be replacing this impl + # with the Rust one shortly. if resp.status_code == 404: - raise LanceDBClientError(f"Not found: {resp.text}") + raise LanceDBClientError( + f"Not found: {resp.text}", request_id="", status_code=404 + ) elif 400 <= resp.status_code < 500: raise LanceDBClientError( - f"Bad Request: {resp.status_code}, error: {resp.text}" + f"Bad Request: {resp.status_code}, error: {resp.text}", + request_id="", + status_code=resp.status_code, ) elif 500 <= resp.status_code < 600: raise LanceDBClientError( - f"Internal Server Error: {resp.status_code}, error: {resp.text}" + f"Internal Server Error: {resp.status_code}, error: {resp.text}", + request_id="", + status_code=resp.status_code, ) elif resp.status_code != 200: raise LanceDBClientError( - f"Unknown Error: {resp.status_code}, error: {resp.text}" + f"Unknown Error: {resp.status_code}, error: {resp.text}", + request_id="", + status_code=resp.status_code, ) @_check_not_closed diff --git a/python/python/lancedb/remote/errors.py b/python/python/lancedb/remote/errors.py index a4d290dc..d8f3fde6 100644 --- a/python/python/lancedb/remote/errors.py +++ b/python/python/lancedb/remote/errors.py @@ -12,5 +12,102 @@ # limitations under the License. +from typing import Optional + + class LanceDBClientError(RuntimeError): + """An error that occurred in the LanceDB client. + + Attributes + ---------- + message: str + The error message. + request_id: str + The id of the request that failed. This can be provided in error reports + to help diagnose the issue. + status_code: int + The HTTP status code of the response. May be None if the request + failed before the response was received. + """ + + def __init__( + self, message: str, request_id: str, status_code: Optional[int] = None + ): + super().__init__(message) + self.request_id = request_id + self.status_code = status_code + + +class HttpError(LanceDBClientError): + """An error that occurred during an HTTP request. + + Attributes + ---------- + message: str + The error message. + request_id: str + The id of the request that failed. This can be provided in error reports + to help diagnose the issue. + status_code: int + The HTTP status code of the response. May be None if the request + failed before the response was received. + """ + pass + + +class RetryError(LanceDBClientError): + """An error that occurs when the client has exceeded the maximum number of retries. + + The retry strategy can be adjusted by setting the + [retry_config](lancedb.remote.ClientConfig.retry_config) in the client + configuration. This is passed in the `client_config` argument of + [connect](lancedb.connect) and [connect_async](lancedb.connect_async). + + The __cause__ attribute of this exception will be the last exception that + caused the retry to fail. It will be an + [HttpError][lancedb.remote.errors.HttpError] instance. + + Attributes + ---------- + message: str + The retry error message, which will describe which retry limit was hit. + request_id: str + The id of the request that failed. This can be provided in error reports + to help diagnose the issue. + request_failures: int + The number of request failures. + connect_failures: int + The number of connect failures. + read_failures: int + The number of read failures. + max_request_failures: int + The maximum number of request failures. + max_connect_failures: int + The maximum number of connect failures. + max_read_failures: int + The maximum number of read failures. + status_code: int + The HTTP status code of the last response. May be None if the request + failed before the response was received. + """ + + def __init__( + self, + message: str, + request_id: str, + request_failures: int, + connect_failures: int, + read_failures: int, + max_request_failures: int, + max_connect_failures: int, + max_read_failures: int, + status_code: Optional[int], + ): + super().__init__(message, request_id, status_code) + self.request_failures = request_failures + self.connect_failures = connect_failures + self.read_failures = read_failures + self.max_request_failures = max_request_failures + self.max_connect_failures = max_connect_failures + self.max_read_failures = max_read_failures diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index 8bd7d3af..2e01343b 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -354,7 +354,7 @@ async def test_create_mode_async(tmp_path): ) await db.create_table("test", data=data) - with pytest.raises(RuntimeError): + with pytest.raises(ValueError, match="already exists"): await db.create_table("test", data=data) new_data = pd.DataFrame( @@ -382,7 +382,7 @@ async def test_create_exist_ok_async(tmp_path): ) tbl = await db.create_table("test", data=data) - with pytest.raises(RuntimeError): + with pytest.raises(ValueError, match="already exists"): await db.create_table("test", data=data) # open the table but don't add more rows diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index dee183d9..e03b6636 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -1,12 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors +import contextlib import http.server import threading from unittest.mock import MagicMock import uuid import lancedb +from lancedb.remote.errors import HttpError, RetryError import pyarrow as pa from lancedb.remote.client import VectorQuery, VectorQueryResult import pytest @@ -98,6 +100,33 @@ def make_mock_http_handler(handler): return MockLanceDBHandler +@contextlib.asynccontextmanager +async def mock_lancedb_connection(handler): + with http.server.HTTPServer( + ("localhost", 8080), make_mock_http_handler(handler) + ) as server: + handle = threading.Thread(target=server.serve_forever) + handle.start() + + db = await lancedb.connect_async( + "db://dev", + api_key="fake", + host_override="http://localhost:8080", + client_config={ + "retry_config": {"retries": 2}, + "timeout_config": { + "connect_timeout": 1, + }, + }, + ) + + try: + yield db + finally: + server.shutdown() + handle.join() + + @pytest.mark.asyncio async def test_async_remote_db(): def handler(request): @@ -114,28 +143,50 @@ async def test_async_remote_db(): request.end_headers() request.wfile.write(b'{"tables": []}') - def run_server(): - with http.server.HTTPServer( - ("localhost", 8080), make_mock_http_handler(handler) - ) as server: - # we will only make one request - server.handle_request() + async with mock_lancedb_connection(handler) as db: + table_names = await db.table_names() + assert table_names == [] - handle = threading.Thread(target=run_server) - handle.start() - db = await lancedb.connect_async( - "db://dev", - api_key="fake", - host_override="http://localhost:8080", - client_config={ - "retry_config": {"retries": 2}, - "timeout_config": { - "connect_timeout": 1, - }, - }, - ) - table_names = await db.table_names() - assert table_names == [] +@pytest.mark.asyncio +async def test_http_error(): + request_id_holder = {"request_id": None} - handle.join() + def handler(request): + request_id_holder["request_id"] = request.headers["x-request-id"] + + request.send_response(507) + request.end_headers() + request.wfile.write(b"Internal Server Error") + + async with mock_lancedb_connection(handler) as db: + with pytest.raises(HttpError, match="Internal Server Error") as exc_info: + await db.table_names() + + assert exc_info.value.request_id == request_id_holder["request_id"] + assert exc_info.value.status_code == 507 + + +@pytest.mark.asyncio +async def test_retry_error(): + request_id_holder = {"request_id": None} + + def handler(request): + request_id_holder["request_id"] = request.headers["x-request-id"] + + request.send_response(429) + request.end_headers() + request.wfile.write(b"Try again later") + + async with mock_lancedb_connection(handler) as db: + with pytest.raises(RetryError, match="Hit retry limit") as exc_info: + await db.table_names() + + assert exc_info.value.request_id == request_id_holder["request_id"] + assert exc_info.value.status_code == 429 + + cause = exc_info.value.__cause__ + assert isinstance(cause, HttpError) + assert "Try again later" in str(cause) + assert cause.request_id == request_id_holder["request_id"] + assert cause.status_code == 429 diff --git a/python/src/error.rs b/python/src/error.rs index 4688b523..4855b8f5 100644 --- a/python/src/error.rs +++ b/python/src/error.rs @@ -14,7 +14,9 @@ use pyo3::{ exceptions::{PyIOError, PyNotImplementedError, PyOSError, PyRuntimeError, PyValueError}, - PyResult, + intern, + types::{PyAnyMethods, PyNone}, + PyErr, PyResult, Python, }; use lancedb::error::Error as LanceError; @@ -38,12 +40,79 @@ impl PythonErrorExt for std::result::Result { LanceError::InvalidInput { .. } | LanceError::InvalidTableName { .. } | LanceError::TableNotFound { .. } - | LanceError::Schema { .. } => self.value_error(), + | LanceError::Schema { .. } + | LanceError::TableAlreadyExists { .. } => self.value_error(), LanceError::CreateDir { .. } => self.os_error(), LanceError::ObjectStore { .. } => Err(PyIOError::new_err(err.to_string())), LanceError::NotSupported { .. } => { Err(PyNotImplementedError::new_err(err.to_string())) } + LanceError::Http { + request_id, + source, + status_code, + } => Python::with_gil(|py| { + let message = err.to_string(); + let http_err_cls = py + .import_bound(intern!(py, "lancedb.remote.errors"))? + .getattr(intern!(py, "HttpError"))?; + let err = http_err_cls.call1(( + message, + request_id, + status_code.map(|s| s.as_u16()), + ))?; + + if let Some(cause) = source.source() { + // The HTTP error already includes the first cause. But + // we can add the rest of the chain if there is any more. + let cause_err = http_from_rust_error( + py, + cause, + request_id, + status_code.map(|s| s.as_u16()), + )?; + err.setattr(intern!(py, "__cause__"), cause_err)?; + } + + Err(PyErr::from_value_bound(err)) + }), + LanceError::Retry { + request_id, + request_failures, + max_request_failures, + connect_failures, + max_connect_failures, + read_failures, + max_read_failures, + source, + status_code, + } => Python::with_gil(|py| { + let cause_err = http_from_rust_error( + py, + source.as_ref(), + request_id, + status_code.map(|s| s.as_u16()), + )?; + + let message = err.to_string(); + let retry_error_cls = py + .import_bound(intern!(py, "lancedb.remote.errors"))? + .getattr("RetryError")?; + let err = retry_error_cls.call1(( + message, + request_id, + *request_failures, + *connect_failures, + *read_failures, + *max_request_failures, + *max_connect_failures, + *max_read_failures, + status_code.map(|s| s.as_u16()), + ))?; + + err.setattr(intern!(py, "__cause__"), cause_err)?; + Err(PyErr::from_value_bound(err)) + }), _ => self.runtime_error(), }, } @@ -61,3 +130,24 @@ impl PythonErrorExt for std::result::Result { self.map_err(|err| PyValueError::new_err(err.to_string())) } } + +fn http_from_rust_error( + py: Python<'_>, + err: &dyn std::error::Error, + request_id: &str, + status_code: Option, +) -> PyResult { + let message = err.to_string(); + let http_err_cls = py.import("lancedb.remote.errors")?.getattr("HttpError")?; + let py_err = http_err_cls.call1((message, request_id, status_code))?; + + // Reset the traceback since it doesn't provide additional information. + let py_err = py_err.call_method1(intern!(py, "with_traceback"), (PyNone::get_bound(py),))?; + + if let Some(cause) = err.source() { + let cause_err = http_from_rust_error(py, cause, request_id, status_code)?; + py_err.setattr(intern!(py, "__cause__"), cause_err)?; + } + + Ok(PyErr::from_value(py_err)) +} diff --git a/rust/lancedb/src/error.rs b/rust/lancedb/src/error.rs index 77f2373a..37bd8852 100644 --- a/rust/lancedb/src/error.rs +++ b/rust/lancedb/src/error.rs @@ -46,8 +46,37 @@ pub enum Error { ObjectStore { source: object_store::Error }, #[snafu(display("lance error: {source}"))] Lance { source: lance::Error }, - #[snafu(display("Http error: {message}"))] - Http { message: String }, + #[cfg(feature = "remote")] + #[snafu(display("Http error: (request_id={request_id}) {source}"))] + Http { + #[snafu(source(from(reqwest::Error, Box::new)))] + source: Box, + request_id: String, + /// Status code associated with the error, if available. + /// This is not always available, for example when the error is due to a + /// connection failure. It may also be missing if the request was + /// successful but there was an error decoding the response. + status_code: Option, + }, + #[cfg(feature = "remote")] + #[snafu(display( + "Hit retry limit for request_id={request_id} (\ + request_failures={request_failures}/{max_request_failures}, \ + connect_failures={connect_failures}/{max_connect_failures}, \ + read_failures={read_failures}/{max_read_failures})" + ))] + Retry { + request_id: String, + request_failures: u8, + max_request_failures: u8, + connect_failures: u8, + max_connect_failures: u8, + read_failures: u8, + max_read_failures: u8, + #[snafu(source(from(reqwest::Error, Box::new)))] + source: Box, + status_code: Option, + }, #[snafu(display("Arrow error: {source}"))] Arrow { source: ArrowError }, #[snafu(display("LanceDBError: not supported: {message}"))] @@ -98,24 +127,6 @@ impl From> for Error { } } -#[cfg(feature = "remote")] -impl From for Error { - fn from(e: reqwest::Error) -> Self { - Self::Http { - message: e.to_string(), - } - } -} - -#[cfg(feature = "remote")] -impl From for Error { - fn from(e: url::ParseError) -> Self { - Self::Http { - message: e.to_string(), - } - } -} - #[cfg(feature = "polars")] impl From for Error { fn from(source: polars::prelude::PolarsError) -> Self { diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index a42c0733..83d5a14f 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -216,10 +216,12 @@ impl RestfulLanceDbClient { host_override: Option, client_config: ClientConfig, ) -> Result { - let parsed_url = url::Url::parse(db_url)?; + let parsed_url = url::Url::parse(db_url).map_err(|err| Error::InvalidInput { + message: format!("db_url is not a valid URL. '{db_url}'. Error: {err}"), + })?; debug_assert_eq!(parsed_url.scheme(), "db"); if !parsed_url.has_host() { - return Err(Error::Http { + return Err(Error::InvalidInput { message: format!("Invalid database URL (missing host) '{}'", db_url), }); } @@ -255,7 +257,11 @@ impl RestfulLanceDbClient { host_override.is_some(), )?) .user_agent(client_config.user_agent) - .build()?; + .build() + .map_err(|err| Error::Other { + message: "Failed to build HTTP client".into(), + source: Some(Box::new(err)), + })?; let host = match host_override { Some(host_override) => host_override, None => format!("https://{}.{}.api.lancedb.com", db_name, region), @@ -284,7 +290,7 @@ impl RestfulLanceDbClient { let mut headers = HeaderMap::new(); headers.insert( "x-api-key", - HeaderValue::from_str(api_key).map_err(|_| Error::Http { + HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput { message: "non-ascii api key provided".to_string(), })?, ); @@ -292,7 +298,7 @@ impl RestfulLanceDbClient { let host = format!("{}.local.api.lancedb.com", db_name); headers.insert( "Host", - HeaderValue::from_str(&host).map_err(|_| Error::Http { + HeaderValue::from_str(&host).map_err(|_| Error::InvalidInput { message: format!("non-ascii database name '{}' provided", db_name), })?, ); @@ -300,7 +306,7 @@ impl RestfulLanceDbClient { if has_host_override { headers.insert( "x-lancedb-database", - HeaderValue::from_str(db_name).map_err(|_| Error::Http { + HeaderValue::from_str(db_name).map_err(|_| Error::InvalidInput { message: format!("non-ascii database name '{}' provided", db_name), })?, ); @@ -319,22 +325,30 @@ impl RestfulLanceDbClient { self.client.post(full_uri) } - pub async fn send(&self, req: RequestBuilder, with_retry: bool) -> Result { + pub async fn send(&self, req: RequestBuilder, with_retry: bool) -> Result<(String, Response)> { let (client, request) = req.build_split(); let mut request = request.unwrap(); // Set a request id. // TODO: allow the user to supply this, through middleware? - if request.headers().get(REQUEST_ID_HEADER).is_none() { - let request_id = uuid::Uuid::new_v4(); - let request_id = HeaderValue::from_str(&request_id.to_string()).unwrap(); - request.headers_mut().insert(REQUEST_ID_HEADER, request_id); - } + let request_id = if let Some(request_id) = request.headers().get(REQUEST_ID_HEADER) { + request_id.to_str().unwrap().to_string() + } else { + let request_id = uuid::Uuid::new_v4().to_string(); + let header = HeaderValue::from_str(&request_id).unwrap(); + request.headers_mut().insert(REQUEST_ID_HEADER, header); + request_id + }; if with_retry { - self.send_with_retry_impl(client, request).await + self.send_with_retry_impl(client, request, request_id).await } else { - Ok(self.sender.send(&client, request).await?) + let response = self + .sender + .send(&client, request) + .await + .err_to_http(request_id.clone())?; + Ok((request_id, response)) } } @@ -342,98 +356,178 @@ impl RestfulLanceDbClient { &self, client: reqwest::Client, req: Request, - ) -> Result { - let mut request_failures = 0; - let mut connect_failures = 0; - let mut read_failures = 0; + request_id: String, + ) -> Result<(String, Response)> { + let mut retry_counter = RetryCounter::new(&self.retry_config, request_id); loop { // This only works if the request body is not a stream. If it is // a stream, we can't use the retry path. We would need to implement // an outer retry. - let request = req.try_clone().ok_or_else(|| Error::Http { + let request = req.try_clone().ok_or_else(|| Error::Runtime { message: "Attempted to retry a request that cannot be cloned".to_string(), })?; - let response = self.sender.send(&client, request).await; - let status_code = response.as_ref().map(|r| r.status()); - match status_code { - Ok(status) if status.is_success() => return Ok(response?), - Ok(status) if self.retry_config.statuses.contains(&status) => { - request_failures += 1; - if request_failures >= self.retry_config.retries { - // TODO: better error - return Err(Error::Runtime { - message: format!( - "Request failed after {} retries with status code {}", - request_failures, status - ), - }); - } + let response = self + .sender + .send(&client, request) + .await + .map(|r| (r.status(), r)); + match response { + Ok((status, response)) if status.is_success() => { + return Ok((retry_counter.request_id, response)) + } + Ok((status, response)) if self.retry_config.statuses.contains(&status) => { + let source = self + .check_response(&retry_counter.request_id, response) + .await + .unwrap_err(); + retry_counter.increment_request_failures(source)?; } Err(err) if err.is_connect() => { - connect_failures += 1; - if connect_failures >= self.retry_config.connect_retries { - return Err(Error::Runtime { - message: format!( - "Request failed after {} connect retries with error: {}", - connect_failures, err - ), - }); - } + retry_counter.increment_connect_failures(err)?; } Err(err) if err.is_timeout() || err.is_body() || err.is_decode() => { - read_failures += 1; - if read_failures >= self.retry_config.read_retries { - return Err(Error::Runtime { - message: format!( - "Request failed after {} read retries with error: {}", - read_failures, err - ), - }); - } + retry_counter.increment_read_failures(err)?; } - Ok(_) | Err(_) => return Ok(response?), + Err(err) => { + let status_code = err.status(); + return Err(Error::Http { + source: Box::new(err), + request_id: retry_counter.request_id, + status_code, + }); + } + Ok((_, response)) => return Ok((retry_counter.request_id, response)), } - let backoff = self.retry_config.backoff_factor * (2.0f32.powi(request_failures as i32)); - let jitter = rand::random::() * self.retry_config.backoff_jitter; - let sleep_time = Duration::from_secs_f32(backoff + jitter); - debug!( - "Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}", - req.headers() - .get("x-request-id") - .and_then(|v| v.to_str().ok()), - connect_failures, - self.retry_config.connect_retries, - request_failures, - self.retry_config.retries, - read_failures, - self.retry_config.read_retries, - sleep_time - ); + let sleep_time = retry_counter.next_sleep_time(); tokio::time::sleep(sleep_time).await; } } - async fn rsp_to_str(response: Response) -> String { + pub async fn check_response(&self, request_id: &str, response: Response) -> Result { + // Try to get the response text, but if that fails, just return the status code let status = response.status(); - response.text().await.unwrap_or_else(|_| status.to_string()) + if status.is_success() { + Ok(response) + } else { + let response_text = response.text().await.ok(); + let message = if let Some(response_text) = response_text { + format!("{}: {}", status, response_text) + } else { + status.to_string() + }; + Err(Error::Http { + source: message.into(), + request_id: request_id.into(), + status_code: Some(status), + }) + } + } +} + +struct RetryCounter<'a> { + request_failures: u8, + connect_failures: u8, + read_failures: u8, + config: &'a ResolvedRetryConfig, + request_id: String, +} + +impl<'a> RetryCounter<'a> { + fn new(config: &'a ResolvedRetryConfig, request_id: String) -> Self { + Self { + request_failures: 0, + connect_failures: 0, + read_failures: 0, + config, + request_id, + } } - pub async fn check_response(&self, response: Response) -> Result { - let status_int: u16 = u16::from(response.status()); - if (400..500).contains(&status_int) { - Err(Error::InvalidInput { - message: Self::rsp_to_str(response).await, - }) - } else if status_int != 200 { - Err(Error::Runtime { - message: Self::rsp_to_str(response).await, + fn check_out_of_retries( + &self, + source: Box, + status_code: Option, + ) -> Result<()> { + if self.request_failures >= self.config.retries + || self.connect_failures >= self.config.connect_retries + || self.read_failures >= self.config.read_retries + { + Err(Error::Retry { + request_id: self.request_id.clone(), + request_failures: self.request_failures, + max_request_failures: self.config.retries, + connect_failures: self.connect_failures, + max_connect_failures: self.config.connect_retries, + read_failures: self.read_failures, + max_read_failures: self.config.read_retries, + source, + status_code, }) } else { - Ok(response) + Ok(()) } } + + fn increment_request_failures(&mut self, source: crate::Error) -> Result<()> { + self.request_failures += 1; + let status_code = if let crate::Error::Http { status_code, .. } = &source { + *status_code + } else { + None + }; + self.check_out_of_retries(Box::new(source), status_code) + } + + fn increment_connect_failures(&mut self, source: reqwest::Error) -> Result<()> { + self.connect_failures += 1; + let status_code = source.status(); + self.check_out_of_retries(Box::new(source), status_code) + } + + fn increment_read_failures(&mut self, source: reqwest::Error) -> Result<()> { + self.read_failures += 1; + let status_code = source.status(); + self.check_out_of_retries(Box::new(source), status_code) + } + + fn next_sleep_time(&self) -> Duration { + let backoff = self.config.backoff_factor * (2.0f32.powi(self.request_failures as i32)); + let jitter = rand::random::() * self.config.backoff_jitter; + let sleep_time = Duration::from_secs_f32(backoff + jitter); + debug!( + "Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}", + self.request_id, + self.connect_failures, + self.config.connect_retries, + self.request_failures, + self.config.retries, + self.read_failures, + self.config.read_retries, + sleep_time + ); + sleep_time + } +} + +pub trait RequestResultExt { + type Output; + fn err_to_http(self, request_id: String) -> Result; +} + +impl RequestResultExt for reqwest::Result { + type Output = T; + fn err_to_http(self, request_id: String) -> Result { + self.map_err(|err| { + let status_code = err.status(); + Error::Http { + source: Box::new(err), + request_id, + status_code, + } + }) + } } #[cfg(test)] diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 036e5e7c..8fe415be 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -29,7 +29,7 @@ use crate::embeddings::EmbeddingRegistry; use crate::error::Result; use crate::Table; -use super::client::{ClientConfig, HttpSend, RestfulLanceDbClient, Sender}; +use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender}; use super::table::RemoteTable; use super::util::batches_to_ipc_bytes; use super::ARROW_STREAM_CONTENT_TYPE; @@ -105,9 +105,13 @@ impl ConnectionInternal for RemoteDatabase { if let Some(start_after) = options.start_after { req = req.query(&[("page_token", start_after)]); } - let rsp = self.client.send(req, true).await?; - let rsp = self.client.check_response(rsp).await?; - let tables = rsp.json::().await?.tables; + let (request_id, rsp) = self.client.send(req, true).await?; + let rsp = self.client.check_response(&request_id, rsp).await?; + let tables = rsp + .json::() + .await + .err_to_http(request_id)? + .tables; for table in &tables { self.table_cache.insert(table.clone(), ()).await; } @@ -130,13 +134,11 @@ impl ConnectionInternal for RemoteDatabase { .client .post(&format!("/v1/table/{}/create/", options.name)) .body(data_buffer) - .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) - // This is currently expected by LanceDb cloud but will be removed soon. - .header("x-request-id", "na"); - let rsp = self.client.send(req, false).await?; + .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE); + let (request_id, rsp) = self.client.send(req, false).await?; if rsp.status() == StatusCode::BAD_REQUEST { - let body = rsp.text().await?; + let body = rsp.text().await.err_to_http(request_id.clone())?; if body.contains("already exists") { return Err(crate::Error::TableAlreadyExists { name: options.name }); } else { @@ -144,7 +146,7 @@ impl ConnectionInternal for RemoteDatabase { } } - self.client.check_response(rsp).await?; + self.client.check_response(&request_id, rsp).await?; self.table_cache.insert(options.name.clone(), ()).await; @@ -160,11 +162,11 @@ impl ConnectionInternal for RemoteDatabase { let req = self .client .get(&format!("/v1/table/{}/describe/", options.name)); - let resp = self.client.send(req, true).await?; + let (request_id, resp) = self.client.send(req, true).await?; if resp.status() == StatusCode::NOT_FOUND { return Err(crate::Error::TableNotFound { name: options.name }); } - self.client.check_response(resp).await?; + self.client.check_response(&request_id, resp).await?; } Ok(Table::new(Arc::new(RemoteTable::new( @@ -178,8 +180,8 @@ impl ConnectionInternal for RemoteDatabase { .client .post(&format!("/v1/table/{}/rename/", current_name)); let req = req.json(&serde_json::json!({ "new_table_name": new_name })); - let resp = self.client.send(req, false).await?; - self.client.check_response(resp).await?; + let (request_id, resp) = self.client.send(req, false).await?; + self.client.check_response(&request_id, resp).await?; self.table_cache.remove(current_name).await; self.table_cache.insert(new_name.into(), ()).await; Ok(()) @@ -187,8 +189,8 @@ impl ConnectionInternal for RemoteDatabase { async fn drop_table(&self, name: &str) -> Result<()> { let req = self.client.post(&format!("/v1/table/{}/drop/", name)); - let resp = self.client.send(req, true).await?; - self.client.check_response(resp).await?; + let (request_id, resp) = self.client.send(req, true).await?; + self.client.check_response(&request_id, resp).await?; self.table_cache.remove(name).await; Ok(()) } @@ -206,16 +208,57 @@ impl ConnectionInternal for RemoteDatabase { #[cfg(test)] mod tests { - use std::sync::Arc; + use std::sync::{Arc, OnceLock}; use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; use crate::{ remote::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE}, - Connection, + Connection, Error, }; + #[tokio::test] + async fn test_retries() { + // We'll record the request_id here, to check it matches the one in the error. + let seen_request_id = Arc::new(OnceLock::new()); + let seen_request_id_ref = seen_request_id.clone(); + let conn = Connection::new_with_handler(move |request| { + // Request id should be the same on each retry. + let request_id = request.headers()["x-request-id"] + .to_str() + .unwrap() + .to_string(); + let seen_id = seen_request_id_ref.get_or_init(|| request_id.clone()); + assert_eq!(&request_id, seen_id); + + http::Response::builder() + .status(500) + .body("internal server error") + .unwrap() + }); + let result = conn.table_names().execute().await; + if let Err(Error::Retry { + request_id, + request_failures, + max_request_failures, + source, + .. + }) = result + { + let expected_id = seen_request_id.get().unwrap(); + assert_eq!(&request_id, expected_id); + assert_eq!(request_failures, max_request_failures); + assert!( + source.to_string().contains("internal server error"), + "source: {:?}", + source + ); + } else { + panic!("unexpected result: {:?}", result); + }; + } + #[tokio::test] async fn test_table_names() { let conn = Connection::new_with_handler(|request| { diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index d68907f3..81fb7a90 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -34,6 +34,7 @@ use crate::{ }, }; +use super::client::RequestResultExt; use super::client::{HttpSend, RestfulLanceDbClient, Sender}; use super::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE}; @@ -53,15 +54,25 @@ impl RemoteTable { let request = self .client .post(&format!("/v1/table/{}/describe/", self.name)); - let response = self.client.send(request, true).await?; + let (request_id, response) = self.client.send(request, true).await?; - let response = self.check_table_response(response).await?; + let response = self.check_table_response(&request_id, response).await?; - let body = response.text().await?; - - serde_json::from_str(&body).map_err(|e| Error::Http { - message: format!("Failed to parse table description: {}", e), - }) + match response.text().await { + Ok(body) => serde_json::from_str(&body).map_err(|e| Error::Http { + source: format!("Failed to parse table description: {}", e).into(), + request_id, + status_code: None, + }), + Err(err) => { + let status_code = err.status(); + Err(Error::Http { + source: Box::new(err), + request_id, + status_code, + }) + } + } } fn reader_as_body(data: Box) -> Result { @@ -87,18 +98,23 @@ impl RemoteTable { Ok(reqwest::Body::wrap_stream(body_stream)) } - async fn check_table_response(&self, response: reqwest::Response) -> Result { + async fn check_table_response( + &self, + request_id: &str, + response: reqwest::Response, + ) -> Result { if response.status() == StatusCode::NOT_FOUND { return Err(Error::TableNotFound { name: self.name.clone(), }); } - self.client.check_response(response).await + self.client.check_response(request_id, response).await } async fn read_arrow_stream( &self, + request_id: &str, body: reqwest::Response, ) -> Result { // Assert that the content type is correct @@ -106,24 +122,31 @@ impl RemoteTable { .headers() .get(CONTENT_TYPE) .ok_or_else(|| Error::Http { - message: "Missing content type".into(), + source: "Missing content type".into(), + request_id: request_id.to_string(), + status_code: None, })? .to_str() .map_err(|e| Error::Http { - message: format!("Failed to parse content type: {}", e), + source: format!("Failed to parse content type: {}", e).into(), + request_id: request_id.to_string(), + status_code: None, })?; if content_type != ARROW_STREAM_CONTENT_TYPE { return Err(Error::Http { - message: format!( + source: format!( "Expected content type {}, got {}", ARROW_STREAM_CONTENT_TYPE, content_type - ), + ) + .into(), + request_id: request_id.to_string(), + status_code: None, }); } // There isn't a way to actually stream this data yet. I have an upstream issue: // https://github.com/apache/arrow-rs/issues/6420 - let body = body.bytes().await?; + let body = body.bytes().await.err_to_http(request_id.into())?; let reader = StreamReader::try_new(body.reader(), None)?; let schema = reader.schema(); let stream = futures::stream::iter(reader).map_err(DataFusionError::from); @@ -259,14 +282,16 @@ impl TableInternal for RemoteTable { request = request.json(&serde_json::json!({})); } - let response = self.client.send(request, true).await?; + let (request_id, response) = self.client.send(request, true).await?; - let response = self.check_table_response(response).await?; + let response = self.check_table_response(&request_id, response).await?; - let body = response.text().await?; + let body = response.text().await.err_to_http(request_id.clone())?; serde_json::from_str(&body).map_err(|e| Error::Http { - message: format!("Failed to parse row count: {}", e), + source: format!("Failed to parse row count: {}", e).into(), + request_id, + status_code: None, }) } async fn add( @@ -288,9 +313,9 @@ impl TableInternal for RemoteTable { } } - let response = self.client.send(request, false).await?; + let (request_id, response) = self.client.send(request, false).await?; - self.check_table_response(response).await?; + self.check_table_response(&request_id, response).await?; Ok(()) } @@ -339,9 +364,9 @@ impl TableInternal for RemoteTable { let request = request.json(&body); - let response = self.client.send(request, true).await?; + let (request_id, response) = self.client.send(request, true).await?; - let stream = self.read_arrow_stream(response).await?; + let stream = self.read_arrow_stream(&request_id, response).await?; Ok(Arc::new(OneShotExec::new(stream))) } @@ -361,9 +386,9 @@ impl TableInternal for RemoteTable { let request = request.json(&body); - let response = self.client.send(request, true).await?; + let (request_id, response) = self.client.send(request, true).await?; - let stream = self.read_arrow_stream(response).await?; + let stream = self.read_arrow_stream(&request_id, response).await?; Ok(DatasetRecordBatchStream::new(stream)) } @@ -383,17 +408,20 @@ impl TableInternal for RemoteTable { "only_if": update.filter, })); - let response = self.client.send(request, false).await?; + let (request_id, response) = self.client.send(request, false).await?; - let response = self.check_table_response(response).await?; + let response = self.check_table_response(&request_id, response).await?; - let body = response.text().await?; + let body = response.text().await.err_to_http(request_id.clone())?; serde_json::from_str(&body).map_err(|e| Error::Http { - message: format!( + source: format!( "Failed to parse updated rows result from response {}: {}", body, e - ), + ) + .into(), + request_id, + status_code: None, }) } async fn delete(&self, predicate: &str) -> Result<()> { @@ -402,8 +430,8 @@ impl TableInternal for RemoteTable { .client .post(&format!("/v1/table/{}/delete/", self.name)) .json(&body); - let response = self.client.send(request, false).await?; - self.check_table_response(response).await?; + let (request_id, response) = self.client.send(request, false).await?; + self.check_table_response(&request_id, response).await?; Ok(()) } @@ -474,9 +502,9 @@ impl TableInternal for RemoteTable { let request = request.json(&body); - let response = self.client.send(request, false).await?; + let (request_id, response) = self.client.send(request, false).await?; - self.check_table_response(response).await?; + self.check_table_response(&request_id, response).await?; Ok(()) } @@ -495,9 +523,9 @@ impl TableInternal for RemoteTable { .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) .body(body); - let response = self.client.send(request, false).await?; + let (request_id, response) = self.client.send(request, false).await?; - self.check_table_response(response).await?; + self.check_table_response(&request_id, response).await?; Ok(()) } @@ -531,8 +559,8 @@ impl TableInternal for RemoteTable { let request = self .client .post(&format!("/v1/table/{}/index/list/", self.name)); - let response = self.client.send(request, true).await?; - let response = self.check_table_response(response).await?; + let (request_id, response) = self.client.send(request, true).await?; + let response = self.check_table_response(&request_id, response).await?; #[derive(Deserialize)] struct ListIndicesResponse { @@ -545,12 +573,15 @@ impl TableInternal for RemoteTable { columns: Vec, } - let body = response.text().await?; + let body = response.text().await.err_to_http(request_id.clone())?; let body: ListIndicesResponse = serde_json::from_str(&body).map_err(|err| Error::Http { - message: format!( + source: format!( "Failed to parse list_indices response: {}, body: {}", err, body - ), + ) + .into(), + request_id, + status_code: None, })?; // Make request to get stats for each index, so we get the index type. @@ -581,18 +612,20 @@ impl TableInternal for RemoteTable { "/v1/table/{}/index/{}/stats/", self.name, index_name )); - let response = self.client.send(request, true).await?; + let (request_id, response) = self.client.send(request, true).await?; if response.status() == StatusCode::NOT_FOUND { return Ok(None); } - let response = self.check_table_response(response).await?; + let response = self.check_table_response(&request_id, response).await?; - let body = response.text().await?; + let body = response.text().await.err_to_http(request_id.clone())?; let stats = serde_json::from_str(&body).map_err(|e| Error::Http { - message: format!("Failed to parse index statistics: {}", e), + source: format!("Failed to parse index statistics: {}", e).into(), + request_id, + status_code: None, })?; Ok(Some(stats))