This commit is contained in:
Chang She
2023-03-24 19:45:46 -07:00
parent 404211d4fb
commit eba533da4f
2 changed files with 23 additions and 15 deletions

View File

@@ -12,7 +12,6 @@
# limitations under the License.
import math
import ratelimiter
from retry import retry
from typing import Callable, Union
@@ -32,7 +31,8 @@ def with_embeddings(
):
func = EmbeddingFunction(func)
if wrap_api:
func = func.retry().rate_limit().batch_size(batch_size)
func = func.retry().rate_limit()
func = func.batch_size(batch_size)
if show_progress:
func = func.show_progress()
if isinstance(data, pd.DataFrame):
@@ -64,6 +64,8 @@ class EmbeddingFunction:
return self.func(c.tolist())
if len(self.rate_limiter_kwargs) > 0:
import ratelimiter
max_calls = self.rate_limiter_kwargs["max_calls"]
limiter = ratelimiter.RateLimiter(
max_calls, period=self.rate_limiter_kwargs["period"]

View File

@@ -10,6 +10,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
import pyarrow as pa
@@ -21,16 +23,20 @@ def mock_embed_func(input_data):
def test_with_embeddings():
data = pa.Table.from_arrays(
[
pa.array(["foo", "bar"]),
pa.array([10.0, 20.0]),
],
names=["text", "price"],
)
data = with_embeddings(mock_embed_func, data)
assert data.num_columns == 3
assert data.num_rows == 2
assert data.column_names == ["text", "price", "vector"]
assert data.column("text").to_pylist() == ["foo", "bar"]
assert data.column("price").to_pylist() == [10.0, 20.0]
for wrap_api in [True, False]:
if wrap_api and sys.version_info.minor >= 11:
# ratelimiter package doesn't work on 3.11
continue
data = pa.Table.from_arrays(
[
pa.array(["foo", "bar"]),
pa.array([10.0, 20.0]),
],
names=["text", "price"],
)
data = with_embeddings(mock_embed_func, data, wrap_api=wrap_api)
assert data.num_columns == 3
assert data.num_rows == 2
assert data.column_names == ["text", "price", "vector"]
assert data.column("text").to_pylist() == ["foo", "bar"]
assert data.column("price").to_pylist() == [10.0, 20.0]