diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 04cbaa77..4816acf0 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -587,19 +587,26 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): scores = pa.array(scores) output_tbl = self._table.to_lance().take(row_ids, columns=self._columns) output_tbl = output_tbl.append_column("score", scores) + # this needs to match vector search results which are uint64 + row_ids = pa.array(row_ids, type=pa.uint64()) if self._where is not None: + tmp_name = "__lancedb__duckdb__indexer__" + output_tbl = output_tbl.append_column( + tmp_name, pa.array(range(len(output_tbl))) + ) try: # TODO would be great to have Substrait generate pyarrow compute # expressions or conversely have pyarrow support SQL expressions # using Substrait import duckdb - output_tbl = ( - duckdb.sql("SELECT * FROM output_tbl") - .filter(self._where) - .to_arrow_table() - ) + indexer = duckdb.sql( + f"SELECT {tmp_name} FROM output_tbl WHERE {self._where}" + ).to_arrow_table()[tmp_name] + output_tbl = output_tbl.take(indexer).drop([tmp_name]) + row_ids = row_ids.take(indexer) + except ImportError: import tempfile @@ -609,10 +616,11 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): with tempfile.TemporaryDirectory() as tmp: ds = lance.write_dataset(output_tbl, tmp) output_tbl = ds.to_table(filter=self._where) + indexer = output_tbl[tmp_name] + row_ids = row_ids.take(indexer) + output_tbl = output_tbl.drop([tmp_name]) if self._with_row_id: - # Need to set this to uint explicitly as vector results are in uint64 - row_ids = pa.array(row_ids, type=pa.uint64()) output_tbl = output_tbl.append_column("_rowid", row_ids) return output_tbl diff --git a/python/python/tests/test_fts.py b/python/python/tests/test_fts.py index e884d605..aa6bfa61 100644 --- a/python/python/tests/test_fts.py +++ b/python/python/tests/test_fts.py @@ -137,7 +137,11 @@ def test_search_index_with_filter(table): # no duckdb with mock.patch("builtins.__import__", side_effect=import_mock): - rs = table.search("puppy").where("id=1").limit(10).to_list() + rs = table.search("puppy").where("id=1").limit(10) + # test schema + assert rs.to_arrow().drop("score").schema.equals(table.schema) + + rs = rs.to_list() for r in rs: assert r["id"] == 1 @@ -147,6 +151,10 @@ def test_search_index_with_filter(table): assert r["id"] == 1 assert rs == rs2 + rs = table.search("puppy").where("id=1").with_row_id(True).limit(10).to_list() + for r in rs: + assert r["id"] == 1 + assert r["_rowid"] is not None def test_null_input(table): diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 564c3829..ccad7d58 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -893,8 +893,17 @@ def test_hybrid_search(db, tmp_path): result3 = table.search( "Our father who art in heaven", query_type="hybrid" ).to_pydantic(MyTable) + assert result1 == result3 + # with post filters + result = ( + table.search("Arrrrggghhhhhhh", query_type="hybrid") + .where("text='Arrrrggghhhhhhh'") + .to_list() + ) + len(result) == 1 + @pytest.mark.parametrize( "consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]