Order by field support FTS (#1132)

This PR adds support for passing through a set of ordering fields at
index time (unsigned ints that tantivity can use as fast_fields) that at
query time you can sort your results on. This is useful for cases where
you want to get related hits, i.e by keyword, but order those hits by
some other score, such as popularity.

I.e search for songs descriptions that match on "sad AND jazz AND 1920"
and then order those by number of times played. Example usage can be
seen in the fts tests.

---------

Co-authored-by: Nat Roth <natroth@Nats-MacBook-Pro.local>
Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
This commit is contained in:
natcharacter
2024-03-20 04:27:37 -04:00
committed by Weston Pace
parent 4466cfa958
commit f6e9f8e3f4
5 changed files with 125 additions and 10 deletions

View File

@@ -28,7 +28,9 @@ except ImportError:
from .table import LanceTable
def create_index(index_path: str, text_fields: List[str]) -> tantivy.Index:
def create_index(
index_path: str, text_fields: List[str], ordering_fields: List[str] = None
) -> tantivy.Index:
"""
Create a new Index (not populated)
@@ -38,12 +40,16 @@ def create_index(index_path: str, text_fields: List[str]) -> tantivy.Index:
Path to the index directory
text_fields : List[str]
List of text fields to index
ordering_fields: List[str]
List of unsigned type fields to order by at search time
Returns
-------
index : tantivy.Index
The index object (not yet populated)
"""
if ordering_fields is None:
ordering_fields = []
# Declaring our schema.
schema_builder = tantivy.SchemaBuilder()
# special field that we'll populate with row_id
@@ -51,6 +57,9 @@ def create_index(index_path: str, text_fields: List[str]) -> tantivy.Index:
# data fields
for name in text_fields:
schema_builder.add_text_field(name, stored=True)
if ordering_fields:
for name in ordering_fields:
schema_builder.add_unsigned_field(name, fast=True)
schema = schema_builder.build()
os.makedirs(index_path, exist_ok=True)
index = tantivy.Index(schema, path=index_path)
@@ -62,6 +71,7 @@ def populate_index(
table: LanceTable,
fields: List[str],
writer_heap_size: int = 1024 * 1024 * 1024,
ordering_fields: List[str] = None,
) -> int:
"""
Populate an index with data from a LanceTable
@@ -82,8 +92,11 @@ def populate_index(
int
The number of rows indexed
"""
if ordering_fields is None:
ordering_fields = []
# first check the fields exist and are string or large string type
nested = []
for name in fields:
try:
f = table.schema.field(name) # raises KeyError if not found
@@ -104,7 +117,7 @@ def populate_index(
if len(nested) > 0:
max_nested_level = max([len(name.split(".")) for name in nested])
for b in dataset.to_batches(columns=fields):
for b in dataset.to_batches(columns=fields + ordering_fields):
if max_nested_level > 0:
b = pa.Table.from_batches([b])
for _ in range(max_nested_level - 1):
@@ -115,6 +128,10 @@ def populate_index(
value = b[name][i].as_py()
if value is not None:
doc.add_text(name, value)
for name in ordering_fields:
value = b[name][i].as_py()
if value is not None:
doc.add_unsigned(name, value)
if not doc.is_empty:
doc.add_integer("doc_id", row_id)
writer.add_document(doc)
@@ -149,7 +166,7 @@ def resolve_path(schema, field_name: str) -> pa.Field:
def search_index(
index: tantivy.Index, query: str, limit: int = 10
index: tantivy.Index, query: str, limit: int = 10, ordering_field=None
) -> Tuple[Tuple[int], Tuple[float]]:
"""
Search an index for a query
@@ -172,7 +189,10 @@ def search_index(
searcher = index.searcher()
query = index.parse_query(query)
# get top results
results = searcher.search(query, limit)
if ordering_field:
results = searcher.search(query, limit, order_by_field=ordering_field)
else:
results = searcher.search(query, limit)
if results.count == 0:
return tuple(), tuple()
return tuple(

View File

@@ -120,6 +120,7 @@ class LanceQueryBuilder(ABC):
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
query_type: str,
vector_column_name: str,
ordering_field_name: str = None,
) -> LanceQueryBuilder:
"""
Create a query builder based on the given query and query type.
@@ -158,7 +159,9 @@ class LanceQueryBuilder(ABC):
if isinstance(query, str):
# fts
return LanceFtsQueryBuilder(table, query)
return LanceFtsQueryBuilder(
table, query, ordering_field_name=ordering_field_name
)
if isinstance(query, list):
query = np.array(query, dtype=np.float32)
@@ -597,10 +600,11 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
class LanceFtsQueryBuilder(LanceQueryBuilder):
"""A builder for full text search for LanceDB."""
def __init__(self, table: "Table", query: str):
def __init__(self, table: "Table", query: str, ordering_field_name: str = None):
super().__init__(table)
self._query = query
self._phrase_query = False
self.ordering_field_name = ordering_field_name
self._reranker = None
def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder:
@@ -646,7 +650,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
if self._phrase_query:
query = query.replace('"', "'")
query = f'"{query}"'
row_ids, scores = search_index(index, query, self._limit)
row_ids, scores = search_index(
index, query, self._limit, ordering_field=self.ordering_field_name
)
if len(row_ids) == 0:
empty_schema = pa.schema([pa.field("score", pa.float32())])
return pa.Table.from_pylist([], schema=empty_schema)

View File

@@ -1157,6 +1157,7 @@ class LanceTable(Table):
def create_fts_index(
self,
field_names: Union[str, List[str]],
ordering_field_names: Union[str, List[str]] = None,
*,
replace: bool = False,
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
@@ -1175,12 +1176,18 @@ class LanceTable(Table):
not yet an atomic operation; the index will be temporarily
unavailable while the new index is being created.
writer_heap_size: int, default 1GB
ordering_field_names:
A list of unsigned type fields to index to optionally order
results on at search time
"""
from .fts import create_index, populate_index
if isinstance(field_names, str):
field_names = [field_names]
if isinstance(ordering_field_names, str):
ordering_field_names = [ordering_field_names]
fs, path = fs_from_uri(self._get_fts_index_path())
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
if index_exists:
@@ -1188,8 +1195,18 @@ class LanceTable(Table):
raise ValueError("Index already exists. Use replace=True to overwrite.")
fs.delete_dir(path)
index = create_index(self._get_fts_index_path(), field_names)
populate_index(index, self, field_names, writer_heap_size=writer_heap_size)
index = create_index(
self._get_fts_index_path(),
field_names,
ordering_fields=ordering_field_names,
)
populate_index(
index,
self,
field_names,
ordering_fields=ordering_field_names,
writer_heap_size=writer_heap_size,
)
def _get_fts_index_path(self):
return join_uri(self._dataset_uri, "_indices", "tantivy")
@@ -1320,6 +1337,7 @@ class LanceTable(Table):
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: str = "auto",
ordering_field_name: Optional[str] = None,
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
@@ -1386,7 +1404,11 @@ class LanceTable(Table):
if vector_column_name is None and query is not None:
vector_column_name = inf_vector_column_query(self.schema)
return LanceQueryBuilder.create(
self, query, query_type, vector_column_name=vector_column_name
self,
query,
query_type,
vector_column_name=vector_column_name,
ordering_field_name=ordering_field_name,
)
@classmethod

View File

@@ -43,6 +43,7 @@ def table(tmp_path) -> ldb.table.LanceTable:
)
for _ in range(100)
]
count = [random.randint(1, 10000) for _ in range(100)]
table = db.create_table(
"test",
data=pd.DataFrame(
@@ -52,6 +53,7 @@ def table(tmp_path) -> ldb.table.LanceTable:
"text": text,
"text2": text,
"nested": [{"text": t} for t in text],
"count": count,
}
),
)
@@ -79,6 +81,39 @@ def test_search_index(tmp_path, table):
assert len(results[1]) == 10 # _distance
def test_search_ordering_field_index_table(tmp_path, table):
table.create_fts_index("text", ordering_field_names=["count"])
rows = (
table.search("puppy", ordering_field_name="count")
.limit(20)
.select(["text", "count"])
.to_list()
)
for r in rows:
assert "puppy" in r["text"]
assert sorted(rows, key=lambda x: x["count"], reverse=True) == rows
def test_search_ordering_field_index(tmp_path, table):
index = ldb.fts.create_index(
str(tmp_path / "index"), ["text"], ordering_fields=["count"]
)
ldb.fts.populate_index(index, table, ["text"], ordering_fields=["count"])
index.reload()
results = ldb.fts.search_index(
index, query="puppy", limit=10, ordering_field="count"
)
assert len(results) == 2
assert len(results[0]) == 10 # row_ids
assert len(results[1]) == 10 # _distance
rows = table.to_lance().take(results[0]).to_pylist()
for r in rows:
assert "puppy" in r["text"]
assert sorted(rows, key=lambda x: x["count"], reverse=True) == rows
def test_create_index_from_table(tmp_path, table):
table.create_fts_index("text")
df = table.search("puppy").limit(10).select(["text"]).to_pandas()
@@ -94,6 +129,7 @@ def test_create_index_from_table(tmp_path, table):
"text": "gorilla",
"text2": "gorilla",
"nested": {"text": "gorilla"},
"count": 10,
}
]
)
@@ -166,6 +202,7 @@ def test_null_input(table):
"text": None,
"text2": None,
"nested": {"text": None},
"count": 7,
}
]
)