update for release

This commit is contained in:
Chang She
2023-03-24 18:16:29 -07:00
parent 4eba83fdc9
commit 5d7832c8a5
4 changed files with 320 additions and 60 deletions

View File

@@ -36,7 +36,7 @@ def with_embeddings(
if show_progress:
func = func.show_progress()
if isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data)
data = pa.Table.from_pandas(data, preserve_index=False)
embeddings = func(data[column].to_numpy())
table = vec_to_table(np.array(embeddings))
return data.append_column("vector", table["vector"])
@@ -102,4 +102,4 @@ class EmbeddingFunction:
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
else:
return _chunker(arr)
yield from _chunker(arr)

View File

@@ -55,6 +55,27 @@ class LanceTable:
"""Return the schema of the table."""
return self._dataset.schema
def __len__(self):
return self._dataset.count_rows()
def __repr__(self) -> str:
return f"LanceTable({self.name})"
def __str__(self) -> str:
return self.__repr__()
def head(self, n=5) -> pa.Table:
"""Return the first n rows of the table."""
return self._dataset.head(n)
def to_pandas(self) -> pd.DataFrame:
"""Return the table as a pandas DataFrame."""
return self.to_arrow().to_pandas()
def to_arrow(self) -> pa.Table:
"""Return the table as a pyarrow Table."""
return self._dataset.to_table()
@property
def _dataset_uri(self) -> str:
return os.path.join(self._conn.uri, f"{self.name}.lance")

View File

@@ -0,0 +1,36 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pyarrow as pa
from lancedb.embeddings import with_embeddings
def mock_embed_func(input_data):
return [np.random.randn(128).tolist() for _ in range(len(input_data))]
def test_with_embeddings():
data = pa.Table.from_arrays(
[
pa.array(["foo", "bar"]),
pa.array([10.0, 20.0]),
],
names=["text", "price"],
)
data = with_embeddings(mock_embed_func, data)
assert data.num_columns == 3
assert data.num_rows == 2
assert data.column_names == ["text", "price", "vector"]
assert data.column("text").to_pylist() == ["foo", "bar"]
assert data.column("price").to_pylist() == [10.0, 20.0]