mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-22 06:20:39 +00:00
fix: discover nested vector columns by default (#3423)
LanceDB default vector column discovery only considered top-level fields, so tables with a single nested vector leaf still required users to pass an explicit field path. This updates Rust and Python discovery to recurse into struct fields, return canonical field paths, and preserve actionable errors when no default or multiple defaults exist. The explicit nested path flow for index creation and search remains supported across Rust, Python, and Node, with regression coverage for single nested vector leaves, multiple candidate leaves, and schemas without vector leaves. Closes #3405.
This commit is contained in:
@@ -10,7 +10,7 @@ import pathlib
|
||||
import warnings
|
||||
from datetime import date, datetime
|
||||
from functools import singledispatch
|
||||
from typing import Tuple, Union, Optional, Any
|
||||
from typing import Tuple, Union, Optional, Any, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
@@ -189,7 +189,33 @@ def flatten_columns(tbl: pa.Table, flatten: Optional[Union[int, bool]] = None):
|
||||
return tbl
|
||||
|
||||
|
||||
def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||
def _format_field_path(path: List[str]) -> str:
|
||||
def format_segment(segment: str) -> str:
|
||||
if all(char.isalnum() or char == "_" for char in segment):
|
||||
return segment
|
||||
return f"`{segment.replace('`', '``')}`"
|
||||
|
||||
return ".".join(format_segment(segment) for segment in path)
|
||||
|
||||
|
||||
def _iter_vector_columns(
|
||||
field: pa.Field, path: List[str], dim: Optional[int] = None
|
||||
) -> List[str]:
|
||||
field_path = [*path, field.name]
|
||||
if is_vector_column(field.type):
|
||||
vector_dim = infer_vector_column_dim(field.type)
|
||||
if dim is None or vector_dim == dim:
|
||||
return [_format_field_path(field_path)]
|
||||
return []
|
||||
if pa.types.is_struct(field.type):
|
||||
columns = []
|
||||
for idx in range(field.type.num_fields):
|
||||
columns.extend(_iter_vector_columns(field.type.field(idx), field_path, dim))
|
||||
return columns
|
||||
return []
|
||||
|
||||
|
||||
def inf_vector_column_query(schema: pa.Schema, dim: Optional[int] = None) -> str:
|
||||
"""
|
||||
Get the vector column name
|
||||
|
||||
@@ -202,26 +228,21 @@ def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||
-------
|
||||
str: the vector column name.
|
||||
"""
|
||||
vector_col_name = ""
|
||||
vector_col_count = 0
|
||||
for field_name in schema.names:
|
||||
field = schema.field(field_name)
|
||||
if is_vector_column(field.type):
|
||||
vector_col_count += 1
|
||||
if vector_col_count > 1:
|
||||
raise ValueError(
|
||||
"Schema has more than one vector column. "
|
||||
"Please specify the vector column name "
|
||||
"for vector search"
|
||||
)
|
||||
elif vector_col_count == 1:
|
||||
vector_col_name = field_name
|
||||
if vector_col_count == 0:
|
||||
vector_col_names = []
|
||||
for field in schema:
|
||||
vector_col_names.extend(_iter_vector_columns(field, [], dim))
|
||||
if len(vector_col_names) > 1:
|
||||
raise ValueError(
|
||||
"Schema has more than one vector column. "
|
||||
"Please specify the vector column name "
|
||||
f"for vector search. Candidates: {vector_col_names}"
|
||||
)
|
||||
if len(vector_col_names) == 0:
|
||||
raise ValueError(
|
||||
"There is no vector column in the data. "
|
||||
"Please specify the vector column name for vector search"
|
||||
)
|
||||
return vector_col_name
|
||||
return vector_col_names[0]
|
||||
|
||||
|
||||
def is_vector_column(data_type: pa.DataType) -> bool:
|
||||
@@ -247,6 +268,29 @@ def is_vector_column(data_type: pa.DataType) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def infer_vector_column_dim(data_type: pa.DataType) -> Optional[int]:
|
||||
if pa.types.is_fixed_size_list(data_type):
|
||||
return data_type.list_size
|
||||
if pa.types.is_list(data_type):
|
||||
return infer_vector_column_dim(data_type.value_type)
|
||||
return None
|
||||
|
||||
|
||||
def _query_vector_dim(query: Optional[Any]) -> Optional[int]:
|
||||
if query is None:
|
||||
return None
|
||||
if isinstance(query, np.ndarray):
|
||||
if query.ndim == 0:
|
||||
return None
|
||||
return query.shape[-1]
|
||||
if isinstance(query, list) and query:
|
||||
first = query[0]
|
||||
if isinstance(first, (list, tuple, np.ndarray)):
|
||||
return len(first)
|
||||
return len(query)
|
||||
return None
|
||||
|
||||
|
||||
def infer_vector_column_name(
|
||||
schema: pa.Schema,
|
||||
query_type: str,
|
||||
@@ -262,7 +306,9 @@ def infer_vector_column_name(
|
||||
|
||||
if query is not None or query_type == "hybrid":
|
||||
try:
|
||||
vector_column_name = inf_vector_column_query(schema)
|
||||
vector_column_name = inf_vector_column_query(
|
||||
schema, dim=_query_vector_dim(query)
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@@ -1934,6 +1934,10 @@ def test_create_index_nested_field_paths(mem_db: DBConnection):
|
||||
assert len(vector_results) == 1
|
||||
assert vector_results[0]["metadata"]["user_id"] == 0
|
||||
|
||||
default_vector_results = table.search([0.0, 1.0]).limit(1).to_list()
|
||||
assert len(default_vector_results) == 1
|
||||
assert default_vector_results[0]["metadata"]["user_id"] == 0
|
||||
|
||||
filtered_results = table.search().where("metadata.user_id = 42").limit(1).to_list()
|
||||
assert len(filtered_results) == 1
|
||||
assert filtered_results[0]["metadata"]["user_id"] == 42
|
||||
@@ -2013,6 +2017,74 @@ def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection):
|
||||
table.search(q).limit(1).to_arrow()
|
||||
|
||||
|
||||
def test_search_infers_single_nested_vector(mem_db: DBConnection):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int32()),
|
||||
pa.field(
|
||||
"image",
|
||||
pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]),
|
||||
),
|
||||
]
|
||||
)
|
||||
data = pa.Table.from_pylist(
|
||||
[
|
||||
{"id": 0, "image": {"embedding": [0.0, 1.0]}},
|
||||
{"id": 1, "image": {"embedding": [10.0, 11.0]}},
|
||||
],
|
||||
schema=schema,
|
||||
)
|
||||
table = mem_db.create_table("nested_vector_default_search", data=data)
|
||||
|
||||
result = table.search([0.0, 1.0]).limit(1).to_list()
|
||||
assert result[0]["id"] == 0
|
||||
|
||||
|
||||
def test_search_nested_vector_multiple_candidates(mem_db: DBConnection):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field(
|
||||
"image",
|
||||
pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]),
|
||||
),
|
||||
pa.field(
|
||||
"text",
|
||||
pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]),
|
||||
),
|
||||
]
|
||||
)
|
||||
data = pa.Table.from_pylist(
|
||||
[
|
||||
{
|
||||
"image": {"embedding": [0.0, 1.0]},
|
||||
"text": {"embedding": [2.0, 3.0]},
|
||||
}
|
||||
],
|
||||
schema=schema,
|
||||
)
|
||||
table = mem_db.create_table("nested_vector_multiple_candidates", data=data)
|
||||
|
||||
with pytest.raises(ValueError, match="image.embedding.*text.embedding"):
|
||||
table.search([0.0, 1.0]).limit(1).to_arrow()
|
||||
|
||||
|
||||
def test_search_nested_vector_no_candidates(mem_db: DBConnection):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int32()),
|
||||
pa.field("metadata", pa.struct([pa.field("label", pa.string())])),
|
||||
]
|
||||
)
|
||||
data = pa.Table.from_pylist(
|
||||
[{"id": 0, "metadata": {"label": "cat"}}],
|
||||
schema=schema,
|
||||
)
|
||||
table = mem_db.create_table("nested_vector_no_candidates", data=data)
|
||||
|
||||
with pytest.raises(ValueError, match="no vector column"):
|
||||
table.search([0.0, 1.0]).limit(1).to_arrow()
|
||||
|
||||
|
||||
def test_compact_cleanup(tmp_db: DBConnection):
|
||||
pytest.importorskip("lance")
|
||||
table = tmp_db.create_table(
|
||||
|
||||
Reference in New Issue
Block a user