mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
fix(python): Few fts patches (#1039)
1. filtering with fts mutated the schema, which caused schema mistmatch problems with hybrid search as it combines fts and vector search tables. 2. fts with filter failed with `with_row_id`. This was because row_id was calculated before filtering which caused size mismatch on attaching it after. 3. The fix for 1 meant that now row_id is attached before filtering but passing a filter to `to_lance` on a dataset that already contains `_rowid` raises a panic from lance. So temporarily, in case where fts is used with a filter AND `with_row_id`, we just force user to using the duckdb pathway. --------- Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
This commit is contained in:
committed by
Weston Pace
parent
c60a193767
commit
b5326d31e9
@@ -587,19 +587,26 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
scores = pa.array(scores)
|
||||
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
|
||||
output_tbl = output_tbl.append_column("score", scores)
|
||||
# this needs to match vector search results which are uint64
|
||||
row_ids = pa.array(row_ids, type=pa.uint64())
|
||||
|
||||
if self._where is not None:
|
||||
tmp_name = "__lancedb__duckdb__indexer__"
|
||||
output_tbl = output_tbl.append_column(
|
||||
tmp_name, pa.array(range(len(output_tbl)))
|
||||
)
|
||||
try:
|
||||
# TODO would be great to have Substrait generate pyarrow compute
|
||||
# expressions or conversely have pyarrow support SQL expressions
|
||||
# using Substrait
|
||||
import duckdb
|
||||
|
||||
output_tbl = (
|
||||
duckdb.sql("SELECT * FROM output_tbl")
|
||||
.filter(self._where)
|
||||
.to_arrow_table()
|
||||
)
|
||||
indexer = duckdb.sql(
|
||||
f"SELECT {tmp_name} FROM output_tbl WHERE {self._where}"
|
||||
).to_arrow_table()[tmp_name]
|
||||
output_tbl = output_tbl.take(indexer).drop([tmp_name])
|
||||
row_ids = row_ids.take(indexer)
|
||||
|
||||
except ImportError:
|
||||
import tempfile
|
||||
|
||||
@@ -609,10 +616,11 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
ds = lance.write_dataset(output_tbl, tmp)
|
||||
output_tbl = ds.to_table(filter=self._where)
|
||||
indexer = output_tbl[tmp_name]
|
||||
row_ids = row_ids.take(indexer)
|
||||
output_tbl = output_tbl.drop([tmp_name])
|
||||
|
||||
if self._with_row_id:
|
||||
# Need to set this to uint explicitly as vector results are in uint64
|
||||
row_ids = pa.array(row_ids, type=pa.uint64())
|
||||
output_tbl = output_tbl.append_column("_rowid", row_ids)
|
||||
return output_tbl
|
||||
|
||||
|
||||
@@ -137,7 +137,11 @@ def test_search_index_with_filter(table):
|
||||
|
||||
# no duckdb
|
||||
with mock.patch("builtins.__import__", side_effect=import_mock):
|
||||
rs = table.search("puppy").where("id=1").limit(10).to_list()
|
||||
rs = table.search("puppy").where("id=1").limit(10)
|
||||
# test schema
|
||||
assert rs.to_arrow().drop("score").schema.equals(table.schema)
|
||||
|
||||
rs = rs.to_list()
|
||||
for r in rs:
|
||||
assert r["id"] == 1
|
||||
|
||||
@@ -147,6 +151,10 @@ def test_search_index_with_filter(table):
|
||||
assert r["id"] == 1
|
||||
|
||||
assert rs == rs2
|
||||
rs = table.search("puppy").where("id=1").with_row_id(True).limit(10).to_list()
|
||||
for r in rs:
|
||||
assert r["id"] == 1
|
||||
assert r["_rowid"] is not None
|
||||
|
||||
|
||||
def test_null_input(table):
|
||||
|
||||
@@ -893,8 +893,17 @@ def test_hybrid_search(db, tmp_path):
|
||||
result3 = table.search(
|
||||
"Our father who art in heaven", query_type="hybrid"
|
||||
).to_pydantic(MyTable)
|
||||
|
||||
assert result1 == result3
|
||||
|
||||
# with post filters
|
||||
result = (
|
||||
table.search("Arrrrggghhhhhhh", query_type="hybrid")
|
||||
.where("text='Arrrrggghhhhhhh'")
|
||||
.to_list()
|
||||
)
|
||||
len(result) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]
|
||||
|
||||
Reference in New Issue
Block a user