diff --git a/python/lancedb/embeddings.py b/python/lancedb/embeddings.py index b0457ae3..2634dd18 100644 --- a/python/lancedb/embeddings.py +++ b/python/lancedb/embeddings.py @@ -12,7 +12,6 @@ # limitations under the License. import math -import ratelimiter from retry import retry from typing import Callable, Union @@ -32,7 +31,8 @@ def with_embeddings( ): func = EmbeddingFunction(func) if wrap_api: - func = func.retry().rate_limit().batch_size(batch_size) + func = func.retry().rate_limit() + func = func.batch_size(batch_size) if show_progress: func = func.show_progress() if isinstance(data, pd.DataFrame): @@ -64,6 +64,8 @@ class EmbeddingFunction: return self.func(c.tolist()) if len(self.rate_limiter_kwargs) > 0: + import ratelimiter + max_calls = self.rate_limiter_kwargs["max_calls"] limiter = ratelimiter.RateLimiter( max_calls, period=self.rate_limiter_kwargs["period"] diff --git a/python/tests/test_embeddings.py b/python/tests/test_embeddings.py index 2740ecb2..39938356 100644 --- a/python/tests/test_embeddings.py +++ b/python/tests/test_embeddings.py @@ -10,6 +10,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys + import numpy as np import pyarrow as pa @@ -21,16 +23,20 @@ def mock_embed_func(input_data): def test_with_embeddings(): - data = pa.Table.from_arrays( - [ - pa.array(["foo", "bar"]), - pa.array([10.0, 20.0]), - ], - names=["text", "price"], - ) - data = with_embeddings(mock_embed_func, data) - assert data.num_columns == 3 - assert data.num_rows == 2 - assert data.column_names == ["text", "price", "vector"] - assert data.column("text").to_pylist() == ["foo", "bar"] - assert data.column("price").to_pylist() == [10.0, 20.0] + for wrap_api in [True, False]: + if wrap_api and sys.version_info.minor >= 11: + # ratelimiter package doesn't work on 3.11 + continue + data = pa.Table.from_arrays( + [ + pa.array(["foo", "bar"]), + pa.array([10.0, 20.0]), + ], + names=["text", "price"], + ) + data = with_embeddings(mock_embed_func, data, wrap_api=wrap_api) + assert data.num_columns == 3 + assert data.num_rows == 2 + assert data.column_names == ["text", "price", "vector"] + assert data.column("text").to_pylist() == ["foo", "bar"] + assert data.column("price").to_pylist() == [10.0, 20.0]