diff --git a/Cargo.toml b/Cargo.toml index a622b608..8905a303 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,16 +1,23 @@ [workspace] -members = [ - "rust/vectordb", - "rust/ffi/node" -] +members = ["rust/ffi/node", "rust/vectordb"] +# Python package needs to be built by maturin. +exclude = ["python"] resolver = "2" [workspace.dependencies] lance = "=0.6.5" +# Note that this one does not include pyarrow +arrow = { version = "43.0.0", optional = false } arrow-array = "43.0" arrow-data = "43.0" -arrow-schema = "43.0" arrow-ipc = "43.0" -half = { "version" = "=2.2.1", default-features = false } +arrow-ord = "43.0" +arrow-schema = "43.0" +arrow-arith = "43.0" +arrow-cast = "43.0" +half = { "version" = "=2.2.1", default-features = false, features = [ + "num-traits" +] } +log = "0.4" object_store = "0.6.1" snafu = "0.7.4" diff --git a/rust/vectordb/Cargo.toml b/rust/vectordb/Cargo.toml index 9c56a204..34247aa2 100644 --- a/rust/vectordb/Cargo.toml +++ b/rust/vectordb/Cargo.toml @@ -10,14 +10,19 @@ categories = ["database-implementations"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +arrow = { workspace = true } arrow-array = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } +arrow-ord = { workspace = true } +arrow-cast = { workspace = true } object_store = { workspace = true } snafu = { workspace = true } half = { workspace = true } lance = { workspace = true } tokio = { version = "1.23", features = ["rt-multi-thread"] } +log = { workspace = true } +num-traits = "0" [dev-dependencies] tempfile = "3.5.0" diff --git a/rust/vectordb/src/arrow.rs b/rust/vectordb/src/arrow.rs new file mode 100644 index 00000000..9e4ae419 --- /dev/null +++ b/rust/vectordb/src/arrow.rs @@ -0,0 +1,15 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub use lance::arrow::*; \ No newline at end of file diff --git a/rust/vectordb/src/data.rs b/rust/vectordb/src/data.rs new file mode 100644 index 00000000..ce566063 --- /dev/null +++ b/rust/vectordb/src/data.rs @@ -0,0 +1,18 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Data types, schema coercion, and data cleaning and etc. + +pub mod inspect; +pub mod sanitize; diff --git a/rust/vectordb/src/data/inspect.rs b/rust/vectordb/src/data/inspect.rs new file mode 100644 index 00000000..7c1ab849 --- /dev/null +++ b/rust/vectordb/src/data/inspect.rs @@ -0,0 +1,180 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; + +use arrow::compute::kernels::{aggregate::bool_and, length::length}; +use arrow_array::{ + cast::AsArray, + types::{ArrowPrimitiveType, Int32Type, Int64Type}, + Array, GenericListArray, OffsetSizeTrait, RecordBatchReader, +}; +use arrow_ord::comparison::eq_dyn_scalar; +use arrow_schema::DataType; +use num_traits::{ToPrimitive, Zero}; + +use crate::error::{Error, Result}; + +pub(crate) fn infer_dimension( + list_arr: &GenericListArray, +) -> Result> +where + T::Native: OffsetSizeTrait + ToPrimitive, +{ + let len_arr = length(list_arr)?; + if len_arr.is_empty() { + return Ok(Some(Zero::zero())); + } + + let dim = len_arr.as_primitive::().value(0); + if bool_and(&eq_dyn_scalar(len_arr.as_primitive::(), dim)?) != Some(true) { + Ok(None) + } else { + Ok(Some(dim)) + } +} + +/// Infer the vector columns from a dataset. +/// +/// Parameters +/// ---------- +/// - reader: RecordBatchReader +/// - strict: if set true, only fixed_size_list is considered as vector column. If set to false, +/// a list column with same length is also considered as vector column. +pub fn infer_vector_columns( + reader: impl RecordBatchReader + Send, + strict: bool, +) -> Result> { + let mut columns = vec![]; + + let mut columns_to_infer: HashMap> = HashMap::new(); + for field in reader.schema().fields() { + match field.data_type() { + DataType::FixedSizeList(sub_field, _) if sub_field.data_type().is_floating() => { + columns.push(field.name().to_string()); + } + DataType::List(sub_field) if sub_field.data_type().is_floating() && !strict => { + columns_to_infer.insert(field.name().to_string(), None); + } + DataType::LargeList(sub_field) if sub_field.data_type().is_floating() && !strict => { + columns_to_infer.insert(field.name().to_string(), None); + } + _ => {} + } + } + for batch in reader { + let batch = batch?; + let col_names = columns_to_infer.keys().cloned().collect::>(); + for col_name in col_names { + let col = batch.column_by_name(&col_name).ok_or(Error::Schema { + message: format!("Column {} not found", col_name), + })?; + if let Some(dim) = match *col.data_type() { + DataType::List(_) => { + infer_dimension::(col.as_list::())?.map(|d| d as i64) + } + DataType::LargeList(_) => infer_dimension::(col.as_list::())?, + _ => { + return Err(Error::Schema { + message: format!("Column {} is not a list", col_name), + }) + } + } { + if let Some(Some(prev_dim)) = columns_to_infer.get(&col_name) { + if prev_dim != &dim { + columns_to_infer.remove(&col_name); + } + } else { + columns_to_infer.insert(col_name, Some(dim)); + } + } else { + columns_to_infer.remove(&col_name); + } + } + } + columns.extend(columns_to_infer.keys().cloned()); + Ok(columns) +} + +#[cfg(test)] +mod tests { + use super::*; + + use arrow_array::{ + types::{Float32Type, Float64Type}, + FixedSizeListArray, Float32Array, ListArray, RecordBatch, RecordBatchIterator, StringArray, + }; + use arrow_schema::{DataType, Field, Schema}; + use std::{sync::Arc, vec}; + + #[test] + fn test_infer_vector_columns() { + let schema = Arc::new(Schema::new(vec![ + Field::new("f", DataType::Float32, false), + Field::new("s", DataType::Utf8, false), + Field::new( + "l1", + DataType::List(Arc::new(Field::new("item", DataType::Float32, true))), + false, + ), + Field::new( + "l2", + DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + false, + ), + Field::new( + "fl", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 32), + true, + ), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + Arc::new(ListArray::from_iter_primitive::( + (0..3).map(|_| Some(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)])), + )), + // Var-length list + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1.0_f64)]), + Some(vec![Some(2.0_f64), Some(3.0_f64)]), + Some(vec![Some(4.0_f64), Some(5.0_f64), Some(6.0_f64)]), + ])), + Arc::new( + FixedSizeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(1.0); 32]), + Some(vec![Some(2.0); 32]), + Some(vec![Some(3.0); 32]), + ], + 32, + ), + ), + ], + ) + .unwrap(); + let reader = + RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), schema.clone()); + + let cols = infer_vector_columns(reader, false).unwrap(); + assert_eq!(cols, vec!["fl", "l1"]); + + let reader = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); + let cols = infer_vector_columns(reader, true).unwrap(); + assert_eq!(cols, vec!["fl"]); + } +} diff --git a/rust/vectordb/src/data/sanitize.rs b/rust/vectordb/src/data/sanitize.rs new file mode 100644 index 00000000..c5efd2bc --- /dev/null +++ b/rust/vectordb/src/data/sanitize.rs @@ -0,0 +1,284 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{iter::repeat_with, sync::Arc}; + +use arrow_array::{ + cast::AsArray, + types::{Float16Type, Float32Type, Float64Type, Int32Type, Int64Type}, + Array, ArrowNumericType, FixedSizeListArray, PrimitiveArray, RecordBatch, RecordBatchIterator, + RecordBatchReader, +}; +use arrow_cast::{can_cast_types, cast}; +use arrow_schema::{ArrowError, DataType, Field, Schema}; +use half::f16; +use lance::arrow::{DataTypeExt, FixedSizeListArrayExt}; +use log::warn; +use num_traits::cast::AsPrimitive; + +use super::inspect::infer_dimension; +use crate::error::Result; + +fn cast_array( + arr: &PrimitiveArray, +) -> Arc> +where + I::Native: AsPrimitive, +{ + Arc::new(PrimitiveArray::::from_iter_values( + arr.values().iter().map(|v| (*v).as_()), + )) +} + +fn cast_float_array( + arr: &PrimitiveArray, + dt: &DataType, +) -> std::result::Result, ArrowError> +where + I::Native: AsPrimitive + AsPrimitive + AsPrimitive, +{ + match dt { + DataType::Float16 => Ok(cast_array::(arr)), + DataType::Float32 => Ok(cast_array::(arr)), + DataType::Float64 => Ok(cast_array::(arr)), + _ => Err(ArrowError::SchemaError(format!( + "Incompatible change field: unable to coerce {:?} to {:?}", + arr.data_type(), + dt + ))), + } +} + +fn coerce_array( + array: &Arc, + field: &Field, +) -> std::result::Result, ArrowError> { + if array.data_type() == field.data_type() { + return Ok(array.clone()); + } + match (array.data_type(), field.data_type()) { + // Normal cast-able types. + (adt, dt) if can_cast_types(adt, dt) => cast(&array, dt), + // Casting between f16/f32/f64 can be lossy. + (adt, dt) if (adt.is_floating() || dt.is_floating()) => { + if adt.byte_width() > dt.byte_width() { + warn!( + "Coercing field {} {:?} to {:?} might lose precision", + field.name(), + adt, + dt + ); + } + match adt { + DataType::Float16 => cast_float_array(array.as_primitive::(), dt), + DataType::Float32 => cast_float_array(array.as_primitive::(), dt), + DataType::Float64 => cast_float_array(array.as_primitive::(), dt), + _ => unreachable!(), + } + } + (adt, DataType::FixedSizeList(exp_field, exp_dim)) => match adt { + // Cast a float fixed size array with same dimension to the expected type. + DataType::FixedSizeList(_, dim) if dim == exp_dim => { + let actual_sub = array.as_fixed_size_list(); + let values = coerce_array(actual_sub.values(), exp_field)?; + Ok(Arc::new(FixedSizeListArray::try_new_from_values( + values.clone(), + *dim, + )?) as Arc) + } + DataType::List(_) | DataType::LargeList(_) => { + let Some(dim) = (match adt { + DataType::List(_) => infer_dimension::(array.as_list::()) + .map_err(|e| { + ArrowError::SchemaError(format!( + "failed to infer dimension from list: {}", + e + )) + })? + .map(|d| d as i64), + DataType::LargeList(_) => infer_dimension::(array.as_list::()) + .map_err(|e| { + ArrowError::SchemaError(format!( + "failed to infer dimension from large list: {}", + e + )) + })?, + _ => unreachable!(), + }) else { + return Err(ArrowError::SchemaError(format!( + "Incompatible coerce fixed size list: unable to coerce {:?} from {:?}", + field, + array.data_type() + ))); + }; + + if dim != *exp_dim as i64 { + return Err(ArrowError::SchemaError(format!( + "Incompatible coerce fixed size list: expected dimension {} but got {}", + exp_dim, dim + ))); + } + + let values = coerce_array(array, exp_field)?; + Ok(Arc::new(FixedSizeListArray::try_new_from_values( + values.clone(), + *exp_dim, + )?) as Arc) + } + _ => Err(ArrowError::SchemaError(format!( + "Incompatible coerce fixed size list: unable to coerce {:?} from {:?}", + field, + array.data_type() + )))?, + }, + _ => Err(ArrowError::SchemaError(format!( + "Incompatible change field {}: unable to coerce {:?} to {:?}", + field.name(), + array.data_type(), + field.data_type() + )))?, + } +} + +fn coerce_schema_batch( + batch: RecordBatch, + schema: Arc, +) -> std::result::Result { + if batch.schema() == schema { + return Ok(batch); + } + let columns = schema + .fields() + .iter() + .map(|field| { + batch + .column_by_name(field.name()) + .ok_or_else(|| { + ArrowError::SchemaError(format!("Column {} not found", field.name())) + }) + .and_then(|c| coerce_array(c, field)) + }) + .collect::, ArrowError>>()?; + RecordBatch::try_new(schema, columns) +} + +/// Coerce the reader (input data) to match the given [Schema]. +/// +pub fn coerce_schema( + reader: impl RecordBatchReader + Send + 'static, + schema: Arc, +) -> Result> { + if reader.schema() == schema { + return Ok(Box::new(RecordBatchIterator::new(reader, schema))); + } + let s = schema.clone(); + let batches = reader + .zip(repeat_with(move || s.clone())) + .map(|(batch, s)| coerce_schema_batch(batch?, s)); + Ok(Box::new(RecordBatchIterator::new(batches, schema))) +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::Arc; + + use arrow_array::{ + FixedSizeListArray, Float16Array, Float32Array, Float64Array, Int32Array, Int8Array, + RecordBatch, RecordBatchIterator, StringArray, + }; + use arrow_schema::Field; + use half::f16; + use lance::arrow::FixedSizeListArrayExt; + + #[test] + fn test_coerce_list_to_fixed_size_list() { + let schema = Arc::new(Schema::new(vec![ + Field::new( + "fl", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 64), + true, + ), + Field::new("s", DataType::Utf8, true), + Field::new("f", DataType::Float16, true), + Field::new("i", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new( + FixedSizeListArray::try_new_from_values( + Float32Array::from_iter_values((0..256).map(|v| v as f32)), + 64, + ) + .unwrap(), + ), + Arc::new(StringArray::from(vec![ + Some("hello"), + Some("world"), + Some("from"), + Some("lance"), + ])), + Arc::new(Float16Array::from_iter_values( + (0..4).map(|v| f16::from_f32(v as f32)), + )), + Arc::new(Int32Array::from_iter_values(0..4)), + ], + ) + .unwrap(); + let reader = + RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), schema.clone()); + + let expected_schema = Arc::new(Schema::new(vec![ + Field::new( + "fl", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float16, true)), 64), + true, + ), + Field::new("s", DataType::Utf8, true), + Field::new("f", DataType::Float64, true), + Field::new("i", DataType::Int8, true), + ])); + let stream = coerce_schema(reader, expected_schema.clone()).unwrap(); + let batches = stream.collect::>(); + assert_eq!(batches.len(), 1); + let batch = batches[0].as_ref().unwrap(); + assert_eq!(batch.schema(), expected_schema); + + let expected = RecordBatch::try_new( + expected_schema, + vec![ + Arc::new( + FixedSizeListArray::try_new_from_values( + Float16Array::from_iter_values((0..256).map(|v| f16::from_f32(v as f32))), + 64, + ) + .unwrap(), + ), + Arc::new(StringArray::from(vec![ + Some("hello"), + Some("world"), + Some("from"), + Some("lance"), + ])), + Arc::new(Float64Array::from_iter_values((0..4).map(|v| v as f64))), + Arc::new(Int8Array::from_iter_values(0..4)), + ], + ) + .unwrap(); + assert_eq!(batch, &expected); + } +} diff --git a/rust/vectordb/src/error.rs b/rust/vectordb/src/error.rs index 4a8a9820..e5418d3c 100644 --- a/rust/vectordb/src/error.rs +++ b/rust/vectordb/src/error.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use arrow_schema::ArrowError; use snafu::Snafu; #[derive(Debug, Snafu)] @@ -32,10 +33,20 @@ pub enum Error { Store { message: String }, #[snafu(display("LanceDBError: {message}"))] Lance { message: String }, + #[snafu(display("LanceDB Schema Error: {message}"))] + Schema { message: String }, } pub type Result = std::result::Result; +impl From for Error { + fn from(e: ArrowError) -> Self { + Self::Lance { + message: e.to_string(), + } + } +} + impl From for Error { fn from(e: lance::Error) -> Self { Self::Lance { diff --git a/rust/vectordb/src/lib.rs b/rust/vectordb/src/lib.rs index 2a89851b..46d1716e 100644 --- a/rust/vectordb/src/lib.rs +++ b/rust/vectordb/src/lib.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub mod data; pub mod database; pub mod error; pub mod index;