diff --git a/src/datatypes/src/vectors/dictionary.rs b/src/datatypes/src/vectors/dictionary.rs index 07994d13bd..1ba06a6ca8 100644 --- a/src/datatypes/src/vectors/dictionary.rs +++ b/src/datatypes/src/vectors/dictionary.rs @@ -13,11 +13,11 @@ // limitations under the License. use std::any::Any; +use std::fmt; use std::sync::Arc; -use arrow::array::Array; -use arrow::datatypes::Int64Type; -use arrow_array::{ArrayRef, DictionaryArray, Int64Array}; +use arrow::array::{Array, ArrayRef, DictionaryArray, PrimitiveArray, PrimitiveBuilder}; +use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType}; use serde_json::Value as JsonValue; use snafu::ResultExt; @@ -30,34 +30,55 @@ use crate::vectors::operations::VectorOp; use crate::vectors::{self, Helper, Validity, Vector, VectorRef}; /// Vector of dictionaries, basically backed by Arrow's `DictionaryArray`. -#[derive(Debug, PartialEq)] -pub struct DictionaryVector { - array: DictionaryArray, +pub struct DictionaryVector { + array: DictionaryArray, + /// The datatype of the keys in the dictionary. + key_type: ConcreteDataType, /// The datatype of the items in the dictionary. item_type: ConcreteDataType, /// The vector of items in the dictionary. item_vector: VectorRef, } -impl DictionaryVector { +impl fmt::Debug for DictionaryVector { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("DictionaryVector") + .field("array", &self.array) + .field("key_type", &self.key_type) + .field("item_type", &self.item_type) + .finish() + } +} + +impl PartialEq for DictionaryVector { + fn eq(&self, other: &DictionaryVector) -> bool { + self.array == other.array + && self.key_type == other.key_type + && self.item_type == other.item_type + } +} + +impl DictionaryVector { /// Create a new instance of `DictionaryVector` from a dictionary array and item type - pub fn new(array: DictionaryArray, item_type: ConcreteDataType) -> Result { + pub fn new(array: DictionaryArray, item_type: ConcreteDataType) -> Result { + let key_type = ConcreteDataType::try_from(&K::DATA_TYPE)?; let item_vector = Helper::try_into_vector(array.values())?; Ok(Self { array, + key_type, item_type, item_vector, }) } /// Returns the underlying Arrow dictionary array - pub fn array(&self) -> &DictionaryArray { + pub fn array(&self) -> &DictionaryArray { &self.array } /// Returns the keys array of this dictionary - pub fn keys(&self) -> &arrow_array::PrimitiveArray { + pub fn keys(&self) -> &arrow_array::PrimitiveArray { self.array.keys() } @@ -71,10 +92,10 @@ impl DictionaryVector { } } -impl Vector for DictionaryVector { +impl Vector for DictionaryVector { fn data_type(&self) -> ConcreteDataType { ConcreteDataType::Dictionary(DictionaryType::new( - ConcreteDataType::int64_datatype(), + self.key_type.clone(), self.item_type.clone(), )) } @@ -118,6 +139,7 @@ impl Vector for DictionaryVector { fn slice(&self, offset: usize, length: usize) -> VectorRef { Arc::new(Self { array: self.array.slice(offset, length), + key_type: self.key_type.clone(), item_type: self.item_type.clone(), item_vector: self.item_vector.clone(), }) @@ -129,7 +151,7 @@ impl Vector for DictionaryVector { } let key = self.array.keys().value(index); - self.item_vector.get(key as usize) + self.item_vector.get(key.as_usize()) } fn get_ref(&self, index: usize) -> ValueRef { @@ -138,11 +160,11 @@ impl Vector for DictionaryVector { } let key = self.array.keys().value(index); - self.item_vector.get_ref(key as usize) + self.item_vector.get_ref(key.as_usize()) } } -impl Serializable for DictionaryVector { +impl Serializable for DictionaryVector { fn serialize_to_json(&self) -> Result> { // Convert the dictionary array to JSON, where each element is either null or // the value it refers to in the dictionary @@ -153,7 +175,7 @@ impl Serializable for DictionaryVector { result.push(JsonValue::Null); } else { let key = self.array.keys().value(i); - let value = self.item_vector.get(key as usize); + let value = self.item_vector.get(key.as_usize()); let json_value = serde_json::to_value(value).context(error::SerializeSnafu)?; result.push(json_value); } @@ -163,33 +185,35 @@ impl Serializable for DictionaryVector { } } -impl TryFrom> for DictionaryVector { +impl TryFrom> for DictionaryVector { type Error = crate::error::Error; - fn try_from(array: DictionaryArray) -> Result { - let item_type = ConcreteDataType::from_arrow_type(array.values().data_type()); + fn try_from(array: DictionaryArray) -> Result { + let key_type = ConcreteDataType::try_from(array.keys().data_type())?; + let item_type = ConcreteDataType::try_from(array.values().data_type())?; let item_vector = Helper::try_into_vector(array.values())?; Ok(Self { array, + key_type, item_type, item_vector, }) } } -pub struct DictionaryIter<'a> { - vector: &'a DictionaryVector, +pub struct DictionaryIter<'a, K: ArrowDictionaryKeyType> { + vector: &'a DictionaryVector, idx: usize, } -impl<'a> DictionaryIter<'a> { - pub fn new(vector: &'a DictionaryVector) -> DictionaryIter<'a> { +impl<'a, K: ArrowDictionaryKeyType> DictionaryIter<'a, K> { + pub fn new(vector: &'a DictionaryVector) -> DictionaryIter<'a, K> { DictionaryIter { vector, idx: 0 } } } -impl<'a> Iterator for DictionaryIter<'a> { +impl<'a, K: ArrowDictionaryKeyType> Iterator for DictionaryIter<'a, K> { type Item = Option>; #[inline] @@ -205,7 +229,7 @@ impl<'a> Iterator for DictionaryIter<'a> { return Some(None); } - Some(Some(self.vector.item_vector.get_ref(self.idx))) + Some(Some(self.vector.get_ref(idx))) } #[inline] @@ -217,10 +241,10 @@ impl<'a> Iterator for DictionaryIter<'a> { } } -impl VectorOp for DictionaryVector { +impl VectorOp for DictionaryVector { fn replicate(&self, offsets: &[usize]) -> VectorRef { let keys = self.array.keys(); - let mut replicated_keys = Vec::with_capacity(offsets.len()); + let mut replicated_keys = PrimitiveBuilder::new(); let mut previous_offset = 0; for (i, &offset) in offsets.iter().enumerate() { @@ -236,19 +260,20 @@ impl VectorOp for DictionaryVector { // repeat this key (offset - previous_offset) times let repeat_count = offset - previous_offset; - if repeat_count > 0 { - replicated_keys.resize(replicated_keys.len() + repeat_count, key); + for _ in 0..repeat_count { + replicated_keys.append_option(key); } previous_offset = offset; } - let new_keys = Int64Array::from(replicated_keys); + let new_keys = replicated_keys.finish(); let new_array = DictionaryArray::try_new(new_keys, self.values().clone()) .expect("Failed to create replicated dictionary array"); Arc::new(Self { array: new_array, + key_type: self.key_type.clone(), item_type: self.item_type.clone(), item_vector: self.item_vector.clone(), }) @@ -261,7 +286,7 @@ impl VectorOp for DictionaryVector { let filtered_key_array = filtered_key_vector.to_arrow_array(); let filtered_key_array = filtered_key_array .as_any() - .downcast_ref::() + .downcast_ref::>() .unwrap(); let new_array = DictionaryArray::try_new(filtered_key_array.clone(), self.values().clone()) @@ -269,6 +294,7 @@ impl VectorOp for DictionaryVector { Ok(Arc::new(Self { array: new_array, + key_type: self.key_type.clone(), item_type: self.item_type.clone(), item_vector: self.item_vector.clone(), })) @@ -281,6 +307,7 @@ impl VectorOp for DictionaryVector { .expect("Failed to create casted dictionary array"); Ok(Arc::new(Self { array: new_array, + key_type: self.key_type.clone(), item_type: to_type.clone(), item_vector: self.item_vector.clone(), })) @@ -291,13 +318,17 @@ impl VectorOp for DictionaryVector { let key_vector = Helper::try_into_vector(&key_array)?; let new_key_vector = key_vector.take(indices)?; let new_key_array = new_key_vector.to_arrow_array(); - let new_key_array = new_key_array.as_any().downcast_ref::().unwrap(); + let new_key_array = new_key_array + .as_any() + .downcast_ref::>() + .unwrap(); let new_array = DictionaryArray::try_new(new_key_array.clone(), self.values().clone()) .expect("Failed to create filtered dictionary array"); Ok(Arc::new(Self { array: new_array, + key_type: self.key_type.clone(), item_type: self.item_type.clone(), item_vector: self.item_vector.clone(), })) @@ -308,19 +339,20 @@ impl VectorOp for DictionaryVector { mod tests { use std::sync::Arc; - use arrow_array::StringArray; + use arrow::array::{Int64Array, StringArray, UInt32Array}; + use arrow::datatypes::{Int64Type, UInt32Type}; use super::*; // Helper function to create a test dictionary vector with string values - fn create_test_dictionary() -> DictionaryVector { + fn create_test_dictionary() -> DictionaryVector { // Dictionary values: ["a", "b", "c", "d"] // Keys: [0, 1, 2, null, 1, 3] // Resulting in: ["a", "b", "c", null, "b", "d"] let values = StringArray::from(vec!["a", "b", "c", "d"]); let keys = Int64Array::from(vec![Some(0), Some(1), Some(2), None, Some(1), Some(3)]); let dict_array = DictionaryArray::new(keys, Arc::new(values)); - DictionaryVector::try_from(dict_array).unwrap() + DictionaryVector::::try_from(dict_array).unwrap() } #[test] @@ -435,4 +467,19 @@ mod tests { assert_eq!(taken.get(1), Value::String("a".to_string().into())); assert_eq!(taken.get(2), Value::String("b".to_string().into())); } + + #[test] + fn test_other_type() { + let values = StringArray::from(vec!["a", "b", "c", "d"]); + let keys = UInt32Array::from(vec![Some(0), Some(1), Some(2), None, Some(1), Some(3)]); + let dict_array = DictionaryArray::new(keys, Arc::new(values)); + let dict_vec = DictionaryVector::::try_from(dict_array).unwrap(); + assert_eq!( + ConcreteDataType::dictionary_datatype( + ConcreteDataType::uint32_datatype(), + ConcreteDataType::string_datatype() + ), + dict_vec.data_type() + ); + } } diff --git a/src/datatypes/src/vectors/helper.rs b/src/datatypes/src/vectors/helper.rs index 2457052032..b867d1a7b1 100644 --- a/src/datatypes/src/vectors/helper.rs +++ b/src/datatypes/src/vectors/helper.rs @@ -20,7 +20,10 @@ use std::sync::Arc; use arrow::array::{Array, ArrayRef, StringArray}; use arrow::compute; use arrow::compute::kernels::comparison; -use arrow::datatypes::{DataType as ArrowDataType, Int64Type, TimeUnit}; +use arrow::datatypes::{ + DataType as ArrowDataType, Int16Type, Int32Type, Int64Type, Int8Type, TimeUnit, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, +}; use arrow_array::{DictionaryArray, StructArray}; use arrow_schema::IntervalUnit; use datafusion_common::ScalarValue; @@ -351,16 +354,37 @@ impl Helper { ArrowDataType::Decimal128(_, _) => { Arc::new(Decimal128Vector::try_from_arrow_array(array)?) } - ArrowDataType::Dictionary(key, value) if matches!(&**key, ArrowDataType::Int64) => { - let array = array - .as_ref() - .as_any() - .downcast_ref::>() - .unwrap(); // Safety: the type is guarded by match arm condition - Arc::new(DictionaryVector::new( - array.clone(), - ConcreteDataType::try_from(value.as_ref())?, - )?) + ArrowDataType::Dictionary(key, value) => { + macro_rules! handle_dictionary_key_type { + ($key_type:ident) => {{ + let array = array + .as_ref() + .as_any() + .downcast_ref::>() + .unwrap(); // Safety: the type is guarded by match arm condition + Arc::new(DictionaryVector::new( + array.clone(), + ConcreteDataType::try_from(value.as_ref())?, + )?) + }}; + } + + match key.as_ref() { + ArrowDataType::Int8 => handle_dictionary_key_type!(Int8Type), + ArrowDataType::Int16 => handle_dictionary_key_type!(Int16Type), + ArrowDataType::Int32 => handle_dictionary_key_type!(Int32Type), + ArrowDataType::Int64 => handle_dictionary_key_type!(Int64Type), + ArrowDataType::UInt8 => handle_dictionary_key_type!(UInt8Type), + ArrowDataType::UInt16 => handle_dictionary_key_type!(UInt16Type), + ArrowDataType::UInt32 => handle_dictionary_key_type!(UInt32Type), + ArrowDataType::UInt64 => handle_dictionary_key_type!(UInt64Type), + _ => { + return error::UnsupportedArrowTypeSnafu { + arrow_type: array.as_ref().data_type().clone(), + } + .fail() + } + } } ArrowDataType::Struct(_fields) => { @@ -375,7 +399,6 @@ impl Helper { | ArrowDataType::LargeList(_) | ArrowDataType::FixedSizeList(_, _) | ArrowDataType::Union(_, _) - | ArrowDataType::Dictionary(_, _) | ArrowDataType::Decimal256(_, _) | ArrowDataType::Map(_, _) | ArrowDataType::RunEndEncoded(_, _) @@ -629,10 +652,55 @@ mod tests { check_try_into_vector(Time64MicrosecondArray::from(vec![1, 2, 3])); check_try_into_vector(Time64NanosecondArray::from(vec![1, 2, 3])); + // Test dictionary arrays with different key types let values = StringArray::from_iter_values(["a", "b", "c"]); + + // Test Int8 keys let keys = Int8Array::from_iter_values([0, 0, 1, 2]); + let array: ArrayRef = + Arc::new(DictionaryArray::try_new(keys, Arc::new(values.clone())).unwrap()); + Helper::try_into_vector(array).unwrap(); + + // Test Int16 keys + let keys = Int16Array::from_iter_values([0, 0, 1, 2]); + let array: ArrayRef = + Arc::new(DictionaryArray::try_new(keys, Arc::new(values.clone())).unwrap()); + Helper::try_into_vector(array).unwrap(); + + // Test Int32 keys + let keys = Int32Array::from_iter_values([0, 0, 1, 2]); + let array: ArrayRef = + Arc::new(DictionaryArray::try_new(keys, Arc::new(values.clone())).unwrap()); + Helper::try_into_vector(array).unwrap(); + + // Test Int64 keys + let keys = Int64Array::from_iter_values([0, 0, 1, 2]); + let array: ArrayRef = + Arc::new(DictionaryArray::try_new(keys, Arc::new(values.clone())).unwrap()); + Helper::try_into_vector(array).unwrap(); + + // Test UInt8 keys + let keys = UInt8Array::from_iter_values([0, 0, 1, 2]); + let array: ArrayRef = + Arc::new(DictionaryArray::try_new(keys, Arc::new(values.clone())).unwrap()); + Helper::try_into_vector(array).unwrap(); + + // Test UInt16 keys + let keys = UInt16Array::from_iter_values([0, 0, 1, 2]); + let array: ArrayRef = + Arc::new(DictionaryArray::try_new(keys, Arc::new(values.clone())).unwrap()); + Helper::try_into_vector(array).unwrap(); + + // Test UInt32 keys + let keys = UInt32Array::from_iter_values([0, 0, 1, 2]); + let array: ArrayRef = + Arc::new(DictionaryArray::try_new(keys, Arc::new(values.clone())).unwrap()); + Helper::try_into_vector(array).unwrap(); + + // Test UInt64 keys + let keys = UInt64Array::from_iter_values([0, 0, 1, 2]); let array: ArrayRef = Arc::new(DictionaryArray::try_new(keys, Arc::new(values)).unwrap()); - Helper::try_into_vector(array).unwrap_err(); + Helper::try_into_vector(array).unwrap(); } #[test]