mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 06:19:57 +00:00
Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ba7fa15a4 | ||
|
|
370867836c | ||
|
|
682f09480c | ||
|
|
cd8807bc97 | ||
|
|
41c44ae92e | ||
|
|
6865d66d37 | ||
|
|
aeecd809cc | ||
|
|
3360678d60 | ||
|
|
177eddfc20 | ||
|
|
d735a69b6e | ||
|
|
a2bd2854e1 | ||
|
|
c32b6880e7 | ||
|
|
c6fe5e38f1 | ||
|
|
1c8b52f07b | ||
|
|
f544c5dd31 | ||
|
|
eba533da4f | ||
|
|
404211d4fb | ||
|
|
5d7832c8a5 | ||
|
|
4eba83fdc9 | ||
|
|
6906a5f912 | ||
|
|
69fd80e9f2 |
@@ -41,9 +41,14 @@ pip install lancedb
|
||||
```python
|
||||
import lancedb
|
||||
|
||||
uri = "/tmp/lancedb"
|
||||
db = lancedb.connect(uri)
|
||||
table = db.create_table("my_table",
|
||||
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])
|
||||
result = table.search([100, 100]).limit(2).to_df()
|
||||
```
|
||||
|
||||
## Blogs, Tutorials & Videos
|
||||
* 📈 <a href="https://blog.eto.ai/benchmarking-random-access-in-lance-ed690757a826">2000x better performance with Lance over Parquet</a>
|
||||
* 🤖 <a href="https://github.com/lancedb/lancedb/blob/main/notebooks/youtube_transcript_search.ipynb">Build a question and answer bot with LanceDB</a>
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -12,7 +12,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import ratelimiter
|
||||
from retry import retry
|
||||
from typing import Callable, Union
|
||||
|
||||
@@ -32,11 +31,12 @@ def with_embeddings(
|
||||
):
|
||||
func = EmbeddingFunction(func)
|
||||
if wrap_api:
|
||||
func = func.retry().rate_limit().batch_size(batch_size)
|
||||
func = func.retry().rate_limit()
|
||||
func = func.batch_size(batch_size)
|
||||
if show_progress:
|
||||
func = func.show_progress()
|
||||
if isinstance(data, pd.DataFrame):
|
||||
data = pa.Table.from_pandas(data)
|
||||
data = pa.Table.from_pandas(data, preserve_index=False)
|
||||
embeddings = func(data[column].to_numpy())
|
||||
table = vec_to_table(np.array(embeddings))
|
||||
return data.append_column("vector", table["vector"])
|
||||
@@ -52,23 +52,38 @@ class EmbeddingFunction:
|
||||
|
||||
def __call__(self, text):
|
||||
# Get the embedding with retry
|
||||
@retry(**self.retry_kwargs)
|
||||
def embed_func(c):
|
||||
return self.func(c.tolist())
|
||||
if len(self.retry_kwargs) > 0:
|
||||
|
||||
max_calls = self.rate_limiter_kwargs["max_calls"]
|
||||
limiter = ratelimiter.RateLimiter(
|
||||
max_calls, period=self.rate_limiter_kwargs["period"]
|
||||
)
|
||||
rate_limited = limiter(embed_func)
|
||||
@retry(**self.retry_kwargs)
|
||||
def embed_func(c):
|
||||
return self.func(c.tolist())
|
||||
|
||||
else:
|
||||
|
||||
def embed_func(c):
|
||||
return self.func(c.tolist())
|
||||
|
||||
if len(self.rate_limiter_kwargs) > 0:
|
||||
import ratelimiter
|
||||
|
||||
max_calls = self.rate_limiter_kwargs["max_calls"]
|
||||
limiter = ratelimiter.RateLimiter(
|
||||
max_calls, period=self.rate_limiter_kwargs["period"]
|
||||
)
|
||||
embed_func = limiter(embed_func)
|
||||
batches = self.to_batches(text)
|
||||
embeds = [emb for c in batches for emb in rate_limited(c)]
|
||||
embeds = [emb for c in batches for emb in embed_func(c)]
|
||||
return embeds
|
||||
|
||||
def __repr__(self):
|
||||
return f"EmbeddingFunction(func={self.func})"
|
||||
|
||||
def rate_limit(self, max_calls=0.9, period=1.0):
|
||||
import sys
|
||||
|
||||
v = int(sys.version_info.minor)
|
||||
if v >= 11:
|
||||
raise ValueError("rate limit only support up to 3.10")
|
||||
self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period)
|
||||
return self
|
||||
|
||||
@@ -102,4 +117,4 @@ class EmbeddingFunction:
|
||||
|
||||
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
|
||||
else:
|
||||
return _chunker(arr)
|
||||
yield from _chunker(arr)
|
||||
|
||||
@@ -55,6 +55,27 @@ class LanceTable:
|
||||
"""Return the schema of the table."""
|
||||
return self._dataset.schema
|
||||
|
||||
def __len__(self):
|
||||
return self._dataset.count_rows()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LanceTable({self.name})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr__()
|
||||
|
||||
def head(self, n=5) -> pa.Table:
|
||||
"""Return the first n rows of the table."""
|
||||
return self._dataset.head(n)
|
||||
|
||||
def to_pandas(self) -> pd.DataFrame:
|
||||
"""Return the table as a pandas DataFrame."""
|
||||
return self.to_arrow().to_pandas()
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""Return the table as a pyarrow Table."""
|
||||
return self._dataset.to_table()
|
||||
|
||||
@property
|
||||
def _dataset_uri(self) -> str:
|
||||
return os.path.join(self._conn.uri, f"{self.name}.lance")
|
||||
@@ -150,6 +171,7 @@ def _sanitize_schema(data: pa.Table, schema: pa.Schema = None) -> pa.Table:
|
||||
return data
|
||||
# cast the columns to the expected types
|
||||
data = data.combine_chunks()
|
||||
data = _sanitize_vector_column(data, vector_column_name=VECTOR_COLUMN_NAME)
|
||||
return pa.Table.from_arrays(
|
||||
[data[name] for name in schema.names], schema=schema
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.0.2"
|
||||
version = "0.0.4"
|
||||
dependencies = ["pylance", "ratelimiter", "retry", "tqdm"]
|
||||
description = "lancedb"
|
||||
authors = [
|
||||
|
||||
42
python/tests/test_embeddings.py
Normal file
42
python/tests/test_embeddings.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb.embeddings import with_embeddings
|
||||
|
||||
|
||||
def mock_embed_func(input_data):
|
||||
return [np.random.randn(128).tolist() for _ in range(len(input_data))]
|
||||
|
||||
|
||||
def test_with_embeddings():
|
||||
for wrap_api in [True, False]:
|
||||
if wrap_api and sys.version_info.minor >= 11:
|
||||
# ratelimiter package doesn't work on 3.11
|
||||
continue
|
||||
data = pa.Table.from_arrays(
|
||||
[
|
||||
pa.array(["foo", "bar"]),
|
||||
pa.array([10.0, 20.0]),
|
||||
],
|
||||
names=["text", "price"],
|
||||
)
|
||||
data = with_embeddings(mock_embed_func, data, wrap_api=wrap_api)
|
||||
assert data.num_columns == 3
|
||||
assert data.num_rows == 2
|
||||
assert data.column_names == ["text", "price", "vector"]
|
||||
assert data.column("text").to_pylist() == ["foo", "bar"]
|
||||
assert data.column("price").to_pylist() == [10.0, 20.0]
|
||||
@@ -46,17 +46,17 @@ def test_basic(db):
|
||||
assert table.to_lance().to_table() == ds.to_table()
|
||||
|
||||
|
||||
def test_add(db):
|
||||
def test_create_table(db):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32())),
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.string()),
|
||||
pa.field("price", pa.float32()),
|
||||
]
|
||||
)
|
||||
expected = pa.Table.from_arrays(
|
||||
[
|
||||
pa.array([[3.1, 4.1], [5.9, 26.5]]),
|
||||
pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2),
|
||||
pa.array(["foo", "bar"]),
|
||||
pa.array([10.0, 20.0]),
|
||||
],
|
||||
@@ -79,3 +79,34 @@ def test_add(db):
|
||||
.to_table()
|
||||
)
|
||||
assert expected == tbl
|
||||
|
||||
|
||||
def test_add(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
|
||||
# table = LanceTable(db, "test")
|
||||
assert len(table) == 2
|
||||
|
||||
count = table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
|
||||
assert count == 3
|
||||
|
||||
expected = pa.Table.from_arrays(
|
||||
[
|
||||
pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2),
|
||||
pa.array(["foo", "bar"]),
|
||||
pa.array([10.0, 20.0]),
|
||||
],
|
||||
schema=pa.schema([
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.string()),
|
||||
pa.field("price", pa.float64()),
|
||||
]),
|
||||
)
|
||||
assert expected == table.to_arrow()
|
||||
|
||||
Reference in New Issue
Block a user