mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-03 18:32:55 +00:00
Exponential standoff retry support for handling rate limited embedding functions (#614)
Users ingesting data using rate limited apis don't need to manually make the process sleep for counter rate limits resolves #579
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -38,3 +40,26 @@ class MockTextEmbeddingFunction(TextEmbeddingFunction):
|
||||
|
||||
def ndims(self):
|
||||
return 10
|
||||
|
||||
|
||||
class RateLimitedAPI:
|
||||
rate_limit = 0.1 # 1 request per 0.1 second
|
||||
last_request_time = 0
|
||||
|
||||
@staticmethod
|
||||
def make_request():
|
||||
current_time = time.time()
|
||||
|
||||
if current_time - RateLimitedAPI.last_request_time < RateLimitedAPI.rate_limit:
|
||||
raise Exception("Rate limit exceeded. Please try again later.")
|
||||
|
||||
# Simulate a successful request
|
||||
RateLimitedAPI.last_request_time = current_time
|
||||
return "Request successful"
|
||||
|
||||
|
||||
@registry.register("test-rate-limited")
|
||||
class MockRateLimitedEmbeddingFunction(MockTextEmbeddingFunction):
|
||||
def generate_embeddings(self, texts):
|
||||
RateLimitedAPI.make_request()
|
||||
return [self._compute_one_embedding(row) for row in texts]
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from .utils import TEXT
|
||||
from .utils import TEXT, retry_with_exponential_backoff
|
||||
|
||||
|
||||
class EmbeddingFunction(BaseModel, ABC):
|
||||
@@ -21,6 +21,9 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
3. ndims method which returns the number of dimensions of the vector column
|
||||
"""
|
||||
|
||||
max_retries: int = (
|
||||
7 # Setitng 0 disables retires. Maybe this should not be enabled by default,
|
||||
)
|
||||
_ndims: int = PrivateAttr()
|
||||
|
||||
@classmethod
|
||||
@@ -44,6 +47,25 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_query_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for a given user query with retries
|
||||
"""
|
||||
return retry_with_exponential_backoff(
|
||||
self.compute_query_embeddings, max_retries=self.max_retries
|
||||
)(
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def compute_source_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database with retries
|
||||
"""
|
||||
return retry_with_exponential_backoff(
|
||||
self.compute_source_embeddings, max_retries=self.max_retries
|
||||
)(*args, **kwargs)
|
||||
|
||||
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import urllib.error
|
||||
from typing import Callable, List, Union
|
||||
|
||||
@@ -162,6 +164,55 @@ class FunctionWrapper:
|
||||
yield from _chunker(arr)
|
||||
|
||||
|
||||
def retry_with_exponential_backoff(
|
||||
func,
|
||||
initial_delay: float = 1,
|
||||
exponential_base: float = 2,
|
||||
jitter: bool = True,
|
||||
max_retries: int = 7,
|
||||
# errors: tuple = (),
|
||||
):
|
||||
"""Retry a function with exponential backoff.
|
||||
|
||||
Args:
|
||||
func (function): The function to be retried.
|
||||
initial_delay (float): Initial delay in seconds (default is 1).
|
||||
exponential_base (float): The base for exponential backoff (default is 2).
|
||||
jitter (bool): Whether to add jitter to the delay (default is True).
|
||||
max_retries (int): Maximum number of retries (default is 10).
|
||||
errors (tuple): Tuple of specific exceptions to retry on (default is (openai.error.RateLimitError,)).
|
||||
|
||||
Returns:
|
||||
function: The decorated function.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
num_retries = 0
|
||||
delay = initial_delay
|
||||
|
||||
# Loop until a successful response or max_retries is hit or an exception is raised
|
||||
while True:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Currently retrying on all exceptions as there is no way to know the format of the error msgs used by different APIs
|
||||
# We'll log the error and say that it is assumed that if this portion errors out, it's due to rate limit but the user
|
||||
# should check the error message to be sure
|
||||
except Exception as e:
|
||||
num_retries += 1
|
||||
|
||||
if num_retries > max_retries:
|
||||
raise Exception(
|
||||
f"Maximum number of retries ({max_retries}) exceeded."
|
||||
)
|
||||
|
||||
delay *= exponential_base * (1 + jitter * random.random())
|
||||
LOGGER.info(f"Retrying in {delay:.2f} seconds due to {e}")
|
||||
time.sleep(delay)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def url_retrieve(url: str):
|
||||
"""
|
||||
Parameters
|
||||
|
||||
@@ -140,7 +140,7 @@ class LanceQueryBuilder(ABC):
|
||||
if not isinstance(query, (list, np.ndarray)):
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
query = conf.function.compute_query_embeddings(query)[0]
|
||||
query = conf.function.compute_query_embeddings_with_retry(query)[0]
|
||||
else:
|
||||
msg = f"No embedding function for {vector_column_name}"
|
||||
raise ValueError(msg)
|
||||
@@ -151,7 +151,7 @@ class LanceQueryBuilder(ABC):
|
||||
else:
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
query = conf.function.compute_query_embeddings(query)[0]
|
||||
query = conf.function.compute_query_embeddings_with_retry(query)[0]
|
||||
return query, "vector"
|
||||
else:
|
||||
return query, "fts"
|
||||
|
||||
@@ -86,7 +86,9 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
|
||||
for vector_column, conf in functions.items():
|
||||
func = conf.function
|
||||
if vector_column not in data.column_names:
|
||||
col_data = func.compute_source_embeddings(data[conf.source_column])
|
||||
col_data = func.compute_source_embeddings_with_retry(
|
||||
data[conf.source_column]
|
||||
)
|
||||
if schema is not None:
|
||||
dtype = schema.field(vector_column).type
|
||||
else:
|
||||
|
||||
@@ -15,13 +15,16 @@ import sys
|
||||
import lance
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
import lancedb
|
||||
from lancedb.conftest import MockRateLimitedEmbeddingFunction, MockTextEmbeddingFunction
|
||||
from lancedb.embeddings import (
|
||||
EmbeddingFunctionConfig,
|
||||
EmbeddingFunctionRegistry,
|
||||
with_embeddings,
|
||||
)
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
|
||||
def mock_embed_func(input_data):
|
||||
@@ -83,3 +86,29 @@ def test_embedding_function(tmp_path):
|
||||
expected = func.compute_query_embeddings("hello world")
|
||||
|
||||
assert np.allclose(actual, expected)
|
||||
|
||||
|
||||
def test_embedding_function_rate_limit(tmp_path):
|
||||
def _get_schema_from_model(model):
|
||||
class Schema(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
return Schema
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
model = registry.get("test-rate-limited").create(max_retries=0)
|
||||
schema = _get_schema_from_model(model)
|
||||
table = db.create_table("test", schema=schema, mode="overwrite")
|
||||
table.add([{"text": "hello world"}])
|
||||
with pytest.raises(Exception):
|
||||
table.add([{"text": "hello world"}])
|
||||
assert len(table) == 1
|
||||
|
||||
model = registry.get("test-rate-limited").create()
|
||||
schema = _get_schema_from_model(model)
|
||||
table = db.create_table("test", schema=schema, mode="overwrite")
|
||||
table.add([{"text": "hello world"}])
|
||||
table.add([{"text": "hello world"}])
|
||||
assert len(table) == 2
|
||||
|
||||
Reference in New Issue
Block a user