invalidate cached dataset after create_index and add

This commit is contained in:
Chang She
2023-04-18 16:39:48 -07:00
parent 3ba7fa15a4
commit f0ea1d898b
2 changed files with 23 additions and 11 deletions

View File

@@ -50,6 +50,12 @@ class LanceTable:
self._conn = connection
self.name = name
def _reset_dataset(self):
try:
del self.__dict__["_dataset"]
except AttributeError:
pass
@property
def schema(self) -> pa.Schema:
"""Return the schema of the table."""
@@ -92,12 +98,13 @@ class LanceTable:
The number of PQ sub-vectors to use when creating the index.
Default is 96.
"""
return self._dataset.create_index(
self._dataset.create_index(
column=VECTOR_COLUMN_NAME,
index_type="IVF_PQ",
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
)
self._reset_dataset()
@cached_property
def _dataset(self) -> LanceDataset:
@@ -123,8 +130,9 @@ class LanceTable:
The number of vectors added to the table.
"""
data = _sanitize_data(data, self.schema)
ds = lance.write_dataset(data, self._dataset_uri, mode=mode)
return ds.count_rows()
lance.write_dataset(data, self._dataset_uri, mode=mode)
self._reset_dataset()
return len(self)
def search(self, query: VEC) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors

View File

@@ -99,14 +99,18 @@ def test_add(db):
expected = pa.Table.from_arrays(
[
pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2),
pa.array(["foo", "bar"]),
pa.array([10.0, 20.0]),
pa.FixedSizeListArray.from_arrays(
pa.array([3.1, 4.1, 5.9, 26.5, 6.3, 100.5]), 2
),
pa.array(["foo", "bar", "new"]),
pa.array([10.0, 20.0, 30.0]),
],
schema=pa.schema([
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.string()),
pa.field("price", pa.float64()),
]),
schema=pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.string()),
pa.field("price", pa.float64()),
]
),
)
assert expected == table.to_arrow()