mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 14:49:57 +00:00
invalidate cached dataset after create_index and add
This commit is contained in:
@@ -50,6 +50,12 @@ class LanceTable:
|
||||
self._conn = connection
|
||||
self.name = name
|
||||
|
||||
def _reset_dataset(self):
|
||||
try:
|
||||
del self.__dict__["_dataset"]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def schema(self) -> pa.Schema:
|
||||
"""Return the schema of the table."""
|
||||
@@ -92,12 +98,13 @@ class LanceTable:
|
||||
The number of PQ sub-vectors to use when creating the index.
|
||||
Default is 96.
|
||||
"""
|
||||
return self._dataset.create_index(
|
||||
self._dataset.create_index(
|
||||
column=VECTOR_COLUMN_NAME,
|
||||
index_type="IVF_PQ",
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
)
|
||||
self._reset_dataset()
|
||||
|
||||
@cached_property
|
||||
def _dataset(self) -> LanceDataset:
|
||||
@@ -123,8 +130,9 @@ class LanceTable:
|
||||
The number of vectors added to the table.
|
||||
"""
|
||||
data = _sanitize_data(data, self.schema)
|
||||
ds = lance.write_dataset(data, self._dataset_uri, mode=mode)
|
||||
return ds.count_rows()
|
||||
lance.write_dataset(data, self._dataset_uri, mode=mode)
|
||||
self._reset_dataset()
|
||||
return len(self)
|
||||
|
||||
def search(self, query: VEC) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
|
||||
@@ -99,14 +99,18 @@ def test_add(db):
|
||||
|
||||
expected = pa.Table.from_arrays(
|
||||
[
|
||||
pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2),
|
||||
pa.array(["foo", "bar"]),
|
||||
pa.array([10.0, 20.0]),
|
||||
pa.FixedSizeListArray.from_arrays(
|
||||
pa.array([3.1, 4.1, 5.9, 26.5, 6.3, 100.5]), 2
|
||||
),
|
||||
pa.array(["foo", "bar", "new"]),
|
||||
pa.array([10.0, 20.0, 30.0]),
|
||||
],
|
||||
schema=pa.schema([
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.string()),
|
||||
pa.field("price", pa.float64()),
|
||||
]),
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.string()),
|
||||
pa.field("price", pa.float64()),
|
||||
]
|
||||
),
|
||||
)
|
||||
assert expected == table.to_arrow()
|
||||
|
||||
Reference in New Issue
Block a user