diff --git a/columnar/src/columnar/writer/column_writers.rs b/columnar/src/columnar/writer/column_writers.rs index 89f112f95..3400dad9f 100644 --- a/columnar/src/columnar/writer/column_writers.rs +++ b/columnar/src/columnar/writer/column_writers.rs @@ -102,18 +102,29 @@ pub(crate) struct NumericalColumnWriter { column_writer: ColumnWriter, } +impl NumericalColumnWriter { + pub fn force_numerical_type(&mut self, numerical_type: NumericalType) { + assert!(self + .compatible_numerical_types + .is_type_accepted(numerical_type)); + self.compatible_numerical_types = CompatibleNumericalTypes::StaticType(numerical_type); + } +} + /// State used to store what types are still acceptable /// after having seen a set of numerical values. #[derive(Clone, Copy)] -struct CompatibleNumericalTypes { - all_values_within_i64_range: bool, - all_values_within_u64_range: bool, - // f64 is always acceptable. +enum CompatibleNumericalTypes { + Dynamic { + all_values_within_i64_range: bool, + all_values_within_u64_range: bool, + }, + StaticType(NumericalType), } impl Default for CompatibleNumericalTypes { fn default() -> CompatibleNumericalTypes { - CompatibleNumericalTypes { + CompatibleNumericalTypes::Dynamic { all_values_within_i64_range: true, all_values_within_u64_range: true, } @@ -121,31 +132,54 @@ impl Default for CompatibleNumericalTypes { } impl CompatibleNumericalTypes { + fn is_type_accepted(&self, numerical_type: NumericalType) -> bool { + match self { + CompatibleNumericalTypes::Dynamic { + all_values_within_i64_range, + all_values_within_u64_range, + } => match numerical_type { + NumericalType::I64 => *all_values_within_i64_range, + NumericalType::U64 => *all_values_within_u64_range, + NumericalType::F64 => true, + }, + CompatibleNumericalTypes::StaticType(static_numerical_type) => { + *static_numerical_type == numerical_type + } + } + } + fn accept_value(&mut self, numerical_value: NumericalValue) { - match numerical_value { - NumericalValue::I64(val_i64) => { - let value_within_u64_range = val_i64 >= 0i64; - self.all_values_within_u64_range &= value_within_u64_range; - } - NumericalValue::U64(val_u64) => { - let value_within_i64_range = val_u64 < i64::MAX as u64; - self.all_values_within_i64_range &= value_within_i64_range; - } - NumericalValue::F64(_) => { - self.all_values_within_i64_range = false; - self.all_values_within_u64_range = false; + match self { + CompatibleNumericalTypes::Dynamic { + all_values_within_i64_range, + all_values_within_u64_range, + } => match numerical_value { + NumericalValue::I64(val_i64) => { + let value_within_u64_range = val_i64 >= 0i64; + *all_values_within_u64_range &= value_within_u64_range; + } + NumericalValue::U64(val_u64) => { + let value_within_i64_range = val_u64 < i64::MAX as u64; + *all_values_within_i64_range &= value_within_i64_range; + } + NumericalValue::F64(_) => { + *all_values_within_i64_range = false; + *all_values_within_u64_range = false; + } + }, + CompatibleNumericalTypes::StaticType(typ) => { + assert_eq!(numerical_value.numerical_type(), *typ); } } } pub fn to_numerical_type(self) -> NumericalType { - if self.all_values_within_i64_range { - NumericalType::I64 - } else if self.all_values_within_u64_range { - NumericalType::U64 - } else { - NumericalType::F64 + for numerical_type in [NumericalType::I64, NumericalType::U64] { + if self.is_type_accepted(numerical_type) { + return numerical_type; + } } + NumericalType::F64 } } @@ -262,4 +296,27 @@ mod tests { test_column_writer_coercion_aux(&[1i64.into(), 1u64.into()], NumericalType::I64); test_column_writer_coercion_aux(&[u64::MAX.into(), (-1i64).into()], NumericalType::F64); } + + #[test] + #[should_panic] + fn test_compatible_numerical_types_static_incompatible_type() { + let mut compatible_numerical_types = + CompatibleNumericalTypes::StaticType(NumericalType::U64); + compatible_numerical_types.accept_value(NumericalValue::I64(1i64)); + } + + #[test] + fn test_compatible_numerical_types_static_different_type_forbidden() { + let mut compatible_numerical_types = + CompatibleNumericalTypes::StaticType(NumericalType::U64); + compatible_numerical_types.accept_value(NumericalValue::U64(u64::MAX)); + } + + #[test] + fn test_compatible_numerical_types_static() { + for typ in [NumericalType::I64, NumericalType::I64, NumericalType::F64] { + let compatible_numerical_types = CompatibleNumericalTypes::StaticType(typ); + assert_eq!(compatible_numerical_types.to_numerical_type(), typ); + } + } } diff --git a/columnar/src/columnar/writer/mod.rs b/columnar/src/columnar/writer/mod.rs index c3522448b..962aaf8c0 100644 --- a/columnar/src/columnar/writer/mod.rs +++ b/columnar/src/columnar/writer/mod.rs @@ -68,20 +68,46 @@ impl Default for ColumnarWriter { } } +#[inline] +fn mutate_or_create_column( + arena_hash_map: &mut ArenaHashMap, + column_name: &str, + updater: TMutator, +) where + V: Copy + 'static, + TMutator: FnMut(Option) -> V, +{ + assert!( + !column_name.as_bytes().contains(&0u8), + "key may not contain the 0 byte" + ); + arena_hash_map.mutate_or_create(column_name.as_bytes(), updater); +} + impl ColumnarWriter { + pub fn force_numerical_type(&mut self, column_name: &str, numerical_type: NumericalType) { + let (hash_map, _) = (&mut self.numerical_field_hash_map, &mut self.arena); + mutate_or_create_column( + hash_map, + column_name, + |column_opt: Option| { + let mut column: NumericalColumnWriter = column_opt.unwrap_or_default(); + column.force_numerical_type(numerical_type); + column + }, + ); + } + pub fn record_numerical + Copy>( &mut self, doc: RowId, column_name: &str, numerical_value: T, ) { - assert!( - !column_name.as_bytes().contains(&0u8), - "key may not contain the 0 byte" - ); let (hash_map, arena) = (&mut self.numerical_field_hash_map, &mut self.arena); - hash_map.mutate_or_create( - column_name.as_bytes(), + mutate_or_create_column( + hash_map, + column_name, |column_opt: Option| { let mut column: NumericalColumnWriter = column_opt.unwrap_or_default(); column.record_numerical_value(doc, numerical_value.into(), arena); @@ -91,33 +117,23 @@ impl ColumnarWriter { } pub fn record_bool(&mut self, doc: RowId, column_name: &str, val: bool) { - assert!( - !column_name.as_bytes().contains(&0u8), - "key may not contain the 0 byte" - ); let (hash_map, arena) = (&mut self.bool_field_hash_map, &mut self.arena); - hash_map.mutate_or_create( - column_name.as_bytes(), - |column_opt: Option| { - let mut column: ColumnWriter = column_opt.unwrap_or_default(); - column.record(doc, val, arena); - column - }, - ); + mutate_or_create_column(hash_map, column_name, |column_opt: Option| { + let mut column: ColumnWriter = column_opt.unwrap_or_default(); + column.record(doc, val, arena); + column + }); } pub fn record_str(&mut self, doc: RowId, column_name: &str, value: &str) { - assert!( - !column_name.as_bytes().contains(&0u8), - "key may not contain the 0 byte" - ); let (hash_map, arena, dictionaries) = ( &mut self.bytes_field_hash_map, &mut self.arena, &mut self.dictionaries, ); - hash_map.mutate_or_create( - column_name.as_bytes(), + mutate_or_create_column( + hash_map, + column_name, |column_opt: Option| { let mut column: StrColumnWriter = column_opt.unwrap_or_else(|| { // Each column has its own dictionary