This commit is contained in:
Chang She
2023-03-24 19:00:22 -07:00
parent 5d7832c8a5
commit 404211d4fb
2 changed files with 25 additions and 12 deletions

View File

@@ -499,7 +499,7 @@
{
"cell_type": "code",
"execution_count": 10,
"id": "421a678d",
"id": "c71f5b31",
"metadata": {},
"outputs": [],
"source": [
@@ -510,7 +510,7 @@
{
"cell_type": "code",
"execution_count": 11,
"id": "80b160f0",
"id": "603ba92c",
"metadata": {},
"outputs": [],
"source": [
@@ -521,7 +521,7 @@
{
"cell_type": "code",
"execution_count": 12,
"id": "7c3ed619",
"id": "80db5c15",
"metadata": {},
"outputs": [],
"source": [

View File

@@ -52,23 +52,36 @@ class EmbeddingFunction:
def __call__(self, text):
# Get the embedding with retry
@retry(**self.retry_kwargs)
def embed_func(c):
return self.func(c.tolist())
if len(self.retry_kwargs) > 0:
max_calls = self.rate_limiter_kwargs["max_calls"]
limiter = ratelimiter.RateLimiter(
max_calls, period=self.rate_limiter_kwargs["period"]
)
rate_limited = limiter(embed_func)
@retry(**self.retry_kwargs)
def embed_func(c):
return self.func(c.tolist())
else:
def embed_func(c):
return self.func(c.tolist())
if len(self.rate_limiter_kwargs) > 0:
max_calls = self.rate_limiter_kwargs["max_calls"]
limiter = ratelimiter.RateLimiter(
max_calls, period=self.rate_limiter_kwargs["period"]
)
embed_func = limiter(embed_func)
batches = self.to_batches(text)
embeds = [emb for c in batches for emb in rate_limited(c)]
embeds = [emb for c in batches for emb in embed_func(c)]
return embeds
def __repr__(self):
return f"EmbeddingFunction(func={self.func})"
def rate_limit(self, max_calls=0.9, period=1.0):
import sys
v = int(sys.version_info.minor)
if v >= 11:
raise ValueError("rate limit only support up to 3.10")
self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period)
return self