[Python] Support replace during create_index (#233)

Closes #214
This commit is contained in:
Lei Xu
2023-06-27 16:02:07 -07:00
committed by GitHub
parent c68c236f17
commit 4bc676e26a
3 changed files with 40 additions and 1 deletions

View File

@@ -188,13 +188,15 @@ class LanceTable:
num_partitions=256,
num_sub_vectors=96,
vector_column_name=VECTOR_COLUMN_NAME,
replace: bool = True,
):
"""Create an index on the table.
Parameters
----------
metric: str, default "L2"
The distance metric to use when creating the index. Valid values are "L2" or "cosine".
The distance metric to use when creating the index.
Valid values are "L2", "cosine", or "dot".
L2 is euclidean distance.
num_partitions: int
The number of IVF partitions to use when creating the index.
@@ -202,6 +204,11 @@ class LanceTable:
num_sub_vectors: int
The number of PQ sub-vectors to use when creating the index.
Default is 96.
vector_column_name: str, default "vector"
The vector column name to create the index.
replace: bool, default True
If True, replace the existing index if it exists.
If False, raise an error if duplicate index exists.
"""
self._dataset.create_index(
column=vector_column_name,
@@ -209,6 +216,7 @@ class LanceTable:
metric=metric,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
replace=replace,
)
self._reset_dataset()

View File

@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pandas as pd
import pytest
@@ -120,3 +121,31 @@ def test_delete_table(tmp_path):
db.create_table("test", data=data)
assert db.table_names() == ["test"]
def test_replace_index(tmp_path):
db = lancedb.connect(uri=tmp_path)
table = db.create_table(
"test",
[
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
for i in range(1000)
],
)
table.create_index(
num_partitions=2,
num_sub_vectors=4,
)
with pytest.raises(Exception):
table.create_index(
num_partitions=2,
num_sub_vectors=4,
replace=False,
)
table.create_index(
num_partitions=2,
num_sub_vectors=4,
replace=True,
)

View File

@@ -164,6 +164,7 @@ def test_create_index_method():
num_partitions=256,
num_sub_vectors=96,
vector_column_name="vector",
replace=True,
)
# Check that the _dataset.create_index method was called with the right parameters
@@ -173,4 +174,5 @@ def test_create_index_method():
metric="L2",
num_partitions=256,
num_sub_vectors=96,
replace=True,
)