diff --git a/docs/src/fts.md b/docs/src/fts.md index ac385197..8d85d563 100644 --- a/docs/src/fts.md +++ b/docs/src/fts.md @@ -75,6 +75,36 @@ applied on top of the full text search results. This can be invoked via the fami table.search("puppy").limit(10).where("meta='foo'").to_list() ``` +## Sorting + +You can pre-sort the documents by specifying `ordering_field_names` when +creating the full-text search index. Once pre-sorted, you can then specify +`ordering_field_name` while searching to return results sorted by the given +field. For example, + +``` +table.create_fts_index(["text_field"], ordering_field_names=["sort_by_field"]) + +(table.search("terms", ordering_field_name="sort_by_field") + .limit(20) + .to_list()) +``` + +!!! note + If you wish to specify an ordering field at query time, you must also + have specified it during indexing time. Otherwise at query time, an + error will be raised that looks like `ValueError: The field does not exist: xxx` + +!!! note + The fields to sort on must be of typed unsigned integer, or else you will see + an error during indexing that looks like + `TypeError: argument 'value': 'float' object cannot be interpreted as an integer`. + +!!! note + You can specify multiple fields for ordering at indexing time. + But at query time only one ordering field is supported. + + ## Phrase queries vs. terms queries For full-text search you can specify either a **phrase** query like `"the old man and the sea"`, diff --git a/python/python/lancedb/fts.py b/python/python/lancedb/fts.py index 6efa2723..970f6f2d 100644 --- a/python/python/lancedb/fts.py +++ b/python/python/lancedb/fts.py @@ -28,7 +28,9 @@ except ImportError: from .table import LanceTable -def create_index(index_path: str, text_fields: List[str]) -> tantivy.Index: +def create_index( + index_path: str, text_fields: List[str], ordering_fields: List[str] = None +) -> tantivy.Index: """ Create a new Index (not populated) @@ -38,12 +40,16 @@ def create_index(index_path: str, text_fields: List[str]) -> tantivy.Index: Path to the index directory text_fields : List[str] List of text fields to index + ordering_fields: List[str] + List of unsigned type fields to order by at search time Returns ------- index : tantivy.Index The index object (not yet populated) """ + if ordering_fields is None: + ordering_fields = [] # Declaring our schema. schema_builder = tantivy.SchemaBuilder() # special field that we'll populate with row_id @@ -51,6 +57,9 @@ def create_index(index_path: str, text_fields: List[str]) -> tantivy.Index: # data fields for name in text_fields: schema_builder.add_text_field(name, stored=True) + if ordering_fields: + for name in ordering_fields: + schema_builder.add_unsigned_field(name, fast=True) schema = schema_builder.build() os.makedirs(index_path, exist_ok=True) index = tantivy.Index(schema, path=index_path) @@ -62,6 +71,7 @@ def populate_index( table: LanceTable, fields: List[str], writer_heap_size: int = 1024 * 1024 * 1024, + ordering_fields: List[str] = None, ) -> int: """ Populate an index with data from a LanceTable @@ -82,8 +92,11 @@ def populate_index( int The number of rows indexed """ + if ordering_fields is None: + ordering_fields = [] # first check the fields exist and are string or large string type nested = [] + for name in fields: try: f = table.schema.field(name) # raises KeyError if not found @@ -104,7 +117,7 @@ def populate_index( if len(nested) > 0: max_nested_level = max([len(name.split(".")) for name in nested]) - for b in dataset.to_batches(columns=fields): + for b in dataset.to_batches(columns=fields + ordering_fields): if max_nested_level > 0: b = pa.Table.from_batches([b]) for _ in range(max_nested_level - 1): @@ -115,6 +128,10 @@ def populate_index( value = b[name][i].as_py() if value is not None: doc.add_text(name, value) + for name in ordering_fields: + value = b[name][i].as_py() + if value is not None: + doc.add_unsigned(name, value) if not doc.is_empty: doc.add_integer("doc_id", row_id) writer.add_document(doc) @@ -149,7 +166,7 @@ def resolve_path(schema, field_name: str) -> pa.Field: def search_index( - index: tantivy.Index, query: str, limit: int = 10 + index: tantivy.Index, query: str, limit: int = 10, ordering_field=None ) -> Tuple[Tuple[int], Tuple[float]]: """ Search an index for a query @@ -172,7 +189,10 @@ def search_index( searcher = index.searcher() query = index.parse_query(query) # get top results - results = searcher.search(query, limit) + if ordering_field: + results = searcher.search(query, limit, order_by_field=ordering_field) + else: + results = searcher.search(query, limit) if results.count == 0: return tuple(), tuple() return tuple( diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 8a7b231c..abc577e0 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -120,6 +120,7 @@ class LanceQueryBuilder(ABC): query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]], query_type: str, vector_column_name: str, + ordering_field_name: str = None, ) -> LanceQueryBuilder: """ Create a query builder based on the given query and query type. @@ -158,7 +159,9 @@ class LanceQueryBuilder(ABC): if isinstance(query, str): # fts - return LanceFtsQueryBuilder(table, query) + return LanceFtsQueryBuilder( + table, query, ordering_field_name=ordering_field_name + ) if isinstance(query, list): query = np.array(query, dtype=np.float32) @@ -597,10 +600,11 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): class LanceFtsQueryBuilder(LanceQueryBuilder): """A builder for full text search for LanceDB.""" - def __init__(self, table: "Table", query: str): + def __init__(self, table: "Table", query: str, ordering_field_name: str = None): super().__init__(table) self._query = query self._phrase_query = False + self.ordering_field_name = ordering_field_name self._reranker = None def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder: @@ -646,7 +650,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): if self._phrase_query: query = query.replace('"', "'") query = f'"{query}"' - row_ids, scores = search_index(index, query, self._limit) + row_ids, scores = search_index( + index, query, self._limit, ordering_field=self.ordering_field_name + ) if len(row_ids) == 0: empty_schema = pa.schema([pa.field("score", pa.float32())]) return pa.Table.from_pylist([], schema=empty_schema) diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 9e5227d5..75d4a2b0 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -1157,6 +1157,7 @@ class LanceTable(Table): def create_fts_index( self, field_names: Union[str, List[str]], + ordering_field_names: Union[str, List[str]] = None, *, replace: bool = False, writer_heap_size: Optional[int] = 1024 * 1024 * 1024, @@ -1175,12 +1176,18 @@ class LanceTable(Table): not yet an atomic operation; the index will be temporarily unavailable while the new index is being created. writer_heap_size: int, default 1GB + ordering_field_names: + A list of unsigned type fields to index to optionally order + results on at search time """ from .fts import create_index, populate_index if isinstance(field_names, str): field_names = [field_names] + if isinstance(ordering_field_names, str): + ordering_field_names = [ordering_field_names] + fs, path = fs_from_uri(self._get_fts_index_path()) index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound if index_exists: @@ -1188,8 +1195,18 @@ class LanceTable(Table): raise ValueError("Index already exists. Use replace=True to overwrite.") fs.delete_dir(path) - index = create_index(self._get_fts_index_path(), field_names) - populate_index(index, self, field_names, writer_heap_size=writer_heap_size) + index = create_index( + self._get_fts_index_path(), + field_names, + ordering_fields=ordering_field_names, + ) + populate_index( + index, + self, + field_names, + ordering_fields=ordering_field_names, + writer_heap_size=writer_heap_size, + ) def _get_fts_index_path(self): return join_uri(self._dataset_uri, "_indices", "tantivy") @@ -1320,6 +1337,7 @@ class LanceTable(Table): query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, vector_column_name: Optional[str] = None, query_type: str = "auto", + ordering_field_name: Optional[str] = None, ) -> LanceQueryBuilder: """Create a search query to find the nearest neighbors of the given query vector. We currently support [vector search][search] @@ -1386,7 +1404,11 @@ class LanceTable(Table): if vector_column_name is None and query is not None: vector_column_name = inf_vector_column_query(self.schema) return LanceQueryBuilder.create( - self, query, query_type, vector_column_name=vector_column_name + self, + query, + query_type, + vector_column_name=vector_column_name, + ordering_field_name=ordering_field_name, ) @classmethod diff --git a/python/python/tests/test_fts.py b/python/python/tests/test_fts.py index fbf74662..d04e0cb2 100644 --- a/python/python/tests/test_fts.py +++ b/python/python/tests/test_fts.py @@ -43,6 +43,7 @@ def table(tmp_path) -> ldb.table.LanceTable: ) for _ in range(100) ] + count = [random.randint(1, 10000) for _ in range(100)] table = db.create_table( "test", data=pd.DataFrame( @@ -52,6 +53,7 @@ def table(tmp_path) -> ldb.table.LanceTable: "text": text, "text2": text, "nested": [{"text": t} for t in text], + "count": count, } ), ) @@ -79,6 +81,39 @@ def test_search_index(tmp_path, table): assert len(results[1]) == 10 # _distance +def test_search_ordering_field_index_table(tmp_path, table): + table.create_fts_index("text", ordering_field_names=["count"]) + rows = ( + table.search("puppy", ordering_field_name="count") + .limit(20) + .select(["text", "count"]) + .to_list() + ) + for r in rows: + assert "puppy" in r["text"] + assert sorted(rows, key=lambda x: x["count"], reverse=True) == rows + + +def test_search_ordering_field_index(tmp_path, table): + index = ldb.fts.create_index( + str(tmp_path / "index"), ["text"], ordering_fields=["count"] + ) + + ldb.fts.populate_index(index, table, ["text"], ordering_fields=["count"]) + index.reload() + results = ldb.fts.search_index( + index, query="puppy", limit=10, ordering_field="count" + ) + assert len(results) == 2 + assert len(results[0]) == 10 # row_ids + assert len(results[1]) == 10 # _distance + rows = table.to_lance().take(results[0]).to_pylist() + + for r in rows: + assert "puppy" in r["text"] + assert sorted(rows, key=lambda x: x["count"], reverse=True) == rows + + def test_create_index_from_table(tmp_path, table): table.create_fts_index("text") df = table.search("puppy").limit(10).select(["text"]).to_pandas() @@ -94,6 +129,7 @@ def test_create_index_from_table(tmp_path, table): "text": "gorilla", "text2": "gorilla", "nested": {"text": "gorilla"}, + "count": 10, } ] ) @@ -166,6 +202,7 @@ def test_null_input(table): "text": None, "text2": None, "nested": {"text": None}, + "count": 7, } ] )