diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index fda877a45..7d43ca351 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -28,6 +28,7 @@ import { List, Schema, SchemaLike, + Struct, Type, Uint8, Utf8, @@ -780,6 +781,113 @@ describe("When creating an index", () => { expect(indices2.length).toBe(0); }); + it("should create and search a nested vector index", async () => { + const db = await connect(tmpDir.name); + const nestedSchema = new Schema([ + new Field("id", new Int32(), true), + new Field( + "image", + new Struct([ + new Field( + "embedding", + new FixedSizeList(2, new Field("item", new Float32(), true)), + true, + ), + ]), + true, + ), + ]); + const nestedTable = await db.createTable( + "nested_vector", + makeArrowTable( + Array.from({ length: 300 }, (_, id) => ({ + id, + image: { embedding: [id, id + 1] }, + })), + { schema: nestedSchema }, + ), + ); + + await nestedTable.createIndex("image.embedding", { + name: "image_embedding_idx", + }); + const indices = await nestedTable.listIndices(); + expect(indices).toContainEqual({ + name: "image_embedding_idx", + indexType: "IvfPq", + columns: ["image.embedding"], + }); + + const explicit = await nestedTable + .query() + .nearestTo([0.0, 1.0]) + .column("image.embedding") + .limit(1) + .toArray(); + const inferred = await nestedTable + .query() + .nearestTo([0.0, 1.0]) + .limit(1) + .toArray(); + expect(inferred[0].id).toEqual(explicit[0].id); + }); + + it("should report multiple nested vector candidates", async () => { + const db = await connect(tmpDir.name); + const nestedSchema = new Schema([ + new Field( + "image", + new Struct([ + new Field( + "embedding", + new FixedSizeList(2, new Field("item", new Float32(), true)), + true, + ), + ]), + true, + ), + new Field( + "text", + new Struct([ + new Field( + "embedding", + new FixedSizeList(2, new Field("item", new Float32(), true)), + true, + ), + ]), + true, + ), + ]); + const nestedTable = await db.createTable( + "multiple_nested_vectors", + makeArrowTable( + [ + { + image: { embedding: [0.0, 1.0] }, + text: { embedding: [2.0, 3.0] }, + }, + ], + { schema: nestedSchema }, + ), + ); + + await expect( + nestedTable.query().nearestTo([0.0, 1.0]).limit(1).toArray(), + ).rejects.toThrow(/image\.embedding.*text\.embedding/); + }); + + it("should report when no default vector column exists", async () => { + const db = await connect(tmpDir.name); + const noVectorTable = await db.createTable( + "no_vector", + makeArrowTable([{ id: 0, label: "cat" }]), + ); + + await expect( + noVectorTable.query().nearestTo([0.0, 1.0]).limit(1).toArray(), + ).rejects.toThrow(/No vector column/); + }); + it("should wait for index readiness", async () => { // Create an index and then wait for it to be ready await tbl.createIndex("vec"); diff --git a/python/python/lancedb/util.py b/python/python/lancedb/util.py index d5b66707f..b8e519933 100644 --- a/python/python/lancedb/util.py +++ b/python/python/lancedb/util.py @@ -10,7 +10,7 @@ import pathlib import warnings from datetime import date, datetime from functools import singledispatch -from typing import Tuple, Union, Optional, Any +from typing import Tuple, Union, Optional, Any, List from urllib.parse import urlparse import numpy as np @@ -189,7 +189,33 @@ def flatten_columns(tbl: pa.Table, flatten: Optional[Union[int, bool]] = None): return tbl -def inf_vector_column_query(schema: pa.Schema) -> str: +def _format_field_path(path: List[str]) -> str: + def format_segment(segment: str) -> str: + if all(char.isalnum() or char == "_" for char in segment): + return segment + return f"`{segment.replace('`', '``')}`" + + return ".".join(format_segment(segment) for segment in path) + + +def _iter_vector_columns( + field: pa.Field, path: List[str], dim: Optional[int] = None +) -> List[str]: + field_path = [*path, field.name] + if is_vector_column(field.type): + vector_dim = infer_vector_column_dim(field.type) + if dim is None or vector_dim == dim: + return [_format_field_path(field_path)] + return [] + if pa.types.is_struct(field.type): + columns = [] + for idx in range(field.type.num_fields): + columns.extend(_iter_vector_columns(field.type.field(idx), field_path, dim)) + return columns + return [] + + +def inf_vector_column_query(schema: pa.Schema, dim: Optional[int] = None) -> str: """ Get the vector column name @@ -202,26 +228,21 @@ def inf_vector_column_query(schema: pa.Schema) -> str: ------- str: the vector column name. """ - vector_col_name = "" - vector_col_count = 0 - for field_name in schema.names: - field = schema.field(field_name) - if is_vector_column(field.type): - vector_col_count += 1 - if vector_col_count > 1: - raise ValueError( - "Schema has more than one vector column. " - "Please specify the vector column name " - "for vector search" - ) - elif vector_col_count == 1: - vector_col_name = field_name - if vector_col_count == 0: + vector_col_names = [] + for field in schema: + vector_col_names.extend(_iter_vector_columns(field, [], dim)) + if len(vector_col_names) > 1: + raise ValueError( + "Schema has more than one vector column. " + "Please specify the vector column name " + f"for vector search. Candidates: {vector_col_names}" + ) + if len(vector_col_names) == 0: raise ValueError( "There is no vector column in the data. " "Please specify the vector column name for vector search" ) - return vector_col_name + return vector_col_names[0] def is_vector_column(data_type: pa.DataType) -> bool: @@ -247,6 +268,29 @@ def is_vector_column(data_type: pa.DataType) -> bool: return False +def infer_vector_column_dim(data_type: pa.DataType) -> Optional[int]: + if pa.types.is_fixed_size_list(data_type): + return data_type.list_size + if pa.types.is_list(data_type): + return infer_vector_column_dim(data_type.value_type) + return None + + +def _query_vector_dim(query: Optional[Any]) -> Optional[int]: + if query is None: + return None + if isinstance(query, np.ndarray): + if query.ndim == 0: + return None + return query.shape[-1] + if isinstance(query, list) and query: + first = query[0] + if isinstance(first, (list, tuple, np.ndarray)): + return len(first) + return len(query) + return None + + def infer_vector_column_name( schema: pa.Schema, query_type: str, @@ -262,7 +306,9 @@ def infer_vector_column_name( if query is not None or query_type == "hybrid": try: - vector_column_name = inf_vector_column_query(schema) + vector_column_name = inf_vector_column_query( + schema, dim=_query_vector_dim(query) + ) except Exception as e: raise e diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index bf1c11fe1..ed4656d81 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -1934,6 +1934,10 @@ def test_create_index_nested_field_paths(mem_db: DBConnection): assert len(vector_results) == 1 assert vector_results[0]["metadata"]["user_id"] == 0 + default_vector_results = table.search([0.0, 1.0]).limit(1).to_list() + assert len(default_vector_results) == 1 + assert default_vector_results[0]["metadata"]["user_id"] == 0 + filtered_results = table.search().where("metadata.user_id = 42").limit(1).to_list() assert len(filtered_results) == 1 assert filtered_results[0]["metadata"]["user_id"] == 42 @@ -2013,6 +2017,74 @@ def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection): table.search(q).limit(1).to_arrow() +def test_search_infers_single_nested_vector(mem_db: DBConnection): + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field( + "image", + pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]), + ), + ] + ) + data = pa.Table.from_pylist( + [ + {"id": 0, "image": {"embedding": [0.0, 1.0]}}, + {"id": 1, "image": {"embedding": [10.0, 11.0]}}, + ], + schema=schema, + ) + table = mem_db.create_table("nested_vector_default_search", data=data) + + result = table.search([0.0, 1.0]).limit(1).to_list() + assert result[0]["id"] == 0 + + +def test_search_nested_vector_multiple_candidates(mem_db: DBConnection): + schema = pa.schema( + [ + pa.field( + "image", + pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]), + ), + pa.field( + "text", + pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]), + ), + ] + ) + data = pa.Table.from_pylist( + [ + { + "image": {"embedding": [0.0, 1.0]}, + "text": {"embedding": [2.0, 3.0]}, + } + ], + schema=schema, + ) + table = mem_db.create_table("nested_vector_multiple_candidates", data=data) + + with pytest.raises(ValueError, match="image.embedding.*text.embedding"): + table.search([0.0, 1.0]).limit(1).to_arrow() + + +def test_search_nested_vector_no_candidates(mem_db: DBConnection): + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("metadata", pa.struct([pa.field("label", pa.string())])), + ] + ) + data = pa.Table.from_pylist( + [{"id": 0, "metadata": {"label": "cat"}}], + schema=schema, + ) + table = mem_db.create_table("nested_vector_no_candidates", data=data) + + with pytest.raises(ValueError, match="no vector column"): + table.search([0.0, 1.0]).limit(1).to_arrow() + + def test_compact_cleanup(tmp_db: DBConnection): pytest.importorskip("lance") table = tmp_db.create_table( diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 1bad181d9..efc23415e 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -27,7 +27,9 @@ use crate::table::UpdateResult; use crate::table::query::create_multi_vector_plan; use crate::table::{AnyQuery, Filter, PreprocessingOutput, TableStatistics}; use crate::utils::background_cache::BackgroundCache; -use crate::utils::{supported_btree_data_type, supported_vector_data_type}; +use crate::utils::{ + resolve_arrow_field_path, supported_btree_data_type, supported_vector_data_type, +}; use crate::{DistanceType, Error}; use crate::{ error::Result, @@ -1563,11 +1565,7 @@ impl BaseTable for RemoteTable { Index::FTS(p) => ("FTS", Some(to_json(p)?)), Index::Auto => { let schema = self.schema().await?; - let field = schema - .field_with_name(&column) - .map_err(|_| Error::InvalidInput { - message: format!("Column {} not found in schema", column), - })?; + let field = resolve_arrow_field_path(&schema, &column)?; if supported_vector_data_type(field.data_type()) { body[METRIC_TYPE_KEY] = serde_json::Value::String(DistanceType::L2.to_string().to_lowercase()); diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 7a2417822..03f967e6e 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -3877,6 +3877,25 @@ mod tests { 1 ); + let default_vector_results = table + .query() + .nearest_to(&[0.0; 8]) + .unwrap() + .limit(1) + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + assert_eq!( + default_vector_results + .iter() + .map(|batch| batch.num_rows()) + .sum::(), + 1 + ); + let fts_results = table .query() .full_text_search(FullTextSearchQuery::new("document".to_string())) diff --git a/rust/lancedb/src/utils/mod.rs b/rust/lancedb/src/utils/mod.rs index 0af8623b4..d43912058 100644 --- a/rust/lancedb/src/utils/mod.rs +++ b/rust/lancedb/src/utils/mod.rs @@ -6,7 +6,7 @@ pub(crate) mod background_cache; use std::sync::Arc; use arrow_array::RecordBatch; -use arrow_schema::{DataType, Schema, SchemaRef}; +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::{DataFusionError, Result as DataFusionResult}; use datafusion_execution::RecordBatchStream; use futures::{FutureExt, Stream}; @@ -152,14 +152,10 @@ pub fn validate_namespace(namespace: &[String]) -> Result<()> { /// Find one default column to create index or perform vector query. pub(crate) fn default_vector_column(schema: &Schema, dim: Option) -> Result { // Try to find a vector column. - let candidates = schema - .fields() - .iter() - .filter_map(|field| match infer_vector_dim(field.data_type()) { - Ok(d) if dim.is_none() || dim == Some(d as i32) => Some(field.name()), - _ => None, - }) - .collect::>(); + let mut candidates = Vec::new(); + for field in schema.fields() { + collect_vector_columns(field, &mut Vec::new(), dim, &mut candidates); + } if candidates.is_empty() { Err(Error::InvalidInput { message: format!( @@ -180,6 +176,63 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option) -> Result } } +fn collect_vector_columns( + field: &Field, + path: &mut Vec, + dim: Option, + candidates: &mut Vec, +) { + path.push(field.name().clone()); + match infer_vector_dim(field.data_type()) { + Ok(d) if dim.is_none() || dim == Some(d as i32) => { + let path_segments = path.iter().map(String::as_str).collect::>(); + candidates.push(lance_core::datatypes::format_field_path(&path_segments)); + } + _ => { + if let DataType::Struct(fields) = field.data_type() { + for child in fields { + collect_vector_columns(child, path, dim, candidates); + } + } + } + } + path.pop(); +} + +pub(crate) fn resolve_arrow_field_path(schema: &Schema, column: &str) -> Result { + let segments = + lance_core::datatypes::parse_field_path(column).map_err(|e| Error::InvalidInput { + message: format!("Invalid field path `{}`: {}", column, e), + })?; + let mut fields = schema.fields(); + + for (idx, segment) in segments.iter().enumerate() { + let field = find_field(fields, segment).ok_or_else(|| Error::Schema { + message: format!("Field path `{}` not found in schema", column), + })?; + if idx + 1 == segments.len() { + return Ok(field.clone()); + } + fields = match field.data_type() { + DataType::Struct(fields) => fields, + _ => { + return Err(Error::Schema { + message: format!("Field path `{}` not found in schema", column), + }); + } + }; + } + + unreachable!("parse_field_path returns at least one segment") +} + +fn find_field<'a>(fields: &'a Fields, name: &str) -> Option<&'a Field> { + fields + .iter() + .find(|field| field.name() == name) + .map(|field| field.as_ref()) +} + pub fn supported_btree_data_type(dtype: &DataType) -> bool { dtype.is_integer() || dtype.is_floating() @@ -450,6 +503,49 @@ mod tests { "vec" ); + let schema_with_nested_vec_col = Schema::new(vec![ + Field::new("id", DataType::Int16, true), + Field::new( + "image", + DataType::Struct( + vec![Field::new( + "embedding", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, false)), + 10, + ), + false, + )] + .into(), + ), + false, + ), + ]); + assert_eq!( + default_vector_column(&schema_with_nested_vec_col, None).unwrap(), + "image.embedding" + ); + + let schema_with_escaped_nested_vec_col = Schema::new(vec![Field::new( + "image-meta", + DataType::Struct( + vec![Field::new( + "embedding.v1", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, false)), + 10, + ), + false, + )] + .into(), + ), + false, + )]); + assert_eq!( + default_vector_column(&schema_with_escaped_nested_vec_col, None).unwrap(), + "`image-meta`.`embedding.v1`" + ); + let multi_vec_col = Schema::new(vec![ Field::new("id", DataType::Int16, true), Field::new( @@ -469,6 +565,48 @@ mod tests { .to_string() .contains("More than one") ); + + let multi_nested_vec_col = Schema::new(vec![ + Field::new( + "image", + DataType::Struct( + vec![Field::new( + "embedding", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, false)), + 10, + ), + false, + )] + .into(), + ), + false, + ), + Field::new( + "text", + DataType::Struct( + vec![Field::new( + "embedding", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, false)), + 50, + ), + false, + )] + .into(), + ), + false, + ), + ]); + assert_eq!( + default_vector_column(&multi_nested_vec_col, Some(50)).unwrap(), + "text.embedding" + ); + let err = default_vector_column(&multi_nested_vec_col, None) + .unwrap_err() + .to_string(); + assert!(err.contains("image.embedding")); + assert!(err.contains("text.embedding")); } #[test]