From c8728d4ca1aecb60c2a4ecf4b53c53ccd93beb8e Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Wed, 27 Dec 2023 09:31:04 -0800 Subject: [PATCH] feat(python): add post filtering for full text search (#739) Closes #721 fts will return results as a pyarrow table. Pyarrow tables has a `filter` method but it does not take sql filter strings (only pyarrow compute expressions). Instead, we do one of two things to support `tbl.search("keywords").where("foo=5").limit(10).to_arrow()`: Default path: If duckdb is available then use duckdb to execute the sql filter string on the pyarrow table. Backup path: Otherwise, write the pyarrow table to a lance dataset and then do `to_table(filter=)` Neither is ideal. Default path has two issues: 1. requires installing an extra library (duckdb) 2. duckdb mangles some fields (like fixed size list => list) Backup path incurs a latency penalty (~20ms on ssd) to write the resultset to disk. In the short term, once #676 is addressed, we can write the dataset to "memory://" instead of disk, this makes the post filter evaluate much quicker (ETA next week). In the longer term, we'd like to be able to evaluate the filter string on the pyarrow Table directly, one possibility being that we use Substrait to generate pyarrow compute expressions from sql string. Or if there's enough progress on pyarrow, it could support Substrait expressions directly (no ETA) --------- Co-authored-by: Will Jones --- docs/src/fts.md | 24 +++++++++++++++++++----- python/lancedb/query.py | 21 +++++++++++++++++++++ python/pyproject.toml | 2 +- python/tests/test_fts.py | 26 ++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 6 deletions(-) diff --git a/docs/src/fts.md b/docs/src/fts.md index 47f51346..78c20f6b 100644 --- a/docs/src/fts.md +++ b/docs/src/fts.md @@ -29,8 +29,9 @@ uri = "data/sample-lancedb" db = lancedb.connect(uri) table = db.create_table("my_table", - data=[{"vector": [3.1, 4.1], "text": "Frodo was a happy puppy"}, - {"vector": [5.9, 26.5], "text": "There are several kittens playing"}]) + data=[{"vector": [3.1, 4.1], "text": "Frodo was a happy puppy", "meta": "foo"}, + {"vector": [5.9, 26.5], "text": "Sam was a loyal puppy", "meta": "bar"}, + {"vector": [15.9, 6.5], "text": "There are several kittens playing"}]) ``` @@ -64,10 +65,23 @@ table.create_fts_index(["text1", "text2"]) Note that the search API call does not change - you can search over all indexed columns at once. +## Filtering + +Currently the LanceDB full text search feature supports *post-filtering*, meaning filters are +applied on top of the full text search results. This can be invoked via the familiar +`where` syntax: + +```python +table.search("puppy").limit(10).where("meta='foo'").to_list() +``` + ## Current limitations 1. Currently we do not yet support incremental writes. -If you add data after fts index creation, it won't be reflected -in search results until you do a full reindex. + If you add data after fts index creation, it won't be reflected + in search results until you do a full reindex. + +2. We currently only support local filesystem paths for the fts index. + This is a tantivy limitation. We've implemented an object store plugin + but there's no way in tantivy-py to specify to use it. -2. We currently only support local filesystem paths for the fts index. \ No newline at end of file diff --git a/python/lancedb/query.py b/python/lancedb/query.py index fe2dc86c..743602ad 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -488,6 +488,27 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): scores = pa.array(scores) output_tbl = self._table.to_lance().take(row_ids, columns=self._columns) output_tbl = output_tbl.append_column("score", scores) + + if self._where is not None: + try: + # TODO would be great to have Substrait generate pyarrow compute expressions + # or conversely have pyarrow support SQL expressions using Substrait + import duckdb + + output_tbl = ( + duckdb.sql(f"SELECT * FROM output_tbl") + .filter(self._where) + .to_arrow_table() + ) + except ImportError: + import lance + import tempfile + + # TODO Use "memory://" instead once that's supported + with tempfile.TemporaryDirectory() as tmp: + ds = lance.write_dataset(output_tbl, tmp) + output_tbl = ds.to_table(filter=self._where) + return output_tbl diff --git a/python/pyproject.toml b/python/pyproject.toml index b16f3bbe..98d3e295 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -49,7 +49,7 @@ classifiers = [ repository = "https://github.com/lancedb/lancedb" [project.optional-dependencies] -tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"] +tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "duckdb"] dev = ["ruff", "pre-commit", "black"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] diff --git a/python/tests/test_fts.py b/python/tests/test_fts.py index 2a61f3ca..f09b44ef 100644 --- a/python/tests/test_fts.py +++ b/python/tests/test_fts.py @@ -12,6 +12,7 @@ # limitations under the License. import os import random +from unittest import mock import numpy as np import pandas as pd @@ -47,6 +48,7 @@ def table(tmp_path) -> ldb.table.LanceTable: data=pd.DataFrame( { "vector": vectors, + "id": [i % 2 for i in range(100)], "text": text, "text2": text, "nested": [{"text": t} for t in text], @@ -88,6 +90,7 @@ def test_create_index_from_table(tmp_path, table): [ { "vector": np.random.randn(128), + "id": 101, "text": "gorilla", "text2": "gorilla", "nested": {"text": "gorilla"}, @@ -121,3 +124,26 @@ def test_nested_schema(tmp_path, table): table.create_fts_index("nested.text") rs = table.search("puppy").limit(10).to_list() assert len(rs) == 10 + + +def test_search_index_with_filter(table): + table.create_fts_index("text") + orig_import = __import__ + + def import_mock(name, *args): + if name == "duckdb": + raise ImportError + return orig_import(name, *args) + + # no duckdb + with mock.patch("builtins.__import__", side_effect=import_mock): + rs = table.search("puppy").where("id=1").limit(10).to_list() + for r in rs: + assert r["id"] == 1 + + # yes duckdb + rs2 = table.search("puppy").where("id=1").limit(10).to_list() + for r in rs2: + assert r["id"] == 1 + + assert rs == rs2