diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 0b8fa9c6..c9170e2d 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -23,6 +23,7 @@ import lance import numpy as np import pyarrow as pa import pyarrow.compute as pc +import pyarrow.fs as pa_fs from lance import LanceDataset from lance.vector import vec_to_table @@ -575,7 +576,9 @@ class LanceTable(Table): ) self._reset_dataset() - def create_fts_index(self, field_names: Union[str, List[str]]): + def create_fts_index( + self, field_names: Union[str, List[str]], *, replace: bool = False + ): """Create a full-text search index on the table. Warning - this API is highly experimental and is highly likely to change @@ -585,11 +588,25 @@ class LanceTable(Table): ---------- field_names: str or list of str The name(s) of the field to index. + replace: bool, default False + If True, replace the existing index if it exists. Note that this is + not yet an atomic operation; the index will be temporarily + unavailable while the new index is being created. """ from .fts import create_index, populate_index if isinstance(field_names, str): field_names = [field_names] + + fs, path = fs_from_uri(self._get_fts_index_path()) + index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound + if index_exists: + if not replace: + raise ValueError( + f"Index already exists. Use replace=True to overwrite." + ) + fs.delete_dir(path) + index = create_index(self._get_fts_index_path(), field_names) populate_index(index, self, field_names) diff --git a/python/tests/test_fts.py b/python/tests/test_fts.py index 301a4d23..2a61f3ca 100644 --- a/python/tests/test_fts.py +++ b/python/tests/test_fts.py @@ -83,6 +83,24 @@ def test_create_index_from_table(tmp_path, table): assert len(df) == 10 assert "text" in df.columns + # Check whether it can be updated + table.add( + [ + { + "vector": np.random.randn(128), + "text": "gorilla", + "text2": "gorilla", + "nested": {"text": "gorilla"}, + } + ] + ) + + with pytest.raises(ValueError, match="already exists"): + table.create_fts_index("text") + + table.create_fts_index("text", replace=True) + assert len(table.search("gorilla").limit(1).to_pandas()) == 1 + def test_create_index_multiple_columns(tmp_path, table): table.create_fts_index(["text", "text2"])