feat: better errors for remote SDK (#1722)

* Adds nicer errors to remote SDK, that expose useful properties like
`request_id` and `status_code`.
* Makes sure the Python tracebacks print nicely by mapping the `source`
field from a Rust error to the `__cause__` field.
This commit is contained in:
Will Jones
2024-10-08 21:21:13 -07:00
committed by GitHub
parent 607476788e
commit 8509f73221
9 changed files with 622 additions and 193 deletions

View File

@@ -103,19 +103,29 @@ class RestfulLanceDBClient:
@staticmethod
def _check_status(resp: requests.Response):
# Leaving request id empty for now, as we'll be replacing this impl
# with the Rust one shortly.
if resp.status_code == 404:
raise LanceDBClientError(f"Not found: {resp.text}")
raise LanceDBClientError(
f"Not found: {resp.text}", request_id="", status_code=404
)
elif 400 <= resp.status_code < 500:
raise LanceDBClientError(
f"Bad Request: {resp.status_code}, error: {resp.text}"
f"Bad Request: {resp.status_code}, error: {resp.text}",
request_id="",
status_code=resp.status_code,
)
elif 500 <= resp.status_code < 600:
raise LanceDBClientError(
f"Internal Server Error: {resp.status_code}, error: {resp.text}"
f"Internal Server Error: {resp.status_code}, error: {resp.text}",
request_id="",
status_code=resp.status_code,
)
elif resp.status_code != 200:
raise LanceDBClientError(
f"Unknown Error: {resp.status_code}, error: {resp.text}"
f"Unknown Error: {resp.status_code}, error: {resp.text}",
request_id="",
status_code=resp.status_code,
)
@_check_not_closed

View File

@@ -12,5 +12,102 @@
# limitations under the License.
from typing import Optional
class LanceDBClientError(RuntimeError):
"""An error that occurred in the LanceDB client.
Attributes
----------
message: str
The error message.
request_id: str
The id of the request that failed. This can be provided in error reports
to help diagnose the issue.
status_code: int
The HTTP status code of the response. May be None if the request
failed before the response was received.
"""
def __init__(
self, message: str, request_id: str, status_code: Optional[int] = None
):
super().__init__(message)
self.request_id = request_id
self.status_code = status_code
class HttpError(LanceDBClientError):
"""An error that occurred during an HTTP request.
Attributes
----------
message: str
The error message.
request_id: str
The id of the request that failed. This can be provided in error reports
to help diagnose the issue.
status_code: int
The HTTP status code of the response. May be None if the request
failed before the response was received.
"""
pass
class RetryError(LanceDBClientError):
"""An error that occurs when the client has exceeded the maximum number of retries.
The retry strategy can be adjusted by setting the
[retry_config](lancedb.remote.ClientConfig.retry_config) in the client
configuration. This is passed in the `client_config` argument of
[connect](lancedb.connect) and [connect_async](lancedb.connect_async).
The __cause__ attribute of this exception will be the last exception that
caused the retry to fail. It will be an
[HttpError][lancedb.remote.errors.HttpError] instance.
Attributes
----------
message: str
The retry error message, which will describe which retry limit was hit.
request_id: str
The id of the request that failed. This can be provided in error reports
to help diagnose the issue.
request_failures: int
The number of request failures.
connect_failures: int
The number of connect failures.
read_failures: int
The number of read failures.
max_request_failures: int
The maximum number of request failures.
max_connect_failures: int
The maximum number of connect failures.
max_read_failures: int
The maximum number of read failures.
status_code: int
The HTTP status code of the last response. May be None if the request
failed before the response was received.
"""
def __init__(
self,
message: str,
request_id: str,
request_failures: int,
connect_failures: int,
read_failures: int,
max_request_failures: int,
max_connect_failures: int,
max_read_failures: int,
status_code: Optional[int],
):
super().__init__(message, request_id, status_code)
self.request_failures = request_failures
self.connect_failures = connect_failures
self.read_failures = read_failures
self.max_request_failures = max_request_failures
self.max_connect_failures = max_connect_failures
self.max_read_failures = max_read_failures

View File

@@ -354,7 +354,7 @@ async def test_create_mode_async(tmp_path):
)
await db.create_table("test", data=data)
with pytest.raises(RuntimeError):
with pytest.raises(ValueError, match="already exists"):
await db.create_table("test", data=data)
new_data = pd.DataFrame(
@@ -382,7 +382,7 @@ async def test_create_exist_ok_async(tmp_path):
)
tbl = await db.create_table("test", data=data)
with pytest.raises(RuntimeError):
with pytest.raises(ValueError, match="already exists"):
await db.create_table("test", data=data)
# open the table but don't add more rows

View File

@@ -1,12 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import contextlib
import http.server
import threading
from unittest.mock import MagicMock
import uuid
import lancedb
from lancedb.remote.errors import HttpError, RetryError
import pyarrow as pa
from lancedb.remote.client import VectorQuery, VectorQueryResult
import pytest
@@ -98,6 +100,33 @@ def make_mock_http_handler(handler):
return MockLanceDBHandler
@contextlib.asynccontextmanager
async def mock_lancedb_connection(handler):
with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler)
) as server:
handle = threading.Thread(target=server.serve_forever)
handle.start()
db = await lancedb.connect_async(
"db://dev",
api_key="fake",
host_override="http://localhost:8080",
client_config={
"retry_config": {"retries": 2},
"timeout_config": {
"connect_timeout": 1,
},
},
)
try:
yield db
finally:
server.shutdown()
handle.join()
@pytest.mark.asyncio
async def test_async_remote_db():
def handler(request):
@@ -114,28 +143,50 @@ async def test_async_remote_db():
request.end_headers()
request.wfile.write(b'{"tables": []}')
def run_server():
with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler)
) as server:
# we will only make one request
server.handle_request()
async with mock_lancedb_connection(handler) as db:
table_names = await db.table_names()
assert table_names == []
handle = threading.Thread(target=run_server)
handle.start()
db = await lancedb.connect_async(
"db://dev",
api_key="fake",
host_override="http://localhost:8080",
client_config={
"retry_config": {"retries": 2},
"timeout_config": {
"connect_timeout": 1,
},
},
)
table_names = await db.table_names()
assert table_names == []
@pytest.mark.asyncio
async def test_http_error():
request_id_holder = {"request_id": None}
handle.join()
def handler(request):
request_id_holder["request_id"] = request.headers["x-request-id"]
request.send_response(507)
request.end_headers()
request.wfile.write(b"Internal Server Error")
async with mock_lancedb_connection(handler) as db:
with pytest.raises(HttpError, match="Internal Server Error") as exc_info:
await db.table_names()
assert exc_info.value.request_id == request_id_holder["request_id"]
assert exc_info.value.status_code == 507
@pytest.mark.asyncio
async def test_retry_error():
request_id_holder = {"request_id": None}
def handler(request):
request_id_holder["request_id"] = request.headers["x-request-id"]
request.send_response(429)
request.end_headers()
request.wfile.write(b"Try again later")
async with mock_lancedb_connection(handler) as db:
with pytest.raises(RetryError, match="Hit retry limit") as exc_info:
await db.table_names()
assert exc_info.value.request_id == request_id_holder["request_id"]
assert exc_info.value.status_code == 429
cause = exc_info.value.__cause__
assert isinstance(cause, HttpError)
assert "Try again later" in str(cause)
assert cause.request_id == request_id_holder["request_id"]
assert cause.status_code == 429

View File

@@ -14,7 +14,9 @@
use pyo3::{
exceptions::{PyIOError, PyNotImplementedError, PyOSError, PyRuntimeError, PyValueError},
PyResult,
intern,
types::{PyAnyMethods, PyNone},
PyErr, PyResult, Python,
};
use lancedb::error::Error as LanceError;
@@ -38,12 +40,79 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
LanceError::InvalidInput { .. }
| LanceError::InvalidTableName { .. }
| LanceError::TableNotFound { .. }
| LanceError::Schema { .. } => self.value_error(),
| LanceError::Schema { .. }
| LanceError::TableAlreadyExists { .. } => self.value_error(),
LanceError::CreateDir { .. } => self.os_error(),
LanceError::ObjectStore { .. } => Err(PyIOError::new_err(err.to_string())),
LanceError::NotSupported { .. } => {
Err(PyNotImplementedError::new_err(err.to_string()))
}
LanceError::Http {
request_id,
source,
status_code,
} => Python::with_gil(|py| {
let message = err.to_string();
let http_err_cls = py
.import_bound(intern!(py, "lancedb.remote.errors"))?
.getattr(intern!(py, "HttpError"))?;
let err = http_err_cls.call1((
message,
request_id,
status_code.map(|s| s.as_u16()),
))?;
if let Some(cause) = source.source() {
// The HTTP error already includes the first cause. But
// we can add the rest of the chain if there is any more.
let cause_err = http_from_rust_error(
py,
cause,
request_id,
status_code.map(|s| s.as_u16()),
)?;
err.setattr(intern!(py, "__cause__"), cause_err)?;
}
Err(PyErr::from_value_bound(err))
}),
LanceError::Retry {
request_id,
request_failures,
max_request_failures,
connect_failures,
max_connect_failures,
read_failures,
max_read_failures,
source,
status_code,
} => Python::with_gil(|py| {
let cause_err = http_from_rust_error(
py,
source.as_ref(),
request_id,
status_code.map(|s| s.as_u16()),
)?;
let message = err.to_string();
let retry_error_cls = py
.import_bound(intern!(py, "lancedb.remote.errors"))?
.getattr("RetryError")?;
let err = retry_error_cls.call1((
message,
request_id,
*request_failures,
*connect_failures,
*read_failures,
*max_request_failures,
*max_connect_failures,
*max_read_failures,
status_code.map(|s| s.as_u16()),
))?;
err.setattr(intern!(py, "__cause__"), cause_err)?;
Err(PyErr::from_value_bound(err))
}),
_ => self.runtime_error(),
},
}
@@ -61,3 +130,24 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
self.map_err(|err| PyValueError::new_err(err.to_string()))
}
}
fn http_from_rust_error(
py: Python<'_>,
err: &dyn std::error::Error,
request_id: &str,
status_code: Option<u16>,
) -> PyResult<PyErr> {
let message = err.to_string();
let http_err_cls = py.import("lancedb.remote.errors")?.getattr("HttpError")?;
let py_err = http_err_cls.call1((message, request_id, status_code))?;
// Reset the traceback since it doesn't provide additional information.
let py_err = py_err.call_method1(intern!(py, "with_traceback"), (PyNone::get_bound(py),))?;
if let Some(cause) = err.source() {
let cause_err = http_from_rust_error(py, cause, request_id, status_code)?;
py_err.setattr(intern!(py, "__cause__"), cause_err)?;
}
Ok(PyErr::from_value(py_err))
}