diff --git a/notebooks/youtube_transcript_search.ipynb b/notebooks/youtube_transcript_search.ipynb index 7174de89..206c277a 100644 --- a/notebooks/youtube_transcript_search.ipynb +++ b/notebooks/youtube_transcript_search.ipynb @@ -499,7 +499,7 @@ { "cell_type": "code", "execution_count": 10, - "id": "421a678d", + "id": "c71f5b31", "metadata": {}, "outputs": [], "source": [ @@ -510,7 +510,7 @@ { "cell_type": "code", "execution_count": 11, - "id": "80b160f0", + "id": "603ba92c", "metadata": {}, "outputs": [], "source": [ @@ -521,7 +521,7 @@ { "cell_type": "code", "execution_count": 12, - "id": "7c3ed619", + "id": "80db5c15", "metadata": {}, "outputs": [], "source": [ diff --git a/python/lancedb/embeddings.py b/python/lancedb/embeddings.py index 0f7c1eb6..b0457ae3 100644 --- a/python/lancedb/embeddings.py +++ b/python/lancedb/embeddings.py @@ -52,23 +52,36 @@ class EmbeddingFunction: def __call__(self, text): # Get the embedding with retry - @retry(**self.retry_kwargs) - def embed_func(c): - return self.func(c.tolist()) + if len(self.retry_kwargs) > 0: - max_calls = self.rate_limiter_kwargs["max_calls"] - limiter = ratelimiter.RateLimiter( - max_calls, period=self.rate_limiter_kwargs["period"] - ) - rate_limited = limiter(embed_func) + @retry(**self.retry_kwargs) + def embed_func(c): + return self.func(c.tolist()) + + else: + + def embed_func(c): + return self.func(c.tolist()) + + if len(self.rate_limiter_kwargs) > 0: + max_calls = self.rate_limiter_kwargs["max_calls"] + limiter = ratelimiter.RateLimiter( + max_calls, period=self.rate_limiter_kwargs["period"] + ) + embed_func = limiter(embed_func) batches = self.to_batches(text) - embeds = [emb for c in batches for emb in rate_limited(c)] + embeds = [emb for c in batches for emb in embed_func(c)] return embeds def __repr__(self): return f"EmbeddingFunction(func={self.func})" def rate_limit(self, max_calls=0.9, period=1.0): + import sys + + v = int(sys.version_info.minor) + if v >= 11: + raise ValueError("rate limit only support up to 3.10") self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period) return self