mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-21 22:10:40 +00:00
fix: discover nested vector columns by default (#3423)
LanceDB default vector column discovery only considered top-level fields, so tables with a single nested vector leaf still required users to pass an explicit field path. This updates Rust and Python discovery to recurse into struct fields, return canonical field paths, and preserve actionable errors when no default or multiple defaults exist. The explicit nested path flow for index creation and search remains supported across Rust, Python, and Node, with regression coverage for single nested vector leaves, multiple candidate leaves, and schemas without vector leaves. Closes #3405.
This commit is contained in:
@@ -28,6 +28,7 @@ import {
|
||||
List,
|
||||
Schema,
|
||||
SchemaLike,
|
||||
Struct,
|
||||
Type,
|
||||
Uint8,
|
||||
Utf8,
|
||||
@@ -780,6 +781,113 @@ describe("When creating an index", () => {
|
||||
expect(indices2.length).toBe(0);
|
||||
});
|
||||
|
||||
it("should create and search a nested vector index", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const nestedSchema = new Schema([
|
||||
new Field("id", new Int32(), true),
|
||||
new Field(
|
||||
"image",
|
||||
new Struct([
|
||||
new Field(
|
||||
"embedding",
|
||||
new FixedSizeList(2, new Field("item", new Float32(), true)),
|
||||
true,
|
||||
),
|
||||
]),
|
||||
true,
|
||||
),
|
||||
]);
|
||||
const nestedTable = await db.createTable(
|
||||
"nested_vector",
|
||||
makeArrowTable(
|
||||
Array.from({ length: 300 }, (_, id) => ({
|
||||
id,
|
||||
image: { embedding: [id, id + 1] },
|
||||
})),
|
||||
{ schema: nestedSchema },
|
||||
),
|
||||
);
|
||||
|
||||
await nestedTable.createIndex("image.embedding", {
|
||||
name: "image_embedding_idx",
|
||||
});
|
||||
const indices = await nestedTable.listIndices();
|
||||
expect(indices).toContainEqual({
|
||||
name: "image_embedding_idx",
|
||||
indexType: "IvfPq",
|
||||
columns: ["image.embedding"],
|
||||
});
|
||||
|
||||
const explicit = await nestedTable
|
||||
.query()
|
||||
.nearestTo([0.0, 1.0])
|
||||
.column("image.embedding")
|
||||
.limit(1)
|
||||
.toArray();
|
||||
const inferred = await nestedTable
|
||||
.query()
|
||||
.nearestTo([0.0, 1.0])
|
||||
.limit(1)
|
||||
.toArray();
|
||||
expect(inferred[0].id).toEqual(explicit[0].id);
|
||||
});
|
||||
|
||||
it("should report multiple nested vector candidates", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const nestedSchema = new Schema([
|
||||
new Field(
|
||||
"image",
|
||||
new Struct([
|
||||
new Field(
|
||||
"embedding",
|
||||
new FixedSizeList(2, new Field("item", new Float32(), true)),
|
||||
true,
|
||||
),
|
||||
]),
|
||||
true,
|
||||
),
|
||||
new Field(
|
||||
"text",
|
||||
new Struct([
|
||||
new Field(
|
||||
"embedding",
|
||||
new FixedSizeList(2, new Field("item", new Float32(), true)),
|
||||
true,
|
||||
),
|
||||
]),
|
||||
true,
|
||||
),
|
||||
]);
|
||||
const nestedTable = await db.createTable(
|
||||
"multiple_nested_vectors",
|
||||
makeArrowTable(
|
||||
[
|
||||
{
|
||||
image: { embedding: [0.0, 1.0] },
|
||||
text: { embedding: [2.0, 3.0] },
|
||||
},
|
||||
],
|
||||
{ schema: nestedSchema },
|
||||
),
|
||||
);
|
||||
|
||||
await expect(
|
||||
nestedTable.query().nearestTo([0.0, 1.0]).limit(1).toArray(),
|
||||
).rejects.toThrow(/image\.embedding.*text\.embedding/);
|
||||
});
|
||||
|
||||
it("should report when no default vector column exists", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const noVectorTable = await db.createTable(
|
||||
"no_vector",
|
||||
makeArrowTable([{ id: 0, label: "cat" }]),
|
||||
);
|
||||
|
||||
await expect(
|
||||
noVectorTable.query().nearestTo([0.0, 1.0]).limit(1).toArray(),
|
||||
).rejects.toThrow(/No vector column/);
|
||||
});
|
||||
|
||||
it("should wait for index readiness", async () => {
|
||||
// Create an index and then wait for it to be ready
|
||||
await tbl.createIndex("vec");
|
||||
|
||||
@@ -10,7 +10,7 @@ import pathlib
|
||||
import warnings
|
||||
from datetime import date, datetime
|
||||
from functools import singledispatch
|
||||
from typing import Tuple, Union, Optional, Any
|
||||
from typing import Tuple, Union, Optional, Any, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
@@ -189,7 +189,33 @@ def flatten_columns(tbl: pa.Table, flatten: Optional[Union[int, bool]] = None):
|
||||
return tbl
|
||||
|
||||
|
||||
def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||
def _format_field_path(path: List[str]) -> str:
|
||||
def format_segment(segment: str) -> str:
|
||||
if all(char.isalnum() or char == "_" for char in segment):
|
||||
return segment
|
||||
return f"`{segment.replace('`', '``')}`"
|
||||
|
||||
return ".".join(format_segment(segment) for segment in path)
|
||||
|
||||
|
||||
def _iter_vector_columns(
|
||||
field: pa.Field, path: List[str], dim: Optional[int] = None
|
||||
) -> List[str]:
|
||||
field_path = [*path, field.name]
|
||||
if is_vector_column(field.type):
|
||||
vector_dim = infer_vector_column_dim(field.type)
|
||||
if dim is None or vector_dim == dim:
|
||||
return [_format_field_path(field_path)]
|
||||
return []
|
||||
if pa.types.is_struct(field.type):
|
||||
columns = []
|
||||
for idx in range(field.type.num_fields):
|
||||
columns.extend(_iter_vector_columns(field.type.field(idx), field_path, dim))
|
||||
return columns
|
||||
return []
|
||||
|
||||
|
||||
def inf_vector_column_query(schema: pa.Schema, dim: Optional[int] = None) -> str:
|
||||
"""
|
||||
Get the vector column name
|
||||
|
||||
@@ -202,26 +228,21 @@ def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||
-------
|
||||
str: the vector column name.
|
||||
"""
|
||||
vector_col_name = ""
|
||||
vector_col_count = 0
|
||||
for field_name in schema.names:
|
||||
field = schema.field(field_name)
|
||||
if is_vector_column(field.type):
|
||||
vector_col_count += 1
|
||||
if vector_col_count > 1:
|
||||
raise ValueError(
|
||||
"Schema has more than one vector column. "
|
||||
"Please specify the vector column name "
|
||||
"for vector search"
|
||||
)
|
||||
elif vector_col_count == 1:
|
||||
vector_col_name = field_name
|
||||
if vector_col_count == 0:
|
||||
vector_col_names = []
|
||||
for field in schema:
|
||||
vector_col_names.extend(_iter_vector_columns(field, [], dim))
|
||||
if len(vector_col_names) > 1:
|
||||
raise ValueError(
|
||||
"Schema has more than one vector column. "
|
||||
"Please specify the vector column name "
|
||||
f"for vector search. Candidates: {vector_col_names}"
|
||||
)
|
||||
if len(vector_col_names) == 0:
|
||||
raise ValueError(
|
||||
"There is no vector column in the data. "
|
||||
"Please specify the vector column name for vector search"
|
||||
)
|
||||
return vector_col_name
|
||||
return vector_col_names[0]
|
||||
|
||||
|
||||
def is_vector_column(data_type: pa.DataType) -> bool:
|
||||
@@ -247,6 +268,29 @@ def is_vector_column(data_type: pa.DataType) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def infer_vector_column_dim(data_type: pa.DataType) -> Optional[int]:
|
||||
if pa.types.is_fixed_size_list(data_type):
|
||||
return data_type.list_size
|
||||
if pa.types.is_list(data_type):
|
||||
return infer_vector_column_dim(data_type.value_type)
|
||||
return None
|
||||
|
||||
|
||||
def _query_vector_dim(query: Optional[Any]) -> Optional[int]:
|
||||
if query is None:
|
||||
return None
|
||||
if isinstance(query, np.ndarray):
|
||||
if query.ndim == 0:
|
||||
return None
|
||||
return query.shape[-1]
|
||||
if isinstance(query, list) and query:
|
||||
first = query[0]
|
||||
if isinstance(first, (list, tuple, np.ndarray)):
|
||||
return len(first)
|
||||
return len(query)
|
||||
return None
|
||||
|
||||
|
||||
def infer_vector_column_name(
|
||||
schema: pa.Schema,
|
||||
query_type: str,
|
||||
@@ -262,7 +306,9 @@ def infer_vector_column_name(
|
||||
|
||||
if query is not None or query_type == "hybrid":
|
||||
try:
|
||||
vector_column_name = inf_vector_column_query(schema)
|
||||
vector_column_name = inf_vector_column_query(
|
||||
schema, dim=_query_vector_dim(query)
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@@ -1934,6 +1934,10 @@ def test_create_index_nested_field_paths(mem_db: DBConnection):
|
||||
assert len(vector_results) == 1
|
||||
assert vector_results[0]["metadata"]["user_id"] == 0
|
||||
|
||||
default_vector_results = table.search([0.0, 1.0]).limit(1).to_list()
|
||||
assert len(default_vector_results) == 1
|
||||
assert default_vector_results[0]["metadata"]["user_id"] == 0
|
||||
|
||||
filtered_results = table.search().where("metadata.user_id = 42").limit(1).to_list()
|
||||
assert len(filtered_results) == 1
|
||||
assert filtered_results[0]["metadata"]["user_id"] == 42
|
||||
@@ -2013,6 +2017,74 @@ def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection):
|
||||
table.search(q).limit(1).to_arrow()
|
||||
|
||||
|
||||
def test_search_infers_single_nested_vector(mem_db: DBConnection):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int32()),
|
||||
pa.field(
|
||||
"image",
|
||||
pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]),
|
||||
),
|
||||
]
|
||||
)
|
||||
data = pa.Table.from_pylist(
|
||||
[
|
||||
{"id": 0, "image": {"embedding": [0.0, 1.0]}},
|
||||
{"id": 1, "image": {"embedding": [10.0, 11.0]}},
|
||||
],
|
||||
schema=schema,
|
||||
)
|
||||
table = mem_db.create_table("nested_vector_default_search", data=data)
|
||||
|
||||
result = table.search([0.0, 1.0]).limit(1).to_list()
|
||||
assert result[0]["id"] == 0
|
||||
|
||||
|
||||
def test_search_nested_vector_multiple_candidates(mem_db: DBConnection):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field(
|
||||
"image",
|
||||
pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]),
|
||||
),
|
||||
pa.field(
|
||||
"text",
|
||||
pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]),
|
||||
),
|
||||
]
|
||||
)
|
||||
data = pa.Table.from_pylist(
|
||||
[
|
||||
{
|
||||
"image": {"embedding": [0.0, 1.0]},
|
||||
"text": {"embedding": [2.0, 3.0]},
|
||||
}
|
||||
],
|
||||
schema=schema,
|
||||
)
|
||||
table = mem_db.create_table("nested_vector_multiple_candidates", data=data)
|
||||
|
||||
with pytest.raises(ValueError, match="image.embedding.*text.embedding"):
|
||||
table.search([0.0, 1.0]).limit(1).to_arrow()
|
||||
|
||||
|
||||
def test_search_nested_vector_no_candidates(mem_db: DBConnection):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int32()),
|
||||
pa.field("metadata", pa.struct([pa.field("label", pa.string())])),
|
||||
]
|
||||
)
|
||||
data = pa.Table.from_pylist(
|
||||
[{"id": 0, "metadata": {"label": "cat"}}],
|
||||
schema=schema,
|
||||
)
|
||||
table = mem_db.create_table("nested_vector_no_candidates", data=data)
|
||||
|
||||
with pytest.raises(ValueError, match="no vector column"):
|
||||
table.search([0.0, 1.0]).limit(1).to_arrow()
|
||||
|
||||
|
||||
def test_compact_cleanup(tmp_db: DBConnection):
|
||||
pytest.importorskip("lance")
|
||||
table = tmp_db.create_table(
|
||||
|
||||
@@ -27,7 +27,9 @@ use crate::table::UpdateResult;
|
||||
use crate::table::query::create_multi_vector_plan;
|
||||
use crate::table::{AnyQuery, Filter, PreprocessingOutput, TableStatistics};
|
||||
use crate::utils::background_cache::BackgroundCache;
|
||||
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
|
||||
use crate::utils::{
|
||||
resolve_arrow_field_path, supported_btree_data_type, supported_vector_data_type,
|
||||
};
|
||||
use crate::{DistanceType, Error};
|
||||
use crate::{
|
||||
error::Result,
|
||||
@@ -1563,11 +1565,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
Index::FTS(p) => ("FTS", Some(to_json(p)?)),
|
||||
Index::Auto => {
|
||||
let schema = self.schema().await?;
|
||||
let field = schema
|
||||
.field_with_name(&column)
|
||||
.map_err(|_| Error::InvalidInput {
|
||||
message: format!("Column {} not found in schema", column),
|
||||
})?;
|
||||
let field = resolve_arrow_field_path(&schema, &column)?;
|
||||
if supported_vector_data_type(field.data_type()) {
|
||||
body[METRIC_TYPE_KEY] =
|
||||
serde_json::Value::String(DistanceType::L2.to_string().to_lowercase());
|
||||
|
||||
@@ -3877,6 +3877,25 @@ mod tests {
|
||||
1
|
||||
);
|
||||
|
||||
let default_vector_results = table
|
||||
.query()
|
||||
.nearest_to(&[0.0; 8])
|
||||
.unwrap()
|
||||
.limit(1)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
default_vector_results
|
||||
.iter()
|
||||
.map(|batch| batch.num_rows())
|
||||
.sum::<usize>(),
|
||||
1
|
||||
);
|
||||
|
||||
let fts_results = table
|
||||
.query()
|
||||
.full_text_search(FullTextSearchQuery::new("document".to_string()))
|
||||
|
||||
@@ -6,7 +6,7 @@ pub(crate) mod background_cache;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::RecordBatch;
|
||||
use arrow_schema::{DataType, Schema, SchemaRef};
|
||||
use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef};
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult};
|
||||
use datafusion_execution::RecordBatchStream;
|
||||
use futures::{FutureExt, Stream};
|
||||
@@ -152,14 +152,10 @@ pub fn validate_namespace(namespace: &[String]) -> Result<()> {
|
||||
/// Find one default column to create index or perform vector query.
|
||||
pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result<String> {
|
||||
// Try to find a vector column.
|
||||
let candidates = schema
|
||||
.fields()
|
||||
.iter()
|
||||
.filter_map(|field| match infer_vector_dim(field.data_type()) {
|
||||
Ok(d) if dim.is_none() || dim == Some(d as i32) => Some(field.name()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let mut candidates = Vec::new();
|
||||
for field in schema.fields() {
|
||||
collect_vector_columns(field, &mut Vec::new(), dim, &mut candidates);
|
||||
}
|
||||
if candidates.is_empty() {
|
||||
Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
@@ -180,6 +176,63 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_vector_columns(
|
||||
field: &Field,
|
||||
path: &mut Vec<String>,
|
||||
dim: Option<i32>,
|
||||
candidates: &mut Vec<String>,
|
||||
) {
|
||||
path.push(field.name().clone());
|
||||
match infer_vector_dim(field.data_type()) {
|
||||
Ok(d) if dim.is_none() || dim == Some(d as i32) => {
|
||||
let path_segments = path.iter().map(String::as_str).collect::<Vec<_>>();
|
||||
candidates.push(lance_core::datatypes::format_field_path(&path_segments));
|
||||
}
|
||||
_ => {
|
||||
if let DataType::Struct(fields) = field.data_type() {
|
||||
for child in fields {
|
||||
collect_vector_columns(child, path, dim, candidates);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
path.pop();
|
||||
}
|
||||
|
||||
pub(crate) fn resolve_arrow_field_path(schema: &Schema, column: &str) -> Result<Field> {
|
||||
let segments =
|
||||
lance_core::datatypes::parse_field_path(column).map_err(|e| Error::InvalidInput {
|
||||
message: format!("Invalid field path `{}`: {}", column, e),
|
||||
})?;
|
||||
let mut fields = schema.fields();
|
||||
|
||||
for (idx, segment) in segments.iter().enumerate() {
|
||||
let field = find_field(fields, segment).ok_or_else(|| Error::Schema {
|
||||
message: format!("Field path `{}` not found in schema", column),
|
||||
})?;
|
||||
if idx + 1 == segments.len() {
|
||||
return Ok(field.clone());
|
||||
}
|
||||
fields = match field.data_type() {
|
||||
DataType::Struct(fields) => fields,
|
||||
_ => {
|
||||
return Err(Error::Schema {
|
||||
message: format!("Field path `{}` not found in schema", column),
|
||||
});
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
unreachable!("parse_field_path returns at least one segment")
|
||||
}
|
||||
|
||||
fn find_field<'a>(fields: &'a Fields, name: &str) -> Option<&'a Field> {
|
||||
fields
|
||||
.iter()
|
||||
.find(|field| field.name() == name)
|
||||
.map(|field| field.as_ref())
|
||||
}
|
||||
|
||||
pub fn supported_btree_data_type(dtype: &DataType) -> bool {
|
||||
dtype.is_integer()
|
||||
|| dtype.is_floating()
|
||||
@@ -450,6 +503,49 @@ mod tests {
|
||||
"vec"
|
||||
);
|
||||
|
||||
let schema_with_nested_vec_col = Schema::new(vec![
|
||||
Field::new("id", DataType::Int16, true),
|
||||
Field::new(
|
||||
"image",
|
||||
DataType::Struct(
|
||||
vec![Field::new(
|
||||
"embedding",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(Field::new("item", DataType::Float32, false)),
|
||||
10,
|
||||
),
|
||||
false,
|
||||
)]
|
||||
.into(),
|
||||
),
|
||||
false,
|
||||
),
|
||||
]);
|
||||
assert_eq!(
|
||||
default_vector_column(&schema_with_nested_vec_col, None).unwrap(),
|
||||
"image.embedding"
|
||||
);
|
||||
|
||||
let schema_with_escaped_nested_vec_col = Schema::new(vec![Field::new(
|
||||
"image-meta",
|
||||
DataType::Struct(
|
||||
vec![Field::new(
|
||||
"embedding.v1",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(Field::new("item", DataType::Float32, false)),
|
||||
10,
|
||||
),
|
||||
false,
|
||||
)]
|
||||
.into(),
|
||||
),
|
||||
false,
|
||||
)]);
|
||||
assert_eq!(
|
||||
default_vector_column(&schema_with_escaped_nested_vec_col, None).unwrap(),
|
||||
"`image-meta`.`embedding.v1`"
|
||||
);
|
||||
|
||||
let multi_vec_col = Schema::new(vec![
|
||||
Field::new("id", DataType::Int16, true),
|
||||
Field::new(
|
||||
@@ -469,6 +565,48 @@ mod tests {
|
||||
.to_string()
|
||||
.contains("More than one")
|
||||
);
|
||||
|
||||
let multi_nested_vec_col = Schema::new(vec![
|
||||
Field::new(
|
||||
"image",
|
||||
DataType::Struct(
|
||||
vec![Field::new(
|
||||
"embedding",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(Field::new("item", DataType::Float32, false)),
|
||||
10,
|
||||
),
|
||||
false,
|
||||
)]
|
||||
.into(),
|
||||
),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"text",
|
||||
DataType::Struct(
|
||||
vec![Field::new(
|
||||
"embedding",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(Field::new("item", DataType::Float32, false)),
|
||||
50,
|
||||
),
|
||||
false,
|
||||
)]
|
||||
.into(),
|
||||
),
|
||||
false,
|
||||
),
|
||||
]);
|
||||
assert_eq!(
|
||||
default_vector_column(&multi_nested_vec_col, Some(50)).unwrap(),
|
||||
"text.embedding"
|
||||
);
|
||||
let err = default_vector_column(&multi_nested_vec_col, None)
|
||||
.unwrap_err()
|
||||
.to_string();
|
||||
assert!(err.contains("image.embedding"));
|
||||
assert!(err.contains("text.embedding"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user