From b60a2177aee38eb2a2b6da5f43a95845dd612ee4 Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Wed, 20 Dec 2023 12:28:53 -0800 Subject: [PATCH] feat(python): support nested reference for fts (#723) https://github.com/lancedb/lance/issues/1739 Support nested field reference in full text search --------- Co-authored-by: Will Jones --- python/lancedb/fts.py | 41 +++++++++++++++++++++++++++++++++++++++- python/tests/test_fts.py | 16 +++++++++++++++- 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/python/lancedb/fts.py b/python/lancedb/fts.py index fe3895d7..f187be8d 100644 --- a/python/lancedb/fts.py +++ b/python/lancedb/fts.py @@ -75,8 +75,14 @@ def populate_index(index: tantivy.Index, table: LanceTable, fields: List[str]) - The number of rows indexed """ # first check the fields exist and are string or large string type + nested = [] for name in fields: - f = table.schema.field(name) # raises KeyError if not found + try: + f = table.schema.field(name) # raises KeyError if not found + except KeyError: + f = resolve_path(table.schema, name) + nested.append(name) + if not pa.types.is_string(f.type) and not pa.types.is_large_string(f.type): raise TypeError(f"Field {name} is not a string type") @@ -85,7 +91,16 @@ def populate_index(index: tantivy.Index, table: LanceTable, fields: List[str]) - # write data into index dataset = table.to_lance() row_id = 0 + + max_nested_level = 0 + if len(nested) > 0: + max_nested_level = max([len(name.split(".")) for name in nested]) + for b in dataset.to_batches(columns=fields): + if max_nested_level > 0: + b = pa.Table.from_batches([b]) + for _ in range(max_nested_level - 1): + b = b.flatten() for i in range(b.num_rows): doc = tantivy.Document() doc.add_integer("doc_id", row_id) @@ -98,6 +113,30 @@ def populate_index(index: tantivy.Index, table: LanceTable, fields: List[str]) - return row_id +def resolve_path(schema, field_name: str) -> pa.Field: + """ + Resolve a nested field path to a list of field names + + Parameters + ---------- + field_name : str + The field name to resolve + + Returns + ------- + List[str] + The resolved path + """ + path = field_name.split(".") + field = schema.field(path.pop(0)) + for segment in path: + if pa.types.is_struct(field.type): + field = field.type.field(segment) + else: + raise KeyError(f"field {field_name} not found in schema {schema}") + return field + + def search_index( index: tantivy.Index, query: str, limit: int = 10 ) -> Tuple[Tuple[int], Tuple[float]]: diff --git a/python/tests/test_fts.py b/python/tests/test_fts.py index c5c69859..301a4d23 100644 --- a/python/tests/test_fts.py +++ b/python/tests/test_fts.py @@ -43,7 +43,15 @@ def table(tmp_path) -> ldb.table.LanceTable: for _ in range(100) ] table = db.create_table( - "test", data=pd.DataFrame({"vector": vectors, "text": text, "text2": text}) + "test", + data=pd.DataFrame( + { + "vector": vectors, + "text": text, + "text2": text, + "nested": [{"text": t} for t in text], + } + ), ) return table @@ -89,3 +97,9 @@ def test_empty_rs(tmp_path, table, mocker): mocker.patch("lancedb.fts.search_index", return_value=([], [])) df = table.search("puppy").limit(10).to_pandas() assert len(df) == 0 + + +def test_nested_schema(tmp_path, table): + table.create_fts_index("nested.text") + rs = table.search("puppy").limit(10).to_list() + assert len(rs) == 10