diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 1d8cce32..13f0d458 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -41,6 +41,7 @@ use lance::dataset::{ WriteParams, }; use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource}; +use lance::index::vector::utils::infer_vector_dim; use lance::io::WrappingObjectStore; use lance_datafusion::exec::execute_plan; use lance_index::vector::hnsw::builder::HnswBuildParams; @@ -73,7 +74,7 @@ use crate::query::{ IntoQueryVector, Query, QueryExecutionOptions, Select, VectorQuery, DEFAULT_TOP_K, }; use crate::utils::{ - default_vector_column, infer_vector_dim, supported_bitmap_data_type, supported_btree_data_type, + default_vector_column, supported_bitmap_data_type, supported_btree_data_type, supported_fts_data_type, supported_label_list_data_type, supported_vector_data_type, PatchReadParam, PatchWriteParam, }; diff --git a/rust/lancedb/src/utils.rs b/rust/lancedb/src/utils.rs index 09ece491..d630ed7c 100644 --- a/rust/lancedb/src/utils.rs +++ b/rust/lancedb/src/utils.rs @@ -17,6 +17,7 @@ use std::sync::Arc; use arrow_schema::{DataType, Schema}; use lance::arrow::json::JsonDataType; use lance::dataset::{ReadParams, WriteParams}; +use lance::index::vector::utils::infer_vector_dim; use lance::io::{ObjectStoreParams, WrappingObjectStore}; use lazy_static::lazy_static; @@ -104,12 +105,12 @@ pub fn validate_table_name(name: &str) -> Result<()> { /// Find one default column to create index or perform vector query. pub(crate) fn default_vector_column(schema: &Schema, dim: Option) -> Result { - // Try to find one fixed size list array column. + // Try to find a vector column. let candidates = schema .fields() .iter() - .filter_map(|field| match inf_vector_dim(field) { - Some(d) if dim.is_none() || dim == Some(d) => Some(field.name()), + .filter_map(|field| match infer_vector_dim(field.data_type()) { + Ok(d) if dim.is_none() || dim == Some(d as i32) => Some(field.name()), _ => None, }) .collect::>(); @@ -133,20 +134,6 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option) -> Result } } -fn inf_vector_dim(field: &arrow_schema::Field) -> Option { - match field.data_type() { - arrow_schema::DataType::FixedSizeList(f, d) => { - if f.data_type().is_floating() || f.data_type() == &DataType::UInt8 { - Some(*d) - } else { - None - } - } - arrow_schema::DataType::List(f) => inf_vector_dim(f), - _ => None, - } -} - pub fn supported_btree_data_type(dtype: &DataType) -> bool { dtype.is_integer() || dtype.is_floating() @@ -188,24 +175,6 @@ pub fn supported_vector_data_type(dtype: &DataType) -> bool { } } -// TODO: remove this after we expose the same function in Lance. -pub fn infer_vector_dim(data_type: &DataType) -> Result { - infer_vector_dim_impl(data_type, false) -} - -fn infer_vector_dim_impl(data_type: &DataType, in_list: bool) -> Result { - match (data_type, in_list) { - (DataType::FixedSizeList(_, dim), _) => Ok(*dim as usize), - (DataType::List(inner), false) => infer_vector_dim_impl(inner.data_type(), true), - _ => Err(Error::InvalidInput { - message: format!( - "data type is not a vector (FixedSizeList or List), but {:?}", - data_type - ), - }), - } -} - /// Note: this is temporary until we get a proper datatype conversion in Lance. pub fn string_to_datatype(s: &str) -> Option { let data_type = serde_json::Value::String(s.to_string());