fix(python): Few fts patches (#1039)

1. filtering with fts mutated the schema, which caused schema mistmatch
problems with hybrid search as it combines fts and vector search tables.
2. fts with filter failed with `with_row_id`. This was because row_id
was calculated before filtering which caused size mismatch on attaching
it after.
3. The fix for 1 meant that now row_id is attached before filtering but
passing a filter to `to_lance` on a dataset that already contains
`_rowid` raises a panic from lance. So temporarily, in case where fts is
used with a filter AND `with_row_id`, we just force user to using the
duckdb pathway.

---------

Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2024-03-05 06:11:59 +05:30
committed by Weston Pace
parent c60a193767
commit b5326d31e9
3 changed files with 33 additions and 8 deletions

View File

@@ -587,19 +587,26 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
scores = pa.array(scores)
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
output_tbl = output_tbl.append_column("score", scores)
# this needs to match vector search results which are uint64
row_ids = pa.array(row_ids, type=pa.uint64())
if self._where is not None:
tmp_name = "__lancedb__duckdb__indexer__"
output_tbl = output_tbl.append_column(
tmp_name, pa.array(range(len(output_tbl)))
)
try:
# TODO would be great to have Substrait generate pyarrow compute
# expressions or conversely have pyarrow support SQL expressions
# using Substrait
import duckdb
output_tbl = (
duckdb.sql("SELECT * FROM output_tbl")
.filter(self._where)
.to_arrow_table()
)
indexer = duckdb.sql(
f"SELECT {tmp_name} FROM output_tbl WHERE {self._where}"
).to_arrow_table()[tmp_name]
output_tbl = output_tbl.take(indexer).drop([tmp_name])
row_ids = row_ids.take(indexer)
except ImportError:
import tempfile
@@ -609,10 +616,11 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
with tempfile.TemporaryDirectory() as tmp:
ds = lance.write_dataset(output_tbl, tmp)
output_tbl = ds.to_table(filter=self._where)
indexer = output_tbl[tmp_name]
row_ids = row_ids.take(indexer)
output_tbl = output_tbl.drop([tmp_name])
if self._with_row_id:
# Need to set this to uint explicitly as vector results are in uint64
row_ids = pa.array(row_ids, type=pa.uint64())
output_tbl = output_tbl.append_column("_rowid", row_ids)
return output_tbl

View File

@@ -137,7 +137,11 @@ def test_search_index_with_filter(table):
# no duckdb
with mock.patch("builtins.__import__", side_effect=import_mock):
rs = table.search("puppy").where("id=1").limit(10).to_list()
rs = table.search("puppy").where("id=1").limit(10)
# test schema
assert rs.to_arrow().drop("score").schema.equals(table.schema)
rs = rs.to_list()
for r in rs:
assert r["id"] == 1
@@ -147,6 +151,10 @@ def test_search_index_with_filter(table):
assert r["id"] == 1
assert rs == rs2
rs = table.search("puppy").where("id=1").with_row_id(True).limit(10).to_list()
for r in rs:
assert r["id"] == 1
assert r["_rowid"] is not None
def test_null_input(table):

View File

@@ -893,8 +893,17 @@ def test_hybrid_search(db, tmp_path):
result3 = table.search(
"Our father who art in heaven", query_type="hybrid"
).to_pydantic(MyTable)
assert result1 == result3
# with post filters
result = (
table.search("Arrrrggghhhhhhh", query_type="hybrid")
.where("text='Arrrrggghhhhhhh'")
.to_list()
)
len(result) == 1
@pytest.mark.parametrize(
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]