diff --git a/python/python/tests/docs/test_multivector.py b/python/python/tests/docs/test_multivector.py index d6b6c3a8..74b62a95 100644 --- a/python/python/tests/docs/test_multivector.py +++ b/python/python/tests/docs/test_multivector.py @@ -1,4 +1,5 @@ import shutil +from lancedb.index import IvfPq import pytest # --8<-- [start:imports] @@ -28,6 +29,9 @@ def test_multivector(): ] tbl = db.create_table("my_table", data=data, schema=schema) + # only cosine similarity is supported for multi-vectors + tbl.create_index(metric="cosine") + # query with single vector query = np.random.random(256) tbl.search(query).to_arrow() @@ -59,6 +63,9 @@ async def test_multivector_async(): ] tbl = await db.create_table("my_table", data=data, schema=schema) + # only cosine similarity is supported for multi-vectors + await tbl.create_index(column="vector", config=IvfPq(distance_type="cosine")) + # query with single vector query = np.random.random(256) await tbl.query().nearest_to(query).to_arrow() diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index a972bd6e..4022d588 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -73,7 +73,7 @@ use crate::query::{ IntoQueryVector, Query, QueryExecutionOptions, Select, VectorQuery, DEFAULT_TOP_K, }; use crate::utils::{ - default_vector_column, supported_bitmap_data_type, supported_btree_data_type, + default_vector_column, infer_vector_dim, supported_bitmap_data_type, supported_btree_data_type, supported_fts_data_type, supported_label_list_data_type, supported_vector_data_type, PatchReadParam, PatchWriteParam, }; @@ -1370,14 +1370,8 @@ impl NativeTable { let num_sub_vectors: u32 = if let Some(n) = index.num_sub_vectors { n } else { - match field.data_type() { - arrow_schema::DataType::FixedSizeList(_, n) => { - Ok::(suggested_num_sub_vectors(*n as u32)) - } - _ => Err(Error::Schema { - message: format!("Column '{}' is not a FixedSizeList", field.name()), - }), - }? + let dim = infer_vector_dim(field.data_type())?; + suggested_num_sub_vectors(dim as u32) }; let mut dataset = self.dataset.get_mut().await?; let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_pq( diff --git a/rust/lancedb/src/utils.rs b/rust/lancedb/src/utils.rs index c8455b4c..09ece491 100644 --- a/rust/lancedb/src/utils.rs +++ b/rust/lancedb/src/utils.rs @@ -188,6 +188,24 @@ pub fn supported_vector_data_type(dtype: &DataType) -> bool { } } +// TODO: remove this after we expose the same function in Lance. +pub fn infer_vector_dim(data_type: &DataType) -> Result { + infer_vector_dim_impl(data_type, false) +} + +fn infer_vector_dim_impl(data_type: &DataType, in_list: bool) -> Result { + match (data_type, in_list) { + (DataType::FixedSizeList(_, dim), _) => Ok(*dim as usize), + (DataType::List(inner), false) => infer_vector_dim_impl(inner.data_type(), true), + _ => Err(Error::InvalidInput { + message: format!( + "data type is not a vector (FixedSizeList or List), but {:?}", + data_type + ), + }), + } +} + /// Note: this is temporary until we get a proper datatype conversion in Lance. pub fn string_to_datatype(s: &str) -> Option { let data_type = serde_json::Value::String(s.to_string());