mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 10:30:40 +00:00
refactor(python): remove legacy tantivy FTS support (#3282)
This follows the Rust-side Tantivy removal by deleting the remaining Python Tantivy runtime, tests, and packaging references. It also turns the legacy Python-only Tantivy parameters into explicit errors and stops reading legacy `_indices/fts` directories so Python FTS is fully native-only.
This commit is contained in:
@@ -1,201 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""Full text search index using tantivy-py"""
|
||||
|
||||
import os
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
try:
|
||||
import tantivy
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install tantivy-py `pip install tantivy` to use the full text search feature." # noqa: E501
|
||||
)
|
||||
|
||||
from .table import LanceTable
|
||||
|
||||
|
||||
def create_index(
|
||||
index_path: str,
|
||||
text_fields: List[str],
|
||||
ordering_fields: Optional[List[str]] = None,
|
||||
tokenizer_name: str = "default",
|
||||
) -> tantivy.Index:
|
||||
"""
|
||||
Create a new Index (not populated)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index_path : str
|
||||
Path to the index directory
|
||||
text_fields : List[str]
|
||||
List of text fields to index
|
||||
ordering_fields: List[str]
|
||||
List of unsigned type fields to order by at search time
|
||||
tokenizer_name : str, default "default"
|
||||
The tokenizer to use
|
||||
|
||||
Returns
|
||||
-------
|
||||
index : tantivy.Index
|
||||
The index object (not yet populated)
|
||||
"""
|
||||
if ordering_fields is None:
|
||||
ordering_fields = []
|
||||
# Declaring our schema.
|
||||
schema_builder = tantivy.SchemaBuilder()
|
||||
# special field that we'll populate with row_id
|
||||
schema_builder.add_integer_field("doc_id", stored=True)
|
||||
# data fields
|
||||
for name in text_fields:
|
||||
schema_builder.add_text_field(name, stored=True, tokenizer_name=tokenizer_name)
|
||||
if ordering_fields:
|
||||
for name in ordering_fields:
|
||||
schema_builder.add_unsigned_field(name, fast=True)
|
||||
schema = schema_builder.build()
|
||||
os.makedirs(index_path, exist_ok=True)
|
||||
index = tantivy.Index(schema, path=index_path)
|
||||
return index
|
||||
|
||||
|
||||
def populate_index(
|
||||
index: tantivy.Index,
|
||||
table: LanceTable,
|
||||
fields: List[str],
|
||||
writer_heap_size: Optional[int] = None,
|
||||
ordering_fields: Optional[List[str]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Populate an index with data from a LanceTable
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index : tantivy.Index
|
||||
The index object
|
||||
table : LanceTable
|
||||
The table to index
|
||||
fields : List[str]
|
||||
List of fields to index
|
||||
writer_heap_size : int
|
||||
The writer heap size in bytes, defaults to 1GB
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The number of rows indexed
|
||||
"""
|
||||
if ordering_fields is None:
|
||||
ordering_fields = []
|
||||
writer_heap_size = writer_heap_size or 1024 * 1024 * 1024
|
||||
# first check the fields exist and are string or large string type
|
||||
nested = []
|
||||
|
||||
for name in fields:
|
||||
try:
|
||||
f = table.schema.field(name) # raises KeyError if not found
|
||||
except KeyError:
|
||||
f = resolve_path(table.schema, name)
|
||||
nested.append(name)
|
||||
|
||||
if not pa.types.is_string(f.type) and not pa.types.is_large_string(f.type):
|
||||
raise TypeError(f"Field {name} is not a string type")
|
||||
|
||||
# create a tantivy writer
|
||||
writer = index.writer(heap_size=writer_heap_size)
|
||||
# write data into index
|
||||
dataset = table.to_lance()
|
||||
row_id = 0
|
||||
|
||||
max_nested_level = 0
|
||||
if len(nested) > 0:
|
||||
max_nested_level = max([len(name.split(".")) for name in nested])
|
||||
|
||||
for b in dataset.to_batches(columns=fields + ordering_fields):
|
||||
if max_nested_level > 0:
|
||||
b = pa.Table.from_batches([b])
|
||||
for _ in range(max_nested_level - 1):
|
||||
b = b.flatten()
|
||||
for i in range(b.num_rows):
|
||||
doc = tantivy.Document()
|
||||
for name in fields:
|
||||
value = b[name][i].as_py()
|
||||
if value is not None:
|
||||
doc.add_text(name, value)
|
||||
for name in ordering_fields:
|
||||
value = b[name][i].as_py()
|
||||
if value is not None:
|
||||
doc.add_unsigned(name, value)
|
||||
if not doc.is_empty:
|
||||
doc.add_integer("doc_id", row_id)
|
||||
writer.add_document(doc)
|
||||
row_id += 1
|
||||
# commit changes
|
||||
writer.commit()
|
||||
return row_id
|
||||
|
||||
|
||||
def resolve_path(schema, field_name: str) -> pa.Field:
|
||||
"""
|
||||
Resolve a nested field path to a list of field names
|
||||
|
||||
Parameters
|
||||
----------
|
||||
field_name : str
|
||||
The field name to resolve
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
The resolved path
|
||||
"""
|
||||
path = field_name.split(".")
|
||||
field = schema.field(path.pop(0))
|
||||
for segment in path:
|
||||
if pa.types.is_struct(field.type):
|
||||
field = field.type.field(segment)
|
||||
else:
|
||||
raise KeyError(f"field {field_name} not found in schema {schema}")
|
||||
return field
|
||||
|
||||
|
||||
def search_index(
|
||||
index: tantivy.Index, query: str, limit: int = 10, ordering_field=None
|
||||
) -> Tuple[Tuple[int], Tuple[float]]:
|
||||
"""
|
||||
Search an index for a query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index : tantivy.Index
|
||||
The index object
|
||||
query : str
|
||||
The query string
|
||||
limit : int
|
||||
The maximum number of results to return
|
||||
|
||||
Returns
|
||||
-------
|
||||
ids_and_score: list[tuple[int], tuple[float]]
|
||||
A tuple of two tuples, the first containing the document ids
|
||||
and the second containing the scores
|
||||
"""
|
||||
searcher = index.searcher()
|
||||
query = index.parse_query(query)
|
||||
# get top results
|
||||
if ordering_field:
|
||||
results = searcher.search(query, limit, order_by_field=ordering_field)
|
||||
else:
|
||||
results = searcher.search(query, limit)
|
||||
if results.count == 0:
|
||||
return tuple(), tuple()
|
||||
return tuple(
|
||||
zip(
|
||||
*[
|
||||
(searcher.doc(doc_address)["doc_id"][0], score)
|
||||
for score, doc_address in results.hits
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -25,7 +25,6 @@ import deprecation
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.fs as pa_fs
|
||||
import pydantic
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
@@ -1526,9 +1525,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
return self._table._output_schema(self.to_query_object())
|
||||
|
||||
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
path, fs, exist = self._table._get_fts_index_path()
|
||||
if exist:
|
||||
return self.tantivy_to_arrow()
|
||||
self._table._ensure_no_legacy_fts_index()
|
||||
|
||||
query = self._query
|
||||
if self._phrase_query:
|
||||
@@ -1552,90 +1549,6 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
):
|
||||
raise NotImplementedError("to_batches on an FTS query")
|
||||
|
||||
def tantivy_to_arrow(self) -> pa.Table:
|
||||
try:
|
||||
import tantivy
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install tantivy-py `pip install tantivy` to use the full text search feature." # noqa: E501
|
||||
)
|
||||
|
||||
from .fts import search_index
|
||||
|
||||
# get the index path
|
||||
path, fs, exist = self._table._get_fts_index_path()
|
||||
|
||||
# check if the index exist
|
||||
if not exist:
|
||||
raise FileNotFoundError(
|
||||
"Fts index does not exist. "
|
||||
"Please first call table.create_fts_index(['<field_names>']) to "
|
||||
"create the fts index."
|
||||
)
|
||||
|
||||
# Check that we are on local filesystem
|
||||
if not isinstance(fs, pa_fs.LocalFileSystem):
|
||||
raise NotImplementedError(
|
||||
"Tantivy-based full text search "
|
||||
"is only supported on the local filesystem"
|
||||
)
|
||||
# open the index
|
||||
index = tantivy.Index.open(path)
|
||||
# get the scores and doc ids
|
||||
query = self._query
|
||||
if self._phrase_query:
|
||||
query = query.replace('"', "'")
|
||||
query = f'"{query}"'
|
||||
limit = self._limit if self._limit is not None else 10
|
||||
row_ids, scores = search_index(
|
||||
index, query, limit, ordering_field=self.ordering_field_name
|
||||
)
|
||||
if len(row_ids) == 0:
|
||||
empty_schema = pa.schema([pa.field("_score", pa.float32())])
|
||||
return pa.Table.from_batches([], schema=empty_schema)
|
||||
scores = pa.array(scores)
|
||||
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
|
||||
output_tbl = output_tbl.append_column("_score", scores)
|
||||
# this needs to match vector search results which are uint64
|
||||
row_ids = pa.array(row_ids, type=pa.uint64())
|
||||
|
||||
if self._where is not None:
|
||||
tmp_name = "__lancedb__duckdb__indexer__"
|
||||
output_tbl = output_tbl.append_column(
|
||||
tmp_name, pa.array(range(len(output_tbl)))
|
||||
)
|
||||
try:
|
||||
# TODO would be great to have Substrait generate pyarrow compute
|
||||
# expressions or conversely have pyarrow support SQL expressions
|
||||
# using Substrait
|
||||
import duckdb
|
||||
|
||||
indexer = duckdb.sql(
|
||||
f"SELECT {tmp_name} FROM output_tbl WHERE {self._where}"
|
||||
).to_arrow_table()[tmp_name]
|
||||
output_tbl = output_tbl.take(indexer).drop([tmp_name])
|
||||
row_ids = row_ids.take(indexer)
|
||||
|
||||
except ImportError:
|
||||
import tempfile
|
||||
|
||||
import lance
|
||||
|
||||
# TODO Use "memory://" instead once that's supported
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
ds = lance.write_dataset(output_tbl, tmp)
|
||||
output_tbl = ds.to_table(filter=self._where)
|
||||
indexer = output_tbl[tmp_name]
|
||||
row_ids = row_ids.take(indexer)
|
||||
output_tbl = output_tbl.drop([tmp_name])
|
||||
|
||||
if self._with_row_id:
|
||||
output_tbl = output_tbl.append_column("_rowid", row_ids)
|
||||
|
||||
if self._reranker is not None:
|
||||
output_tbl = self._reranker.rerank_fts(self._query, output_tbl)
|
||||
return output_tbl
|
||||
|
||||
def rerank(self, reranker: Reranker) -> LanceFtsQueryBuilder:
|
||||
"""Rerank the results using the specified reranker.
|
||||
|
||||
|
||||
@@ -943,29 +943,26 @@ class Table(ABC):
|
||||
Parameters
|
||||
----------
|
||||
field_names: str or list of str
|
||||
The name(s) of the field to index.
|
||||
If ``use_tantivy`` is False (default), only a single field name
|
||||
(str) is supported. To index multiple fields, create a separate
|
||||
FTS index for each field.
|
||||
The name of the field to index. Native FTS indexes can only be
|
||||
created on a single field at a time. To search over multiple text
|
||||
fields, create a separate FTS index for each field.
|
||||
replace: bool, default False
|
||||
If True, replace the existing index if it exists. Note that this is
|
||||
not yet an atomic operation; the index will be temporarily
|
||||
unavailable while the new index is being created.
|
||||
writer_heap_size: int, default 1GB
|
||||
Only available with use_tantivy=True
|
||||
Deprecated legacy Tantivy parameter. Any value other than the
|
||||
default raises an error.
|
||||
ordering_field_names:
|
||||
A list of unsigned type fields to index to optionally order
|
||||
results on at search time.
|
||||
only available with use_tantivy=True
|
||||
Deprecated legacy Tantivy parameter. Setting this raises an error.
|
||||
tokenizer_name: str, default "default"
|
||||
The tokenizer to use for the index. Can be "raw", "default" or the 2 letter
|
||||
language code followed by "_stem". So for english it would be "en_stem".
|
||||
For available languages see: https://docs.rs/tantivy/latest/tantivy/tokenizer/enum.Language.html
|
||||
A compatibility alias for native tokenizer configs. Can be "raw",
|
||||
"default" or the 2 letter language code followed by "_stem". So
|
||||
for english it would be "en_stem".
|
||||
use_tantivy: bool, default False
|
||||
If True, use the legacy full-text search implementation based on tantivy.
|
||||
If False, use the new full-text search implementation based on lance-index.
|
||||
Deprecated legacy Tantivy parameter. Setting this to True raises an
|
||||
error.
|
||||
with_position: bool, default False
|
||||
Only available with use_tantivy=False
|
||||
If False, do not store the positions of the terms in the text.
|
||||
This can reduce the size of the index and improve indexing speed.
|
||||
But it will raise an exception for phrase queries.
|
||||
@@ -1746,6 +1743,16 @@ class Table(ABC):
|
||||
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
|
||||
return (path, fs, index_exists)
|
||||
|
||||
def _ensure_no_legacy_fts_index(self):
|
||||
path, _, exists = self._get_fts_index_path()
|
||||
if exists:
|
||||
raise ValueError(
|
||||
"Legacy Tantivy FTS index detected at "
|
||||
f"{path}. Tantivy-based FTS has been removed. "
|
||||
"Delete the legacy index and recreate it with "
|
||||
"table.create_fts_index(...)."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def uses_v2_manifest_paths(self) -> bool:
|
||||
"""
|
||||
@@ -2405,84 +2412,63 @@ class LanceTable(Table):
|
||||
prefix_only: bool = False,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
if not use_tantivy:
|
||||
if not isinstance(field_names, str):
|
||||
raise ValueError(
|
||||
"Native FTS indexes can only be created on a single field "
|
||||
"at a time. To search over multiple text fields, create a "
|
||||
"separate FTS index for each field."
|
||||
)
|
||||
self._ensure_no_legacy_fts_index()
|
||||
|
||||
if tokenizer_name is None:
|
||||
tokenizer_configs = {
|
||||
"base_tokenizer": base_tokenizer,
|
||||
"language": language,
|
||||
"with_position": with_position,
|
||||
"max_token_length": max_token_length,
|
||||
"lower_case": lower_case,
|
||||
"stem": stem,
|
||||
"remove_stop_words": remove_stop_words,
|
||||
"ascii_folding": ascii_folding,
|
||||
"ngram_min_length": ngram_min_length,
|
||||
"ngram_max_length": ngram_max_length,
|
||||
"prefix_only": prefix_only,
|
||||
}
|
||||
else:
|
||||
tokenizer_configs = self.infer_tokenizer_configs(tokenizer_name)
|
||||
|
||||
config = FTS(
|
||||
**tokenizer_configs,
|
||||
if use_tantivy:
|
||||
raise ValueError(
|
||||
"Tantivy-based FTS has been removed. "
|
||||
"Remove use_tantivy and recreate the index with native FTS."
|
||||
)
|
||||
|
||||
# delete the existing legacy index if it exists
|
||||
if replace:
|
||||
path, fs, exist = self._get_fts_index_path()
|
||||
if exist:
|
||||
fs.delete_dir(path)
|
||||
|
||||
LOOP.run(
|
||||
self._table.create_index(
|
||||
field_names,
|
||||
replace=replace,
|
||||
config=config,
|
||||
name=name,
|
||||
)
|
||||
if ordering_field_names is not None:
|
||||
raise ValueError(
|
||||
"ordering_field_names was only supported by the removed "
|
||||
"Tantivy-based FTS implementation."
|
||||
)
|
||||
return
|
||||
|
||||
from .fts import create_index, populate_index
|
||||
|
||||
if isinstance(field_names, str):
|
||||
field_names = [field_names]
|
||||
|
||||
if isinstance(ordering_field_names, str):
|
||||
ordering_field_names = [ordering_field_names]
|
||||
|
||||
path, fs, exist = self._get_fts_index_path()
|
||||
if exist:
|
||||
if not replace:
|
||||
raise ValueError("Index already exists. Use replace=True to overwrite.")
|
||||
fs.delete_dir(path)
|
||||
|
||||
if not isinstance(fs, pa_fs.LocalFileSystem):
|
||||
raise NotImplementedError(
|
||||
"Full-text search is only supported on the local filesystem"
|
||||
if writer_heap_size != 1024 * 1024 * 1024:
|
||||
raise ValueError(
|
||||
"writer_heap_size was only supported by the removed "
|
||||
"Tantivy-based FTS implementation."
|
||||
)
|
||||
if not isinstance(field_names, str):
|
||||
raise ValueError(
|
||||
"Native FTS indexes can only be created on a single field "
|
||||
"at a time. To search over multiple text fields, create a "
|
||||
"separate FTS index for each field."
|
||||
)
|
||||
if "." in field_names:
|
||||
raise ValueError(
|
||||
"Native FTS indexes can only be created on top-level fields. "
|
||||
f"Received nested field path: {field_names!r}."
|
||||
)
|
||||
|
||||
if tokenizer_name is None:
|
||||
tokenizer_name = "default"
|
||||
index = create_index(
|
||||
path,
|
||||
field_names,
|
||||
ordering_fields=ordering_field_names,
|
||||
tokenizer_name=tokenizer_name,
|
||||
tokenizer_configs = {
|
||||
"base_tokenizer": base_tokenizer,
|
||||
"language": language,
|
||||
"with_position": with_position,
|
||||
"max_token_length": max_token_length,
|
||||
"lower_case": lower_case,
|
||||
"stem": stem,
|
||||
"remove_stop_words": remove_stop_words,
|
||||
"ascii_folding": ascii_folding,
|
||||
"ngram_min_length": ngram_min_length,
|
||||
"ngram_max_length": ngram_max_length,
|
||||
"prefix_only": prefix_only,
|
||||
}
|
||||
else:
|
||||
tokenizer_configs = self.infer_tokenizer_configs(tokenizer_name)
|
||||
|
||||
config = FTS(
|
||||
**tokenizer_configs,
|
||||
)
|
||||
populate_index(
|
||||
index,
|
||||
self,
|
||||
field_names,
|
||||
ordering_fields=ordering_field_names,
|
||||
writer_heap_size=writer_heap_size,
|
||||
|
||||
LOOP.run(
|
||||
self._table.create_index(
|
||||
field_names,
|
||||
replace=replace,
|
||||
config=config,
|
||||
name=name,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -180,7 +180,7 @@ def test_fts_fuzzy_query():
|
||||
),
|
||||
mode="overwrite",
|
||||
)
|
||||
table.create_fts_index("text", use_tantivy=False, replace=True)
|
||||
table.create_fts_index("text", replace=True)
|
||||
|
||||
results = table.search(MatchQuery("foo", "text", fuzziness=1)).to_pandas()
|
||||
assert len(results) == 4
|
||||
@@ -230,7 +230,7 @@ def test_fts_boost_query():
|
||||
),
|
||||
mode="overwrite",
|
||||
)
|
||||
table.create_fts_index("desc", use_tantivy=False, replace=True)
|
||||
table.create_fts_index("desc", replace=True)
|
||||
|
||||
results = table.search(
|
||||
BoostQuery(
|
||||
@@ -265,7 +265,7 @@ def test_fts_boolean_query(tmp_path):
|
||||
],
|
||||
mode="overwrite",
|
||||
)
|
||||
table.create_fts_index("text", use_tantivy=False, replace=True)
|
||||
table.create_fts_index("text", replace=True)
|
||||
|
||||
# SHOULD
|
||||
results = table.search(
|
||||
@@ -319,9 +319,7 @@ def test_fts_native():
|
||||
],
|
||||
)
|
||||
|
||||
# passing `use_tantivy=False` to use lance FTS index
|
||||
# `use_tantivy=True` by default
|
||||
table.create_fts_index("text", use_tantivy=False)
|
||||
table.create_fts_index("text")
|
||||
table.search("puppy").limit(10).select(["text"]).to_list()
|
||||
# [{'text': 'Frodo was a happy puppy', '_score': 0.6931471824645996}]
|
||||
# ...
|
||||
@@ -332,7 +330,6 @@ def test_fts_native():
|
||||
# --8<-- [start:fts_config_folding]
|
||||
table.create_fts_index(
|
||||
"text",
|
||||
use_tantivy=False,
|
||||
language="French",
|
||||
stem=True,
|
||||
ascii_folding=True,
|
||||
@@ -346,7 +343,7 @@ def test_fts_native():
|
||||
table.search("puppy").limit(10).where("text='foo'", prefilter=False).to_list()
|
||||
# --8<-- [end:fts_postfiltering]
|
||||
# --8<-- [start:fts_with_position]
|
||||
table.create_fts_index("text", use_tantivy=False, with_position=True, replace=True)
|
||||
table.create_fts_index("text", with_position=True, replace=True)
|
||||
# --8<-- [end:fts_with_position]
|
||||
# --8<-- [start:fts_incremental_index]
|
||||
table.add([{"vector": [3.1, 4.1], "text": "Frodo was a happy puppy"}])
|
||||
|
||||
@@ -15,8 +15,7 @@ import pytest
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_basic(tmp_path, use_tantivy):
|
||||
def test_basic(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
|
||||
assert db.uri == str(tmp_path)
|
||||
@@ -49,7 +48,7 @@ def test_basic(tmp_path, use_tantivy):
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "foo"
|
||||
|
||||
table.create_fts_index("item", use_tantivy=use_tantivy)
|
||||
table.create_fts_index("item")
|
||||
rs = table.search("bar", query_type="fts").to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "bar"
|
||||
|
||||
@@ -36,9 +36,6 @@ import pytest
|
||||
import pytest_asyncio
|
||||
from utils import exception_output
|
||||
|
||||
pytest.importorskip("lancedb.fts")
|
||||
tantivy = pytest.importorskip("tantivy")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def table(tmp_path) -> ldb.table.LanceTable:
|
||||
@@ -144,58 +141,53 @@ async def async_table(tmp_path) -> ldb.table.AsyncTable:
|
||||
return table
|
||||
|
||||
|
||||
def test_create_index(tmp_path):
|
||||
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
|
||||
assert isinstance(index, tantivy.Index)
|
||||
assert os.path.exists(str(tmp_path / "index"))
|
||||
@pytest.mark.parametrize(
|
||||
("kwargs", "match"),
|
||||
[
|
||||
(
|
||||
{"use_tantivy": True},
|
||||
"Tantivy-based FTS has been removed",
|
||||
),
|
||||
(
|
||||
{"ordering_field_names": ["count"]},
|
||||
"ordering_field_names was only supported",
|
||||
),
|
||||
(
|
||||
{"writer_heap_size": 128},
|
||||
"writer_heap_size was only supported",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_reject_removed_tantivy_parameters(table, kwargs, match):
|
||||
with pytest.raises(ValueError, match=match):
|
||||
table.create_fts_index("text", **kwargs)
|
||||
|
||||
|
||||
def test_create_index_with_stemming(tmp_path, table):
|
||||
index = ldb.fts.create_index(
|
||||
str(tmp_path / "index"), ["text"], tokenizer_name="en_stem"
|
||||
)
|
||||
assert isinstance(index, tantivy.Index)
|
||||
assert os.path.exists(str(tmp_path / "index"))
|
||||
def test_reject_legacy_tantivy_index(table):
|
||||
path, _, _ = table._get_fts_index_path()
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
# Check stemming by running tokenizer on non empty table
|
||||
table.create_fts_index("text", tokenizer_name="en_stem", use_tantivy=True)
|
||||
with pytest.raises(ValueError, match="Legacy Tantivy FTS index detected"):
|
||||
table.search("puppy").limit(5).to_list()
|
||||
|
||||
with pytest.raises(ValueError, match="Legacy Tantivy FTS index detected"):
|
||||
table.create_fts_index("text")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
@pytest.mark.parametrize("with_position", [True, False])
|
||||
def test_create_inverted_index(table, use_tantivy, with_position):
|
||||
if use_tantivy and not with_position:
|
||||
pytest.skip("we don't support building a tantivy index without position")
|
||||
def test_create_inverted_index(table, with_position):
|
||||
table.create_fts_index(
|
||||
"text",
|
||||
use_tantivy=use_tantivy,
|
||||
with_position=with_position,
|
||||
name="custom_fts_index",
|
||||
)
|
||||
if not use_tantivy:
|
||||
indices = table.list_indices()
|
||||
fts_indices = [i for i in indices if i.index_type == "FTS"]
|
||||
assert any(i.name == "custom_fts_index" for i in fts_indices)
|
||||
indices = table.list_indices()
|
||||
fts_indices = [i for i in indices if i.index_type == "FTS"]
|
||||
assert any(i.name == "custom_fts_index" for i in fts_indices)
|
||||
|
||||
|
||||
def test_populate_index(tmp_path, table):
|
||||
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
|
||||
assert ldb.fts.populate_index(index, table, ["text"]) == len(table)
|
||||
|
||||
|
||||
def test_search_index(tmp_path, table):
|
||||
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
|
||||
ldb.fts.populate_index(index, table, ["text"])
|
||||
index.reload()
|
||||
results = ldb.fts.search_index(index, query="puppy", limit=5)
|
||||
assert len(results) == 2
|
||||
assert len(results[0]) == 5 # row_ids
|
||||
assert len(results[1]) == 5 # _score
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_search_fts(table, use_tantivy):
|
||||
table.create_fts_index("text", use_tantivy=use_tantivy)
|
||||
def test_search_fts(table):
|
||||
table.create_fts_index("text")
|
||||
results = table.search("puppy").select(["id", "text"]).limit(5).to_list()
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
@@ -204,53 +196,52 @@ def test_search_fts(table, use_tantivy):
|
||||
results = table.search("puppy").select(["id", "text"]).to_list()
|
||||
assert len(results) == 10
|
||||
|
||||
if not use_tantivy:
|
||||
# Test with a query
|
||||
results = (
|
||||
table.search(MatchQuery("puppy", "text"))
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
# Test with a query
|
||||
results = (
|
||||
table.search(MatchQuery("puppy", "text"))
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
|
||||
# Test boost query
|
||||
results = (
|
||||
table.search(
|
||||
BoostQuery(
|
||||
MatchQuery("puppy", "text"),
|
||||
MatchQuery("runs", "text"),
|
||||
)
|
||||
# Test boost query
|
||||
results = (
|
||||
table.search(
|
||||
BoostQuery(
|
||||
MatchQuery("puppy", "text"),
|
||||
MatchQuery("runs", "text"),
|
||||
)
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
|
||||
# Test multi match query
|
||||
table.create_fts_index("text2", use_tantivy=use_tantivy)
|
||||
results = (
|
||||
table.search(MultiMatchQuery("puppy", ["text", "text2"]))
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
# Test multi match query
|
||||
table.create_fts_index("text2")
|
||||
results = (
|
||||
table.search(MultiMatchQuery("puppy", ["text", "text2"]))
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
|
||||
# Test boolean query
|
||||
results = (
|
||||
table.search(MatchQuery("puppy", "text") & MatchQuery("runs", "text"))
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
for r in results:
|
||||
assert "puppy" in r["text"]
|
||||
assert "runs" in r["text"]
|
||||
# Test boolean query
|
||||
results = (
|
||||
table.search(MatchQuery("puppy", "text") & MatchQuery("runs", "text"))
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
for r in results:
|
||||
assert "puppy" in r["text"]
|
||||
assert "runs" in r["text"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -318,13 +309,13 @@ async def test_fts_select_async(async_table):
|
||||
|
||||
|
||||
def test_search_fts_phrase_query(table):
|
||||
table.create_fts_index("text", use_tantivy=False, with_position=False)
|
||||
table.create_fts_index("text", with_position=False)
|
||||
try:
|
||||
phrase_results = table.search('"puppy runs"').limit(100).to_list()
|
||||
assert False
|
||||
except Exception:
|
||||
pass
|
||||
table.create_fts_index("text", use_tantivy=False, with_position=True, replace=True)
|
||||
table.create_fts_index("text", with_position=True, replace=True)
|
||||
results = table.search("puppy").limit(100).to_list()
|
||||
|
||||
# Test with quotation marks
|
||||
@@ -375,8 +366,8 @@ async def test_search_fts_phrase_query_async(async_table):
|
||||
|
||||
|
||||
def test_search_fts_specify_column(table):
|
||||
table.create_fts_index("text", use_tantivy=False)
|
||||
table.create_fts_index("text2", use_tantivy=False)
|
||||
table.create_fts_index("text")
|
||||
table.create_fts_index("text2")
|
||||
|
||||
results = table.search("puppy", fts_columns="text").limit(5).to_list()
|
||||
assert len(results) == 5
|
||||
@@ -470,42 +461,8 @@ async def test_search_fts_specify_column_async(async_table):
|
||||
pass
|
||||
|
||||
|
||||
def test_search_ordering_field_index_table(tmp_path, table):
|
||||
table.create_fts_index("text", ordering_field_names=["count"], use_tantivy=True)
|
||||
rows = (
|
||||
table.search("puppy", ordering_field_name="count")
|
||||
.limit(20)
|
||||
.select(["text", "count"])
|
||||
.to_list()
|
||||
)
|
||||
for r in rows:
|
||||
assert "puppy" in r["text"]
|
||||
assert sorted(rows, key=lambda x: x["count"], reverse=True) == rows
|
||||
|
||||
|
||||
def test_search_ordering_field_index(tmp_path, table):
|
||||
index = ldb.fts.create_index(
|
||||
str(tmp_path / "index"), ["text"], ordering_fields=["count"]
|
||||
)
|
||||
|
||||
ldb.fts.populate_index(index, table, ["text"], ordering_fields=["count"])
|
||||
index.reload()
|
||||
results = ldb.fts.search_index(
|
||||
index, query="puppy", limit=5, ordering_field="count"
|
||||
)
|
||||
assert len(results) == 2
|
||||
assert len(results[0]) == 5 # row_ids
|
||||
assert len(results[1]) == 5 # _distance
|
||||
rows = table.to_lance().take(results[0]).to_pylist()
|
||||
|
||||
for r in rows:
|
||||
assert "puppy" in r["text"]
|
||||
assert sorted(rows, key=lambda x: x["count"], reverse=True) == rows
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_create_index_from_table(tmp_path, table, use_tantivy):
|
||||
table.create_fts_index("text", use_tantivy=use_tantivy)
|
||||
def test_create_index_from_table(tmp_path, table):
|
||||
table.create_fts_index("text")
|
||||
df = table.search("puppy").limit(5).select(["text"]).to_pandas()
|
||||
assert len(df) <= 5
|
||||
assert "text" in df.columns
|
||||
@@ -525,36 +482,24 @@ def test_create_index_from_table(tmp_path, table, use_tantivy):
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="already exists"):
|
||||
table.create_fts_index("text", use_tantivy=use_tantivy)
|
||||
table.create_fts_index("text")
|
||||
|
||||
table.create_fts_index("text", replace=True, use_tantivy=use_tantivy)
|
||||
table.create_fts_index("text", replace=True)
|
||||
assert len(table.search("gorilla").limit(1).to_pandas()) == 1
|
||||
|
||||
|
||||
def test_create_index_multiple_columns(tmp_path, table):
|
||||
table.create_fts_index(["text", "text2"], use_tantivy=True)
|
||||
df = table.search("puppy").limit(5).to_pandas()
|
||||
assert len(df) == 5
|
||||
assert "text" in df.columns
|
||||
assert "text2" in df.columns
|
||||
|
||||
|
||||
def test_empty_rs(tmp_path, table, mocker):
|
||||
table.create_fts_index(["text", "text2"], use_tantivy=True)
|
||||
mocker.patch("lancedb.fts.search_index", return_value=([], []))
|
||||
df = table.search("puppy").limit(5).to_pandas()
|
||||
assert len(df) == 0
|
||||
with pytest.raises(ValueError, match="Native FTS indexes can only be created"):
|
||||
table.create_fts_index(["text", "text2"])
|
||||
|
||||
|
||||
def test_nested_schema(tmp_path, table):
|
||||
table.create_fts_index("nested.text", use_tantivy=True)
|
||||
rs = table.search("puppy").limit(5).to_list()
|
||||
assert len(rs) == 5
|
||||
with pytest.raises(ValueError, match="top-level fields"):
|
||||
table.create_fts_index("nested.text")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_search_index_with_filter(table, use_tantivy):
|
||||
table.create_fts_index("text", use_tantivy=use_tantivy)
|
||||
def test_search_index_with_filter(table):
|
||||
table.create_fts_index("text")
|
||||
orig_import = __import__
|
||||
|
||||
def import_mock(name, *args):
|
||||
@@ -584,8 +529,7 @@ def test_search_index_with_filter(table, use_tantivy):
|
||||
assert r["_rowid"] is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_null_input(table, use_tantivy):
|
||||
def test_null_input(table):
|
||||
table.add(
|
||||
[
|
||||
{
|
||||
@@ -598,14 +542,13 @@ def test_null_input(table, use_tantivy):
|
||||
}
|
||||
]
|
||||
)
|
||||
table.create_fts_index("text", use_tantivy=use_tantivy)
|
||||
table.create_fts_index("text")
|
||||
|
||||
|
||||
def test_syntax(table):
|
||||
# https://github.com/lancedb/lancedb/issues/769
|
||||
table.create_fts_index("text", use_tantivy=True)
|
||||
with pytest.raises(ValueError, match="Syntax Error"):
|
||||
table.search("they could have been dogs OR").limit(10).to_list()
|
||||
table.create_fts_index("text")
|
||||
table.search("they could have been dogs OR").limit(10).to_list()
|
||||
|
||||
# these should work
|
||||
|
||||
@@ -616,6 +559,7 @@ def test_syntax(table):
|
||||
).to_list()
|
||||
|
||||
# phrase queries
|
||||
table.create_fts_index("text", with_position=True, replace=True)
|
||||
table.search("they could have been dogs OR cats").phrase_query().limit(10).to_list()
|
||||
table.search('"they could have been dogs OR cats"').limit(10).to_list()
|
||||
table.search('''"the cats OR dogs were not really 'pets' at all"''').limit(
|
||||
@@ -639,7 +583,7 @@ def test_language(mem_db: DBConnection):
|
||||
table = mem_db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
table.create_fts_index("text", use_tantivy=False, language="klingon")
|
||||
table.create_fts_index("text", language="klingon")
|
||||
|
||||
assert exception_output(e) == (
|
||||
"ValueError: LanceDB does not support the requested language: 'klingon'\n"
|
||||
@@ -650,7 +594,6 @@ def test_language(mem_db: DBConnection):
|
||||
|
||||
table.create_fts_index(
|
||||
"text",
|
||||
use_tantivy=False,
|
||||
language="French",
|
||||
stem=True,
|
||||
ascii_folding=True,
|
||||
@@ -690,7 +633,7 @@ def test_fts_on_list(mem_db: DBConnection):
|
||||
}
|
||||
)
|
||||
table = mem_db.create_table("test", data=data)
|
||||
table.create_fts_index("text", use_tantivy=False, with_position=True)
|
||||
table.create_fts_index("text", with_position=True)
|
||||
|
||||
res = table.search("lance").limit(5).to_list()
|
||||
assert len(res) == 3
|
||||
@@ -702,7 +645,7 @@ def test_fts_on_list(mem_db: DBConnection):
|
||||
def test_fts_ngram(mem_db: DBConnection):
|
||||
data = pa.table({"text": ["hello world", "lance database", "lance is cool"]})
|
||||
table = mem_db.create_table("test", data=data)
|
||||
table.create_fts_index("text", use_tantivy=False, base_tokenizer="ngram")
|
||||
table.create_fts_index("text", base_tokenizer="ngram")
|
||||
|
||||
results = table.search("lan", query_type="fts").limit(10).to_list()
|
||||
assert len(results) == 2
|
||||
@@ -721,7 +664,6 @@ def test_fts_ngram(mem_db: DBConnection):
|
||||
# test setting min_ngram_length and prefix_only
|
||||
table.create_fts_index(
|
||||
"text",
|
||||
use_tantivy=False,
|
||||
base_tokenizer="ngram",
|
||||
replace=True,
|
||||
ngram_min_length=2,
|
||||
@@ -886,7 +828,7 @@ def test_fts_query_to_json():
|
||||
|
||||
|
||||
def test_fts_fast_search(table):
|
||||
table.create_fts_index("text", use_tantivy=False)
|
||||
table.create_fts_index("text")
|
||||
|
||||
# Insert some unindexed data
|
||||
table.add(
|
||||
|
||||
@@ -28,7 +28,7 @@ def sync_table(tmpdir_factory) -> Table:
|
||||
}
|
||||
)
|
||||
table = db.create_table("test", data)
|
||||
table.create_fts_index("text", with_position=False, use_tantivy=False)
|
||||
table.create_fts_index("text", with_position=False)
|
||||
return table
|
||||
|
||||
|
||||
@@ -192,7 +192,7 @@ def table_with_id(tmpdir_factory) -> Table:
|
||||
}
|
||||
)
|
||||
table = db.create_table("test_with_id", data)
|
||||
table.create_fts_index("text", with_position=False, use_tantivy=False)
|
||||
table.create_fts_index("text", with_position=False)
|
||||
return table
|
||||
|
||||
|
||||
|
||||
@@ -1385,7 +1385,7 @@ def test_query_timeout(tmp_path):
|
||||
}
|
||||
)
|
||||
table = db.create_table("test", data)
|
||||
table.create_fts_index("text", use_tantivy=False)
|
||||
table.create_fts_index("text")
|
||||
|
||||
with pytest.raises(Exception, match="Query timeout"):
|
||||
table.search().where("text = 'a'").to_list(timeout=timedelta(0))
|
||||
|
||||
@@ -26,11 +26,8 @@ from lancedb.rerankers import (
|
||||
)
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
# Tests rely on FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
|
||||
|
||||
def get_test_table(tmp_path, use_tantivy):
|
||||
def get_test_table(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test").create()
|
||||
@@ -98,7 +95,7 @@ def get_test_table(tmp_path, use_tantivy):
|
||||
)
|
||||
|
||||
# Create a fts index
|
||||
table.create_fts_index("text", use_tantivy=use_tantivy, replace=True)
|
||||
table.create_fts_index("text", replace=True)
|
||||
|
||||
return table, MyTable
|
||||
|
||||
@@ -208,8 +205,8 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
assert len(result) == 20 and result == result_arrow
|
||||
|
||||
|
||||
def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
def _run_test_hybrid_reranker(reranker, tmp_path):
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# The default reranker
|
||||
result1 = (
|
||||
table.search(
|
||||
@@ -285,8 +282,7 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_linear_combination(tmp_path, use_tantivy):
|
||||
def test_linear_combination(tmp_path):
|
||||
reranker = LinearCombinationReranker()
|
||||
|
||||
vector_results = pa.Table.from_pydict(
|
||||
@@ -313,22 +309,20 @@ def test_linear_combination(tmp_path, use_tantivy):
|
||||
assert "_score" not in combined_results.column_names
|
||||
assert "_relevance_score" in combined_results.column_names
|
||||
|
||||
_run_test_hybrid_reranker(reranker, tmp_path, use_tantivy)
|
||||
_run_test_hybrid_reranker(reranker, tmp_path)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_rrf_reranker(tmp_path, use_tantivy):
|
||||
def test_rrf_reranker(tmp_path):
|
||||
reranker = RRFReranker()
|
||||
_run_test_hybrid_reranker(reranker, tmp_path, use_tantivy)
|
||||
_run_test_hybrid_reranker(reranker, tmp_path)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_mrr_reranker(tmp_path, use_tantivy):
|
||||
def test_mrr_reranker(tmp_path):
|
||||
reranker = MRRReranker()
|
||||
_run_test_hybrid_reranker(reranker, tmp_path, use_tantivy)
|
||||
_run_test_hybrid_reranker(reranker, tmp_path)
|
||||
|
||||
# Test multi-vector part
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
table, schema = get_test_table(tmp_path)
|
||||
query = "single player experience"
|
||||
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
|
||||
rs2 = (
|
||||
@@ -363,7 +357,7 @@ def test_rrf_reranker_distance():
|
||||
table = db.create_table("test", data)
|
||||
|
||||
table.create_index(num_partitions=1, num_sub_vectors=2)
|
||||
table.create_fts_index("text", use_tantivy=False)
|
||||
table.create_fts_index("text")
|
||||
|
||||
reranker = RRFReranker(return_score="all")
|
||||
|
||||
@@ -422,35 +416,31 @@ def test_rrf_reranker_distance():
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
)
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_cohere_reranker(tmp_path, use_tantivy):
|
||||
def test_cohere_reranker(tmp_path):
|
||||
pytest.importorskip("cohere")
|
||||
reranker = CohereReranker()
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
table, schema = get_test_table(tmp_path)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_cross_encoder_reranker(tmp_path, use_tantivy):
|
||||
def test_cross_encoder_reranker(tmp_path):
|
||||
pytest.importorskip("sentence_transformers")
|
||||
reranker = CrossEncoderReranker()
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
table, schema = get_test_table(tmp_path)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_colbert_reranker(tmp_path, use_tantivy):
|
||||
def test_colbert_reranker(tmp_path):
|
||||
pytest.importorskip("rerankers")
|
||||
reranker = ColbertReranker()
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
table, schema = get_test_table(tmp_path)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_answerdotai_reranker(tmp_path, use_tantivy):
|
||||
def test_answerdotai_reranker(tmp_path):
|
||||
pytest.importorskip("rerankers")
|
||||
reranker = AnswerdotaiRerankers()
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
table, schema = get_test_table(tmp_path)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
@@ -459,10 +449,9 @@ def test_answerdotai_reranker(tmp_path, use_tantivy):
|
||||
or os.environ.get("OPENAI_BASE_URL") is not None,
|
||||
reason="OPENAI_API_KEY not set",
|
||||
)
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_openai_reranker(tmp_path, use_tantivy):
|
||||
def test_openai_reranker(tmp_path):
|
||||
pytest.importorskip("openai")
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
table, schema = get_test_table(tmp_path)
|
||||
reranker = OpenaiReranker()
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
@@ -470,10 +459,9 @@ def test_openai_reranker(tmp_path, use_tantivy):
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("JINA_API_KEY") is None, reason="JINA_API_KEY not set"
|
||||
)
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_jina_reranker(tmp_path, use_tantivy):
|
||||
def test_jina_reranker(tmp_path):
|
||||
pytest.importorskip("jina")
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
table, schema = get_test_table(tmp_path)
|
||||
reranker = JinaReranker()
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
@@ -481,11 +469,10 @@ def test_jina_reranker(tmp_path, use_tantivy):
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_voyageai_reranker(tmp_path, use_tantivy):
|
||||
def test_voyageai_reranker(tmp_path):
|
||||
pytest.importorskip("voyageai")
|
||||
reranker = VoyageAIReranker(model_name="rerank-2.5")
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
table, schema = get_test_table(tmp_path)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
@@ -504,7 +491,7 @@ def test_empty_result_reranker():
|
||||
|
||||
# Create empty table with schema
|
||||
empty_table = db.create_table("empty_table", schema=schema, mode="overwrite")
|
||||
empty_table.create_fts_index("text", use_tantivy=False, replace=True)
|
||||
empty_table.create_fts_index("text", replace=True)
|
||||
for reranker in [
|
||||
CrossEncoderReranker(),
|
||||
# ColbertReranker(),
|
||||
@@ -603,11 +590,10 @@ def test_empty_hybrid_result_reranker():
|
||||
assert "_rowid" in result.column_names
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_cross_encoder_reranker_return_all(tmp_path, use_tantivy):
|
||||
def test_cross_encoder_reranker_return_all(tmp_path):
|
||||
pytest.importorskip("sentence_transformers")
|
||||
reranker = CrossEncoderReranker(return_score="all")
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
table, schema = get_test_table(tmp_path)
|
||||
query = "single player experience"
|
||||
result = (
|
||||
table.search(query, query_type="hybrid", vector_column_name="vector")
|
||||
|
||||
@@ -242,8 +242,8 @@ def test_s3_dynamodb_sync(s3_bucket: str, commit_table: str, monkeypatch):
|
||||
|
||||
# FTS indices should error since they are not supported yet.
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="Full-text search is only supported on the local filesystem",
|
||||
ValueError,
|
||||
match="Tantivy-based FTS has been removed",
|
||||
):
|
||||
table.create_fts_index("x", use_tantivy=True)
|
||||
|
||||
|
||||
@@ -1948,7 +1948,6 @@ def setup_hybrid_search_table(db: DBConnection, embedding_func):
|
||||
|
||||
def test_hybrid_search(tmp_db: DBConnection):
|
||||
# This test uses an FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
pytest.importorskip("lance")
|
||||
|
||||
table, MyTable, emb = setup_hybrid_search_table(tmp_db, "test")
|
||||
@@ -2019,7 +2018,6 @@ def test_hybrid_search(tmp_db: DBConnection):
|
||||
|
||||
def test_hybrid_search_metric_type(tmp_db: DBConnection):
|
||||
# This test uses an FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
pytest.importorskip("lance")
|
||||
|
||||
# Need to use nonnorm as the embedding function so l2 and dot results
|
||||
|
||||
Reference in New Issue
Block a user