diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 98964aa1..8f26af19 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -61,6 +61,8 @@ jobs: run: | pip install -e . pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985 - pip install pytest pytest-mock + pip install pytest pytest-mock black + - name: Black + run: black --check --diff --no-color --quiet . - name: Run tests - run: pytest -x -v --durations=30 tests \ No newline at end of file + run: pytest -x -v --durations=30 tests diff --git a/python/lancedb/conftest.py b/python/lancedb/conftest.py index ea0e2300..2b2497e6 100644 --- a/python/lancedb/conftest.py +++ b/python/lancedb/conftest.py @@ -1,10 +1,8 @@ -import builtins import os import pytest # import lancedb so we don't have to in every example -import lancedb @pytest.fixture(autouse=True) diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 3414dff4..ae321b21 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -13,7 +13,7 @@ from __future__ import annotations import asyncio -from typing import Literal +from typing import Literal, Union import numpy as np import pandas as pd @@ -48,7 +48,7 @@ class LanceQueryBuilder: def __init__( self, table: "lancedb.table.LanceTable", - query: np.ndarray, + query: Union[np.ndarray, str], vector_column_name: str = VECTOR_COLUMN_NAME, ): self._metric = "L2" diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 4541a32e..82f97730 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -22,6 +22,7 @@ import numpy as np import pandas as pd import pyarrow as pa import pyarrow.compute as pc +import pyarrow.fs from lance import LanceDataset from lance.vector import vec_to_table @@ -95,7 +96,8 @@ class LanceTable: def _reset_dataset(self): try: - del self.__dict__["_dataset"] + if "_dataset" in self.__dict__: + del self.__dict__["_dataset"] except AttributeError: pass @@ -281,6 +283,7 @@ class LanceTable: int The number of vectors in the table. """ + # TODO: manage table listing and metadata separately data = _sanitize_data( data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value ) @@ -326,7 +329,7 @@ class LanceTable: cls, db, name, - data, + data=None, schema=None, mode="create", on_bad_vectors: str = "error", @@ -354,10 +357,12 @@ class LanceTable: The LanceDB instance to create the table in. name: str The name of the table to create. - data: list-of-dict, dict, pd.DataFrame + data: list-of-dict, dict, pd.DataFrame, default None The data to insert into the table. + At least one of `data` or `schema` must be provided. schema: dict, optional The schema of the table. If not provided, the schema is inferred from the data. + At least one of `data` or `schema` must be provided. mode: str, default "create" The mode to use when writing the data. Valid values are "create", "overwrite", and "append". @@ -368,11 +373,16 @@ class LanceTable: The value to use when filling vectors. Only used if on_bad_vectors="fill". """ tbl = LanceTable(db, name) - data = _sanitize_data( - data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value - ) + if data is not None: + data = _sanitize_data( + data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value + ) + else: + if schema is None: + raise ValueError("Either data or schema must be provided") + data = pa.Table.from_pylist([], schema=schema) lance.write_dataset(data, tbl._dataset_uri, mode=mode) - return tbl + return LanceTable(db, name) @classmethod def open(cls, db, name): @@ -384,7 +394,6 @@ class LanceTable: raise FileNotFoundError( f"Table {name} does not exist. Please first call db.create_table({name}, data)" ) - return tbl def delete(self, where: str): diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 019bf22b..a3695d43 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -19,6 +19,7 @@ import numpy as np import pandas as pd import pyarrow as pa import pytest +from lance.vector import vec_to_table from lancedb.db import LanceDBConnection from lancedb.table import LanceTable @@ -89,7 +90,31 @@ def test_create_table(db): assert expected == tbl +def test_empty_table(db): + schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), 2)), + pa.field("item", pa.string()), + pa.field("price", pa.float32()), + ] + ) + tbl = LanceTable.create(db, "test", schema=schema) + data = [ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, + ] + tbl.add(data=data) + + def test_add(db): + schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), 2)), + pa.field("item", pa.string()), + pa.field("price", pa.float64()), + ] + ) + table = LanceTable.create( db, "test", @@ -98,7 +123,19 @@ def test_add(db): {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, ], ) + _add(table, schema) + table = LanceTable.create(db, "test2", schema=schema) + table.add( + data=[ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, + ], + ) + _add(table, schema) + + +def _add(table, schema): # table = LanceTable(db, "test") assert len(table) == 2 @@ -113,13 +150,7 @@ def test_add(db): pa.array(["foo", "bar", "new"]), pa.array([10.0, 20.0, 30.0]), ], - schema=pa.schema( - [ - pa.field("vector", pa.list_(pa.float32(), 2)), - pa.field("item", pa.string()), - pa.field("price", pa.float64()), - ] - ), + schema=schema, ) assert expected == table.to_arrow()