mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
feat(python): add post filtering for full text search (#739)
Closes #721 fts will return results as a pyarrow table. Pyarrow tables has a `filter` method but it does not take sql filter strings (only pyarrow compute expressions). Instead, we do one of two things to support `tbl.search("keywords").where("foo=5").limit(10).to_arrow()`: Default path: If duckdb is available then use duckdb to execute the sql filter string on the pyarrow table. Backup path: Otherwise, write the pyarrow table to a lance dataset and then do `to_table(filter=<filter>)` Neither is ideal. Default path has two issues: 1. requires installing an extra library (duckdb) 2. duckdb mangles some fields (like fixed size list => list) Backup path incurs a latency penalty (~20ms on ssd) to write the resultset to disk. In the short term, once #676 is addressed, we can write the dataset to "memory://" instead of disk, this makes the post filter evaluate much quicker (ETA next week). In the longer term, we'd like to be able to evaluate the filter string on the pyarrow Table directly, one possibility being that we use Substrait to generate pyarrow compute expressions from sql string. Or if there's enough progress on pyarrow, it could support Substrait expressions directly (no ETA) --------- Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
@@ -488,6 +488,27 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
scores = pa.array(scores)
|
||||
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
|
||||
output_tbl = output_tbl.append_column("score", scores)
|
||||
|
||||
if self._where is not None:
|
||||
try:
|
||||
# TODO would be great to have Substrait generate pyarrow compute expressions
|
||||
# or conversely have pyarrow support SQL expressions using Substrait
|
||||
import duckdb
|
||||
|
||||
output_tbl = (
|
||||
duckdb.sql(f"SELECT * FROM output_tbl")
|
||||
.filter(self._where)
|
||||
.to_arrow_table()
|
||||
)
|
||||
except ImportError:
|
||||
import lance
|
||||
import tempfile
|
||||
|
||||
# TODO Use "memory://" instead once that's supported
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
ds = lance.write_dataset(output_tbl, tmp)
|
||||
output_tbl = ds.to_table(filter=self._where)
|
||||
|
||||
return output_tbl
|
||||
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ classifiers = [
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
|
||||
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "duckdb"]
|
||||
dev = ["ruff", "pre-commit", "black"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
import random
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -47,6 +48,7 @@ def table(tmp_path) -> ldb.table.LanceTable:
|
||||
data=pd.DataFrame(
|
||||
{
|
||||
"vector": vectors,
|
||||
"id": [i % 2 for i in range(100)],
|
||||
"text": text,
|
||||
"text2": text,
|
||||
"nested": [{"text": t} for t in text],
|
||||
@@ -88,6 +90,7 @@ def test_create_index_from_table(tmp_path, table):
|
||||
[
|
||||
{
|
||||
"vector": np.random.randn(128),
|
||||
"id": 101,
|
||||
"text": "gorilla",
|
||||
"text2": "gorilla",
|
||||
"nested": {"text": "gorilla"},
|
||||
@@ -121,3 +124,26 @@ def test_nested_schema(tmp_path, table):
|
||||
table.create_fts_index("nested.text")
|
||||
rs = table.search("puppy").limit(10).to_list()
|
||||
assert len(rs) == 10
|
||||
|
||||
|
||||
def test_search_index_with_filter(table):
|
||||
table.create_fts_index("text")
|
||||
orig_import = __import__
|
||||
|
||||
def import_mock(name, *args):
|
||||
if name == "duckdb":
|
||||
raise ImportError
|
||||
return orig_import(name, *args)
|
||||
|
||||
# no duckdb
|
||||
with mock.patch("builtins.__import__", side_effect=import_mock):
|
||||
rs = table.search("puppy").where("id=1").limit(10).to_list()
|
||||
for r in rs:
|
||||
assert r["id"] == 1
|
||||
|
||||
# yes duckdb
|
||||
rs2 = table.search("puppy").where("id=1").limit(10).to_list()
|
||||
for r in rs2:
|
||||
assert r["id"] == 1
|
||||
|
||||
assert rs == rs2
|
||||
|
||||
Reference in New Issue
Block a user