mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
feat(python): support nested reference for fts (#723)
https://github.com/lancedb/lance/issues/1739 Support nested field reference in full text search --------- Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
@@ -75,8 +75,14 @@ def populate_index(index: tantivy.Index, table: LanceTable, fields: List[str]) -
|
||||
The number of rows indexed
|
||||
"""
|
||||
# first check the fields exist and are string or large string type
|
||||
nested = []
|
||||
for name in fields:
|
||||
f = table.schema.field(name) # raises KeyError if not found
|
||||
try:
|
||||
f = table.schema.field(name) # raises KeyError if not found
|
||||
except KeyError:
|
||||
f = resolve_path(table.schema, name)
|
||||
nested.append(name)
|
||||
|
||||
if not pa.types.is_string(f.type) and not pa.types.is_large_string(f.type):
|
||||
raise TypeError(f"Field {name} is not a string type")
|
||||
|
||||
@@ -85,7 +91,16 @@ def populate_index(index: tantivy.Index, table: LanceTable, fields: List[str]) -
|
||||
# write data into index
|
||||
dataset = table.to_lance()
|
||||
row_id = 0
|
||||
|
||||
max_nested_level = 0
|
||||
if len(nested) > 0:
|
||||
max_nested_level = max([len(name.split(".")) for name in nested])
|
||||
|
||||
for b in dataset.to_batches(columns=fields):
|
||||
if max_nested_level > 0:
|
||||
b = pa.Table.from_batches([b])
|
||||
for _ in range(max_nested_level - 1):
|
||||
b = b.flatten()
|
||||
for i in range(b.num_rows):
|
||||
doc = tantivy.Document()
|
||||
doc.add_integer("doc_id", row_id)
|
||||
@@ -98,6 +113,30 @@ def populate_index(index: tantivy.Index, table: LanceTable, fields: List[str]) -
|
||||
return row_id
|
||||
|
||||
|
||||
def resolve_path(schema, field_name: str) -> pa.Field:
|
||||
"""
|
||||
Resolve a nested field path to a list of field names
|
||||
|
||||
Parameters
|
||||
----------
|
||||
field_name : str
|
||||
The field name to resolve
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
The resolved path
|
||||
"""
|
||||
path = field_name.split(".")
|
||||
field = schema.field(path.pop(0))
|
||||
for segment in path:
|
||||
if pa.types.is_struct(field.type):
|
||||
field = field.type.field(segment)
|
||||
else:
|
||||
raise KeyError(f"field {field_name} not found in schema {schema}")
|
||||
return field
|
||||
|
||||
|
||||
def search_index(
|
||||
index: tantivy.Index, query: str, limit: int = 10
|
||||
) -> Tuple[Tuple[int], Tuple[float]]:
|
||||
|
||||
@@ -43,7 +43,15 @@ def table(tmp_path) -> ldb.table.LanceTable:
|
||||
for _ in range(100)
|
||||
]
|
||||
table = db.create_table(
|
||||
"test", data=pd.DataFrame({"vector": vectors, "text": text, "text2": text})
|
||||
"test",
|
||||
data=pd.DataFrame(
|
||||
{
|
||||
"vector": vectors,
|
||||
"text": text,
|
||||
"text2": text,
|
||||
"nested": [{"text": t} for t in text],
|
||||
}
|
||||
),
|
||||
)
|
||||
return table
|
||||
|
||||
@@ -89,3 +97,9 @@ def test_empty_rs(tmp_path, table, mocker):
|
||||
mocker.patch("lancedb.fts.search_index", return_value=([], []))
|
||||
df = table.search("puppy").limit(10).to_pandas()
|
||||
assert len(df) == 0
|
||||
|
||||
|
||||
def test_nested_schema(tmp_path, table):
|
||||
table.create_fts_index("nested.text")
|
||||
rs = table.search("puppy").limit(10).to_list()
|
||||
assert len(rs) == 10
|
||||
|
||||
Reference in New Issue
Block a user