diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 9a2bf395..e34fdc15 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -397,14 +397,6 @@ class LanceTable(Table): self.name = name self._version = version - def _reset_dataset(self, version=None): - try: - if "_dataset" in self.__dict__: - del self.__dict__["_dataset"] - self._version = version - except AttributeError: - pass - @property def schema(self) -> pa.Schema: """Return the schema of the table. @@ -458,7 +450,6 @@ class LanceTable(Table): max_ver = max([v["version"] for v in self._dataset.versions()]) if version < 1 or version > max_ver: raise ValueError(f"Invalid version {version}") - self._reset_dataset(version=version) try: # Accessing the property updates the cached value @@ -518,7 +509,6 @@ class LanceTable(Table): return self._dataset.restore() - self._reset_dataset() def __len__(self): return self._dataset.count_rows() @@ -575,7 +565,6 @@ class LanceTable(Table): accelerator=accelerator, index_cache_size=index_cache_size, ) - self._reset_dataset() register_event("create_index") def create_fts_index( @@ -661,9 +650,8 @@ class LanceTable(Table): metadata=self.schema.metadata, on_bad_vectors=on_bad_vectors, fill_value=fill_value, - ) - lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode) - self._reset_dataset() + ) + self._dataset.write(data, mode=mode) register_event("add") def merge( @@ -727,7 +715,6 @@ class LanceTable(Table): self._dataset.merge( other_table, left_on=left_on, right_on=right_on, schema=schema ) - self._reset_dataset() register_event("merge") @cached_property @@ -985,7 +972,6 @@ class LanceTable(Table): values_sql = {k: value_to_sql(v) for k, v in values.items()} self.to_lance().update(values_sql, where) - self._reset_dataset() register_event("update") def _execute_query(self, query: Query) -> pa.Table: diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 3ae193a9..3f06fede 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -226,39 +226,38 @@ def test_versioning(db): def test_create_index_method(): - with patch.object(LanceTable, "_reset_dataset", return_value=None): - with patch.object( - LanceTable, "_dataset", new_callable=PropertyMock - ) as mock_dataset: - # Setup mock responses - mock_dataset.return_value.create_index.return_value = None + with patch.object( + LanceTable, "_dataset", new_callable=PropertyMock + ) as mock_dataset: + # Setup mock responses + mock_dataset.return_value.create_index.return_value = None - # Create a LanceTable object - connection = LanceDBConnection(uri="mock.uri") - table = LanceTable(connection, "test_table") + # Create a LanceTable object + connection = LanceDBConnection(uri="mock.uri") + table = LanceTable(connection, "test_table") - # Call the create_index method - table.create_index( - metric="L2", - num_partitions=256, - num_sub_vectors=96, - vector_column_name="vector", - replace=True, - index_cache_size=256, - ) + # Call the create_index method + table.create_index( + metric="L2", + num_partitions=256, + num_sub_vectors=96, + vector_column_name="vector", + replace=True, + index_cache_size=256, + ) - # Check that the _dataset.create_index method was called - # with the right parameters - mock_dataset.return_value.create_index.assert_called_once_with( - column="vector", - index_type="IVF_PQ", - metric="L2", - num_partitions=256, - num_sub_vectors=96, - replace=True, - accelerator=None, - index_cache_size=256, - ) + # Check that the _dataset.create_index method was called + # with the right parameters + mock_dataset.return_value.create_index.assert_called_once_with( + column="vector", + index_type="IVF_PQ", + metric="L2", + num_partitions=256, + num_sub_vectors=96, + replace=True, + accelerator=None, + index_cache_size=256, + ) def test_add_with_nans(db):