diff --git a/python/lancedb/table.py b/python/lancedb/table.py index b4e4daae..82a0c229 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -188,13 +188,15 @@ class LanceTable: num_partitions=256, num_sub_vectors=96, vector_column_name=VECTOR_COLUMN_NAME, + replace: bool = True, ): """Create an index on the table. Parameters ---------- metric: str, default "L2" - The distance metric to use when creating the index. Valid values are "L2" or "cosine". + The distance metric to use when creating the index. + Valid values are "L2", "cosine", or "dot". L2 is euclidean distance. num_partitions: int The number of IVF partitions to use when creating the index. @@ -202,6 +204,11 @@ class LanceTable: num_sub_vectors: int The number of PQ sub-vectors to use when creating the index. Default is 96. + vector_column_name: str, default "vector" + The vector column name to create the index. + replace: bool, default True + If True, replace the existing index if it exists. + If False, raise an error if duplicate index exists. """ self._dataset.create_index( column=vector_column_name, @@ -209,6 +216,7 @@ class LanceTable: metric=metric, num_partitions=num_partitions, num_sub_vectors=num_sub_vectors, + replace=replace, ) self._reset_dataset() diff --git a/python/tests/test_db.py b/python/tests/test_db.py index f0fc5bb4..ad1645aa 100644 --- a/python/tests/test_db.py +++ b/python/tests/test_db.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import pandas as pd import pytest @@ -120,3 +121,31 @@ def test_delete_table(tmp_path): db.create_table("test", data=data) assert db.table_names() == ["test"] + + +def test_replace_index(tmp_path): + db = lancedb.connect(uri=tmp_path) + table = db.create_table( + "test", + [ + {"vector": np.random.rand(128), "item": "foo", "price": float(i)} + for i in range(1000) + ], + ) + table.create_index( + num_partitions=2, + num_sub_vectors=4, + ) + + with pytest.raises(Exception): + table.create_index( + num_partitions=2, + num_sub_vectors=4, + replace=False, + ) + + table.create_index( + num_partitions=2, + num_sub_vectors=4, + replace=True, + ) diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 6021dc2b..e2e9b64b 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -164,6 +164,7 @@ def test_create_index_method(): num_partitions=256, num_sub_vectors=96, vector_column_name="vector", + replace=True, ) # Check that the _dataset.create_index method was called with the right parameters @@ -173,4 +174,5 @@ def test_create_index_method(): metric="L2", num_partitions=256, num_sub_vectors=96, + replace=True, )