From f0ea1d898b3264802ca732a4f0da467faba2b282 Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Tue, 18 Apr 2023 16:39:48 -0700 Subject: [PATCH] invalidate cached dataset after create_index and add --- python/lancedb/table.py | 14 +++++++++++--- python/tests/test_table.py | 20 ++++++++++++-------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/python/lancedb/table.py b/python/lancedb/table.py index abff62fe..5a79306c 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -50,6 +50,12 @@ class LanceTable: self._conn = connection self.name = name + def _reset_dataset(self): + try: + del self.__dict__["_dataset"] + except AttributeError: + pass + @property def schema(self) -> pa.Schema: """Return the schema of the table.""" @@ -92,12 +98,13 @@ class LanceTable: The number of PQ sub-vectors to use when creating the index. Default is 96. """ - return self._dataset.create_index( + self._dataset.create_index( column=VECTOR_COLUMN_NAME, index_type="IVF_PQ", num_partitions=num_partitions, num_sub_vectors=num_sub_vectors, ) + self._reset_dataset() @cached_property def _dataset(self) -> LanceDataset: @@ -123,8 +130,9 @@ class LanceTable: The number of vectors added to the table. """ data = _sanitize_data(data, self.schema) - ds = lance.write_dataset(data, self._dataset_uri, mode=mode) - return ds.count_rows() + lance.write_dataset(data, self._dataset_uri, mode=mode) + self._reset_dataset() + return len(self) def search(self, query: VEC) -> LanceQueryBuilder: """Create a search query to find the nearest neighbors diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 56ef0f2a..d5699faa 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -99,14 +99,18 @@ def test_add(db): expected = pa.Table.from_arrays( [ - pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2), - pa.array(["foo", "bar"]), - pa.array([10.0, 20.0]), + pa.FixedSizeListArray.from_arrays( + pa.array([3.1, 4.1, 5.9, 26.5, 6.3, 100.5]), 2 + ), + pa.array(["foo", "bar", "new"]), + pa.array([10.0, 20.0, 30.0]), ], - schema=pa.schema([ - pa.field("vector", pa.list_(pa.float32(), 2)), - pa.field("item", pa.string()), - pa.field("price", pa.float64()), - ]), + schema=pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), 2)), + pa.field("item", pa.string()), + pa.field("price", pa.float64()), + ] + ), ) assert expected == table.to_arrow()