diff --git a/docs/src/fts.md b/docs/src/fts.md index 47f51346..78c20f6b 100644 --- a/docs/src/fts.md +++ b/docs/src/fts.md @@ -29,8 +29,9 @@ uri = "data/sample-lancedb" db = lancedb.connect(uri) table = db.create_table("my_table", - data=[{"vector": [3.1, 4.1], "text": "Frodo was a happy puppy"}, - {"vector": [5.9, 26.5], "text": "There are several kittens playing"}]) + data=[{"vector": [3.1, 4.1], "text": "Frodo was a happy puppy", "meta": "foo"}, + {"vector": [5.9, 26.5], "text": "Sam was a loyal puppy", "meta": "bar"}, + {"vector": [15.9, 6.5], "text": "There are several kittens playing"}]) ``` @@ -64,10 +65,23 @@ table.create_fts_index(["text1", "text2"]) Note that the search API call does not change - you can search over all indexed columns at once. +## Filtering + +Currently the LanceDB full text search feature supports *post-filtering*, meaning filters are +applied on top of the full text search results. This can be invoked via the familiar +`where` syntax: + +```python +table.search("puppy").limit(10).where("meta='foo'").to_list() +``` + ## Current limitations 1. Currently we do not yet support incremental writes. -If you add data after fts index creation, it won't be reflected -in search results until you do a full reindex. + If you add data after fts index creation, it won't be reflected + in search results until you do a full reindex. + +2. We currently only support local filesystem paths for the fts index. + This is a tantivy limitation. We've implemented an object store plugin + but there's no way in tantivy-py to specify to use it. -2. We currently only support local filesystem paths for the fts index. \ No newline at end of file diff --git a/python/lancedb/query.py b/python/lancedb/query.py index fe2dc86c..743602ad 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -488,6 +488,27 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): scores = pa.array(scores) output_tbl = self._table.to_lance().take(row_ids, columns=self._columns) output_tbl = output_tbl.append_column("score", scores) + + if self._where is not None: + try: + # TODO would be great to have Substrait generate pyarrow compute expressions + # or conversely have pyarrow support SQL expressions using Substrait + import duckdb + + output_tbl = ( + duckdb.sql(f"SELECT * FROM output_tbl") + .filter(self._where) + .to_arrow_table() + ) + except ImportError: + import lance + import tempfile + + # TODO Use "memory://" instead once that's supported + with tempfile.TemporaryDirectory() as tmp: + ds = lance.write_dataset(output_tbl, tmp) + output_tbl = ds.to_table(filter=self._where) + return output_tbl diff --git a/python/pyproject.toml b/python/pyproject.toml index b16f3bbe..98d3e295 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -49,7 +49,7 @@ classifiers = [ repository = "https://github.com/lancedb/lancedb" [project.optional-dependencies] -tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"] +tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "duckdb"] dev = ["ruff", "pre-commit", "black"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] diff --git a/python/tests/test_fts.py b/python/tests/test_fts.py index 2a61f3ca..f09b44ef 100644 --- a/python/tests/test_fts.py +++ b/python/tests/test_fts.py @@ -12,6 +12,7 @@ # limitations under the License. import os import random +from unittest import mock import numpy as np import pandas as pd @@ -47,6 +48,7 @@ def table(tmp_path) -> ldb.table.LanceTable: data=pd.DataFrame( { "vector": vectors, + "id": [i % 2 for i in range(100)], "text": text, "text2": text, "nested": [{"text": t} for t in text], @@ -88,6 +90,7 @@ def test_create_index_from_table(tmp_path, table): [ { "vector": np.random.randn(128), + "id": 101, "text": "gorilla", "text2": "gorilla", "nested": {"text": "gorilla"}, @@ -121,3 +124,26 @@ def test_nested_schema(tmp_path, table): table.create_fts_index("nested.text") rs = table.search("puppy").limit(10).to_list() assert len(rs) == 10 + + +def test_search_index_with_filter(table): + table.create_fts_index("text") + orig_import = __import__ + + def import_mock(name, *args): + if name == "duckdb": + raise ImportError + return orig_import(name, *args) + + # no duckdb + with mock.patch("builtins.__import__", side_effect=import_mock): + rs = table.search("puppy").where("id=1").limit(10).to_list() + for r in rs: + assert r["id"] == 1 + + # yes duckdb + rs2 = table.search("puppy").where("id=1").limit(10).to_list() + for r in rs2: + assert r["id"] == 1 + + assert rs == rs2