diff --git a/python/lancedb/remote/client.py b/python/lancedb/remote/client.py index 35c7d4a6..11e18ab3 100644 --- a/python/lancedb/remote/client.py +++ b/python/lancedb/remote/client.py @@ -13,6 +13,8 @@ import functools +import logging +import os from typing import Any, Callable, Dict, List, Optional, Union from urllib.parse import urljoin @@ -20,6 +22,8 @@ import attrs import pyarrow as pa import requests from pydantic import BaseModel +from requests.adapters import HTTPAdapter +from urllib3 import Retry from lancedb.common import Credential from lancedb.remote import VectorQuery, VectorQueryResult @@ -57,6 +61,10 @@ class RestfulLanceDBClient: @functools.cached_property def session(self) -> requests.Session: sess = requests.Session() + + retry_adapter_instance = retry_adapter(retry_adapter_options()) + sess.mount(urljoin(self.url, "/v1/table/"), retry_adapter_instance) + adapter_class = LanceDBClientHTTPAdapterFactory() sess.mount("https://", adapter_class()) return sess @@ -170,3 +178,72 @@ class RestfulLanceDBClient: """Query a table.""" tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc) return VectorQueryResult(tbl) + + def mount_retry_adapter_for_table(self, table_name: str) -> None: + """ + Adds an http adapter to session that will retry retryable requests to the table. + """ + retry_options = retry_adapter_options(methods=["GET", "POST"]) + retry_adapter_instance = retry_adapter(retry_options) + session = self.session + + session.mount( + urljoin(self.url, f"/v1/table/{table_name}/query/"), retry_adapter_instance + ) + session.mount( + urljoin(self.url, f"/v1/table/{table_name}/describe/"), + retry_adapter_instance, + ) + session.mount( + urljoin(self.url, f"/v1/table/{table_name}/index/list/"), + retry_adapter_instance, + ) + + +def retry_adapter_options(methods=["GET"]) -> Dict[str, Any]: + return { + "retries": int(os.environ.get("LANCE_CLIENT_MAX_RETRIES", "3")), + "connect_retries": int(os.environ.get("LANCE_CLIENT_CONNECT_RETRIES", "3")), + "read_retries": int(os.environ.get("LANCE_CLIENT_READ_RETRIES", "3")), + "backoff_factor": float( + os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_FACTOR", "0.25") + ), + "backoff_jitter": float( + os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_JITTER", "0.25") + ), + "statuses": [ + int(i.strip()) + for i in os.environ.get( + "LANCE_CLIENT_RETRY_STATUSES", "429, 500, 502, 503" + ).split(",") + ], + "methods": methods, + } + + +def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter: + total_retries = options["retries"] + connect_retries = options["connect_retries"] + read_retries = options["read_retries"] + backoff_factor = options["backoff_factor"] + backoff_jitter = options["backoff_jitter"] + statuses = options["statuses"] + methods = frozenset(options["methods"]) + logging.debug( + f"Setting up retry adapter with {total_retries} retries," # noqa G003 + + f"connect retries {connect_retries}, read retries {read_retries}," + + f"backoff factor {backoff_factor}, statuses {statuses}, " + + f"methods {methods}" + ) + + return HTTPAdapter( + max_retries=Retry( + total=total_retries, + connect=connect_retries, + read=read_retries, + backoff_factor=backoff_factor, + backoff_jitter=backoff_jitter, + status_forcelist=statuses, + allowed_methods=methods, + ) + ) diff --git a/python/lancedb/remote/db.py b/python/lancedb/remote/db.py index 3a88152f..2b84f6c2 100644 --- a/python/lancedb/remote/db.py +++ b/python/lancedb/remote/db.py @@ -95,6 +95,8 @@ class RemoteDBConnection(DBConnection): """ from .table import RemoteTable + self._client.mount_retry_adapter_for_table(name) + # check if table exists try: self._client.post(f"/v1/table/{name}/describe/") diff --git a/python/tests/test_remote_db.py b/python/tests/test_remote_db.py index d4928c6a..bca4451f 100644 --- a/python/tests/test_remote_db.py +++ b/python/tests/test_remote_db.py @@ -29,6 +29,9 @@ class FakeLanceDBClient: def post(self, path: str): pass + def mount_retry_adapter_for_table(self, table_name: str): + pass + def test_remote_db(): conn = lancedb.connect("db://client-will-be-injected", api_key="fake")