mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-08 21:02:58 +00:00
fix 3.11
This commit is contained in:
@@ -499,7 +499,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "421a678d",
|
||||
"id": "c71f5b31",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -510,7 +510,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "80b160f0",
|
||||
"id": "603ba92c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -521,7 +521,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "7c3ed619",
|
||||
"id": "80db5c15",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
||||
@@ -52,23 +52,36 @@ class EmbeddingFunction:
|
||||
|
||||
def __call__(self, text):
|
||||
# Get the embedding with retry
|
||||
@retry(**self.retry_kwargs)
|
||||
def embed_func(c):
|
||||
return self.func(c.tolist())
|
||||
if len(self.retry_kwargs) > 0:
|
||||
|
||||
max_calls = self.rate_limiter_kwargs["max_calls"]
|
||||
limiter = ratelimiter.RateLimiter(
|
||||
max_calls, period=self.rate_limiter_kwargs["period"]
|
||||
)
|
||||
rate_limited = limiter(embed_func)
|
||||
@retry(**self.retry_kwargs)
|
||||
def embed_func(c):
|
||||
return self.func(c.tolist())
|
||||
|
||||
else:
|
||||
|
||||
def embed_func(c):
|
||||
return self.func(c.tolist())
|
||||
|
||||
if len(self.rate_limiter_kwargs) > 0:
|
||||
max_calls = self.rate_limiter_kwargs["max_calls"]
|
||||
limiter = ratelimiter.RateLimiter(
|
||||
max_calls, period=self.rate_limiter_kwargs["period"]
|
||||
)
|
||||
embed_func = limiter(embed_func)
|
||||
batches = self.to_batches(text)
|
||||
embeds = [emb for c in batches for emb in rate_limited(c)]
|
||||
embeds = [emb for c in batches for emb in embed_func(c)]
|
||||
return embeds
|
||||
|
||||
def __repr__(self):
|
||||
return f"EmbeddingFunction(func={self.func})"
|
||||
|
||||
def rate_limit(self, max_calls=0.9, period=1.0):
|
||||
import sys
|
||||
|
||||
v = int(sys.version_info.minor)
|
||||
if v >= 11:
|
||||
raise ValueError("rate limit only support up to 3.10")
|
||||
self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period)
|
||||
return self
|
||||
|
||||
|
||||
Reference in New Issue
Block a user