feat!: migrate FTS from tantivy to lance-index (#1483)

Lance now supports FTS, so add it into lancedb Python, TypeScript and
Rust SDKs.

For Python, we still use tantivy based FTS by default because the lance
FTS index now misses some features of tantivy.

For Python:
- Support to create lance based FTS index
- Support to specify columns for full text search (only available for
lance based FTS index)

For TypeScript:
- Change the search method so that it can accept both string and vector
- Support full text search

For Rust
- Support full text search

The others:
- Update the FTS doc

BREAKING CHANGE: 
- for Python, this renames the attached score column of FTS from "score"
to "_score", this could be a breaking change for users that rely the
scores

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
BubbleCal
2024-08-08 15:33:15 +08:00
committed by GitHub
parent 4db554eea5
commit f9d5fa88a1
34 changed files with 713 additions and 145 deletions

View File

@@ -99,6 +99,9 @@ class Query(pydantic.BaseModel):
# if True then apply the filter before vector search
prefilter: bool = False
# full text search query
full_text_query: Optional[Union[str, dict]] = None
# top k results to return
k: int
@@ -131,6 +134,7 @@ class LanceQueryBuilder(ABC):
query_type: str,
vector_column_name: str,
ordering_field_name: str = None,
fts_columns: Union[str, List[str]] = None,
) -> LanceQueryBuilder:
"""
Create a query builder based on the given query and query type.
@@ -226,6 +230,7 @@ class LanceQueryBuilder(ABC):
self._limit = 10
self._columns = None
self._where = None
self._prefilter = False
self._with_row_id = False
@deprecation.deprecated(
@@ -664,12 +669,19 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
class LanceFtsQueryBuilder(LanceQueryBuilder):
"""A builder for full text search for LanceDB."""
def __init__(self, table: "Table", query: str, ordering_field_name: str = None):
def __init__(
self,
table: "Table",
query: str,
ordering_field_name: str = None,
fts_columns: Union[str, List[str]] = None,
):
super().__init__(table)
self._query = query
self._phrase_query = False
self.ordering_field_name = ordering_field_name
self._reranker = None
self._fts_columns = fts_columns
def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder:
"""Set whether to use phrase query.
@@ -689,6 +701,35 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
return self
def to_arrow(self) -> pa.Table:
tantivy_index_path = self._table._get_fts_index_path()
if Path(tantivy_index_path).exists():
return self.tantivy_to_arrow()
query = self._query
if self._phrase_query:
raise NotImplementedError(
"Phrase query is not yet supported in Lance FTS. "
"Use tantivy-based index instead for now."
)
if self._reranker:
raise NotImplementedError(
"Reranking is not yet supported in Lance FTS. "
"Use tantivy-based index instead for now."
)
ds = self._table.to_lance()
return ds.to_table(
columns=self._columns,
filter=self._where,
limit=self._limit,
prefilter=self._prefilter,
with_row_id=self._with_row_id,
full_text_query={
"query": query,
"columns": self._fts_columns,
},
)
def tantivy_to_arrow(self) -> pa.Table:
try:
import tantivy
except ImportError:
@@ -726,11 +767,11 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
index, query, self._limit, ordering_field=self.ordering_field_name
)
if len(row_ids) == 0:
empty_schema = pa.schema([pa.field("score", pa.float32())])
empty_schema = pa.schema([pa.field("_score", pa.float32())])
return pa.Table.from_pylist([], schema=empty_schema)
scores = pa.array(scores)
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
output_tbl = output_tbl.append_column("score", scores)
output_tbl = output_tbl.append_column("_score", scores)
# this needs to match vector search results which are uint64
row_ids = pa.array(row_ids, type=pa.uint64())
@@ -784,8 +825,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
LanceFtsQueryBuilder
The LanceQueryBuilder object.
"""
self._reranker = reranker
return self
raise NotImplementedError("Reranking is not yet supported for FTS queries.")
class LanceEmptyQueryBuilder(LanceQueryBuilder):
@@ -856,13 +896,13 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
# convert to ranks first if needed
if self._norm == "rank":
vector_results = self._rank(vector_results, "_distance")
fts_results = self._rank(fts_results, "score")
fts_results = self._rank(fts_results, "_score")
# normalize the scores to be between 0 and 1, 0 being most relevant
vector_results = self._normalize_scores(vector_results, "_distance")
# In fts higher scores represent relevance. Not inverting them here as
# rerankers might need to preserve this score to support `return_score="all"`
fts_results = self._normalize_scores(fts_results, "score")
fts_results = self._normalize_scores(fts_results, "_score")
results = self._reranker.rerank_hybrid(
self._fts_query._query, vector_results, fts_results

View File

@@ -220,8 +220,8 @@ class Reranker(ABC):
def _keep_relevance_score(self, combined_results: pa.Table):
if self.score == "relevance":
if "score" in combined_results.column_names:
combined_results = combined_results.drop_columns(["score"])
if "_score" in combined_results.column_names:
combined_results = combined_results.drop_columns(["_score"])
if "_distance" in combined_results.column_names:
combined_results = combined_results.drop_columns(["_distance"])
return combined_results

View File

@@ -113,6 +113,6 @@ class CohereReranker(Reranker):
):
result_set = self._rerank(fts_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["score"])
result_set = result_set.drop_columns(["_score"])
return result_set

View File

@@ -105,7 +105,7 @@ class ColbertReranker(Reranker):
):
result_set = self._rerank(fts_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["score"])
result_set = result_set.drop_columns(["_score"])
result_set = result_set.sort_by([("_relevance_score", "descending")])

View File

@@ -96,7 +96,7 @@ class CrossEncoderReranker(Reranker):
):
fts_results = self._rerank(fts_results, query)
if self.score == "relevance":
fts_results = fts_results.drop_columns(["score"])
fts_results = fts_results.drop_columns(["_score"])
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
return fts_results

View File

@@ -117,6 +117,6 @@ class JinaReranker(Reranker):
):
result_set = self._rerank(fts_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["score"])
result_set = result_set.drop_columns(["_score"])
return result_set

View File

@@ -69,12 +69,12 @@ class LinearCombinationReranker(Reranker):
vi = vector_list[i]
fj = fts_list[j]
# invert the fts score from relevance to distance
inverted_fts_score = self._invert_score(fj["score"])
inverted_fts_score = self._invert_score(fj["_score"])
if vi["_rowid"] == fj["_rowid"]:
vi["_relevance_score"] = self._combine_score(
vi["_distance"], inverted_fts_score
)
vi["score"] = fj["score"] # keep the original score
vi["_score"] = fj["_score"] # keep the original score
combined_list.append(vi)
i += 1
j += 1

View File

@@ -108,7 +108,7 @@ class OpenaiReranker(Reranker):
def rerank_fts(self, query: str, fts_results: pa.Table):
fts_results = self._rerank(fts_results, query)
if self.score == "relevance":
fts_results = fts_results.drop_columns(["score"])
fts_results = fts_results.drop_columns(["_score"])
fts_results = fts_results.sort_by([("_relevance_score", "descending")])

View File

@@ -59,7 +59,6 @@ from .util import (
if TYPE_CHECKING:
import PIL
from lance.dataset import CleanupStats, ReaderLike
from ._lancedb import Table as LanceDBTable, OptimizeStats
from .db import LanceDBConnection
from .index import BTree, IndexConfig, IvfPq
@@ -350,6 +349,7 @@ class Table(ABC):
def create_scalar_index(
self,
column: str,
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE",
*,
replace: bool = True,
):
@@ -511,6 +511,8 @@ class Table(ABC):
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: str = "auto",
ordering_field_name: Optional[str] = None,
fts_columns: Union[str, List[str]] = None,
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
@@ -1188,9 +1190,15 @@ class LanceTable(Table):
index_cache_size=index_cache_size,
)
def create_scalar_index(self, column: str, *, replace: bool = True):
def create_scalar_index(
self,
column: str,
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE",
*,
replace: bool = True,
):
self._dataset_mut.create_scalar_index(
column, index_type="BTREE", replace=replace
column, index_type=index_type, replace=replace
)
def create_fts_index(
@@ -1201,6 +1209,7 @@ class LanceTable(Table):
replace: bool = False,
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
tokenizer_name: str = "default",
use_tantivy: bool = True,
):
"""Create a full-text search index on the table.
@@ -1211,6 +1220,7 @@ class LanceTable(Table):
----------
field_names: str or list of str
The name(s) of the field to index.
can be only str if use_tantivy=True for now.
replace: bool, default False
If True, replace the existing index if it exists. Note that this is
not yet an atomic operation; the index will be temporarily
@@ -1218,12 +1228,31 @@ class LanceTable(Table):
writer_heap_size: int, default 1GB
ordering_field_names:
A list of unsigned type fields to index to optionally order
results on at search time
results on at search time.
only available with use_tantivy=True
tokenizer_name: str, default "default"
The tokenizer to use for the index. Can be "raw", "default" or the 2 letter
language code followed by "_stem". So for english it would be "en_stem".
For available languages see: https://docs.rs/tantivy/latest/tantivy/tokenizer/enum.Language.html
only available with use_tantivy=True for now
use_tantivy: bool, default False
If True, use the legacy full-text search implementation based on tantivy.
If False, use the new full-text search implementation based on lance-index.
"""
if not use_tantivy:
if not isinstance(field_names, str):
raise ValueError("field_names must be a string when use_tantivy=False")
# delete the existing legacy index if it exists
if replace:
fs, path = fs_from_uri(self._get_fts_index_path())
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
if index_exists:
fs.delete_dir(path)
self._dataset_mut.create_scalar_index(
field_names, index_type="INVERTED", replace=replace
)
return
from .fts import create_index, populate_index
if isinstance(field_names, str):
@@ -1392,6 +1421,7 @@ class LanceTable(Table):
vector_column_name: Optional[str] = None,
query_type: str = "auto",
ordering_field_name: Optional[str] = None,
fts_columns: Union[str, List[str]] = None,
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
@@ -1446,6 +1476,10 @@ class LanceTable(Table):
or raise an error if no corresponding embedding function is found.
If the `query` is a string, then the query type is "vector" if the
table has embedding functions, else the query type is "fts"
fts_columns: str or list of str, default None
The column(s) to search in for full-text search.
If None then the search is performed on all indexed columns.
For now, only one column can be searched at a time.
Returns
-------
@@ -1665,6 +1699,7 @@ class LanceTable(Table):
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
},
full_text_query=query.full_text_query,
with_row_id=query.with_row_id,
batch_size=batch_size,
).to_reader()

View File

@@ -22,7 +22,8 @@ import pytest
from lancedb.pydantic import LanceModel, Vector
def test_basic(tmp_path):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_basic(tmp_path, use_tantivy):
db = lancedb.connect(tmp_path)
assert db.uri == str(tmp_path)
@@ -55,7 +56,7 @@ def test_basic(tmp_path):
assert len(rs) == 1
assert rs["item"].iloc[0] == "foo"
table.create_fts_index(["item"])
table.create_fts_index("item", use_tantivy=use_tantivy)
rs = table.search("bar", query_type="fts").to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "bar"

View File

@@ -74,7 +74,12 @@ def test_create_index_with_stemming(tmp_path, table):
assert os.path.exists(str(tmp_path / "index"))
# Check stemming by running tokenizer on non empty table
table.create_fts_index("text", tokenizer_name="en_stem")
table.create_fts_index("text", tokenizer_name="en_stem", use_tantivy=True)
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_create_inverted_index(table, use_tantivy):
table.create_fts_index("text", use_tantivy=use_tantivy)
def test_populate_index(tmp_path, table):
@@ -92,8 +97,15 @@ def test_search_index(tmp_path, table):
assert len(results[1]) == 10 # _distance
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_search_fts(table, use_tantivy):
table.create_fts_index("text", use_tantivy=use_tantivy)
results = table.search("puppy").limit(10).to_list()
assert len(results) == 10
def test_search_ordering_field_index_table(tmp_path, table):
table.create_fts_index("text", ordering_field_names=["count"])
table.create_fts_index("text", ordering_field_names=["count"], use_tantivy=True)
rows = (
table.search("puppy", ordering_field_name="count")
.limit(20)
@@ -125,8 +137,9 @@ def test_search_ordering_field_index(tmp_path, table):
assert sorted(rows, key=lambda x: x["count"], reverse=True) == rows
def test_create_index_from_table(tmp_path, table):
table.create_fts_index("text")
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_create_index_from_table(tmp_path, table, use_tantivy):
table.create_fts_index("text", use_tantivy=use_tantivy)
df = table.search("puppy").limit(10).select(["text"]).to_pandas()
assert len(df) <= 10
assert "text" in df.columns
@@ -145,15 +158,15 @@ def test_create_index_from_table(tmp_path, table):
]
)
with pytest.raises(ValueError, match="already exists"):
table.create_fts_index("text")
with pytest.raises(Exception, match="already exists"):
table.create_fts_index("text", use_tantivy=use_tantivy)
table.create_fts_index("text", replace=True)
table.create_fts_index("text", replace=True, use_tantivy=use_tantivy)
assert len(table.search("gorilla").limit(1).to_pandas()) == 1
def test_create_index_multiple_columns(tmp_path, table):
table.create_fts_index(["text", "text2"])
table.create_fts_index(["text", "text2"], use_tantivy=True)
df = table.search("puppy").limit(10).to_pandas()
assert len(df) == 10
assert "text" in df.columns
@@ -161,20 +174,21 @@ def test_create_index_multiple_columns(tmp_path, table):
def test_empty_rs(tmp_path, table, mocker):
table.create_fts_index(["text", "text2"])
table.create_fts_index(["text", "text2"], use_tantivy=True)
mocker.patch("lancedb.fts.search_index", return_value=([], []))
df = table.search("puppy").limit(10).to_pandas()
assert len(df) == 0
def test_nested_schema(tmp_path, table):
table.create_fts_index("nested.text")
table.create_fts_index("nested.text", use_tantivy=True)
rs = table.search("puppy").limit(10).to_list()
assert len(rs) == 10
def test_search_index_with_filter(table):
table.create_fts_index("text")
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_search_index_with_filter(table, use_tantivy):
table.create_fts_index("text", use_tantivy=use_tantivy)
orig_import = __import__
def import_mock(name, *args):
@@ -186,7 +200,7 @@ def test_search_index_with_filter(table):
with mock.patch("builtins.__import__", side_effect=import_mock):
rs = table.search("puppy").where("id=1").limit(10)
# test schema
assert rs.to_arrow().drop("score").schema.equals(table.schema)
assert rs.to_arrow().drop("_score").schema.equals(table.schema)
rs = rs.to_list()
for r in rs:
@@ -204,7 +218,8 @@ def test_search_index_with_filter(table):
assert r["_rowid"] is not None
def test_null_input(table):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_null_input(table, use_tantivy):
table.add(
[
{
@@ -217,12 +232,12 @@ def test_null_input(table):
}
]
)
table.create_fts_index("text")
table.create_fts_index("text", use_tantivy=use_tantivy)
def test_syntax(table):
# https://github.com/lancedb/lancedb/issues/769
table.create_fts_index("text")
table.create_fts_index("text", use_tantivy=True)
with pytest.raises(ValueError, match="Syntax Error"):
table.search("they could have been dogs OR").limit(10).to_list()

View File

@@ -22,7 +22,7 @@ from lancedb.table import LanceTable
pytest.importorskip("lancedb.fts")
def get_test_table(tmp_path):
def get_test_table(tmp_path, use_tantivy):
db = lancedb.connect(tmp_path)
# Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
@@ -89,7 +89,7 @@ def get_test_table(tmp_path):
)
# Create a fts index
table.create_fts_index("text")
table.create_fts_index("text", use_tantivy=use_tantivy)
return table, MyTable
@@ -174,8 +174,8 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
assert len(result) == 20 and result == result_arrow
def _run_test_hybrid_reranker(reranker, tmp_path):
table, schema = get_test_table(tmp_path)
def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
table, schema = get_test_table(tmp_path, use_tantivy)
# The default reranker
result1 = (
table.search(
@@ -221,14 +221,16 @@ def _run_test_hybrid_reranker(reranker, tmp_path):
)
def test_linear_combination(tmp_path):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_linear_combination(tmp_path, use_tantivy):
reranker = LinearCombinationReranker()
_run_test_hybrid_reranker(reranker, tmp_path)
_run_test_hybrid_reranker(reranker, tmp_path, use_tantivy)
def test_rrf_reranker(tmp_path):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_rrf_reranker(tmp_path, use_tantivy):
reranker = RRFReranker()
_run_test_hybrid_reranker(reranker, tmp_path)
_run_test_hybrid_reranker(reranker, tmp_path, use_tantivy)
@pytest.mark.skipif(