Added solution to force the type of a column.

This commit is contained in:
Paul Masurel
2023-01-17 15:13:41 +09:00
parent c9cb3d04bf
commit 4640fae516
2 changed files with 120 additions and 47 deletions

View File

@@ -102,18 +102,29 @@ pub(crate) struct NumericalColumnWriter {
column_writer: ColumnWriter,
}
impl NumericalColumnWriter {
pub fn force_numerical_type(&mut self, numerical_type: NumericalType) {
assert!(self
.compatible_numerical_types
.is_type_accepted(numerical_type));
self.compatible_numerical_types = CompatibleNumericalTypes::StaticType(numerical_type);
}
}
/// State used to store what types are still acceptable
/// after having seen a set of numerical values.
#[derive(Clone, Copy)]
struct CompatibleNumericalTypes {
all_values_within_i64_range: bool,
all_values_within_u64_range: bool,
// f64 is always acceptable.
enum CompatibleNumericalTypes {
Dynamic {
all_values_within_i64_range: bool,
all_values_within_u64_range: bool,
},
StaticType(NumericalType),
}
impl Default for CompatibleNumericalTypes {
fn default() -> CompatibleNumericalTypes {
CompatibleNumericalTypes {
CompatibleNumericalTypes::Dynamic {
all_values_within_i64_range: true,
all_values_within_u64_range: true,
}
@@ -121,31 +132,54 @@ impl Default for CompatibleNumericalTypes {
}
impl CompatibleNumericalTypes {
fn is_type_accepted(&self, numerical_type: NumericalType) -> bool {
match self {
CompatibleNumericalTypes::Dynamic {
all_values_within_i64_range,
all_values_within_u64_range,
} => match numerical_type {
NumericalType::I64 => *all_values_within_i64_range,
NumericalType::U64 => *all_values_within_u64_range,
NumericalType::F64 => true,
},
CompatibleNumericalTypes::StaticType(static_numerical_type) => {
*static_numerical_type == numerical_type
}
}
}
fn accept_value(&mut self, numerical_value: NumericalValue) {
match numerical_value {
NumericalValue::I64(val_i64) => {
let value_within_u64_range = val_i64 >= 0i64;
self.all_values_within_u64_range &= value_within_u64_range;
}
NumericalValue::U64(val_u64) => {
let value_within_i64_range = val_u64 < i64::MAX as u64;
self.all_values_within_i64_range &= value_within_i64_range;
}
NumericalValue::F64(_) => {
self.all_values_within_i64_range = false;
self.all_values_within_u64_range = false;
match self {
CompatibleNumericalTypes::Dynamic {
all_values_within_i64_range,
all_values_within_u64_range,
} => match numerical_value {
NumericalValue::I64(val_i64) => {
let value_within_u64_range = val_i64 >= 0i64;
*all_values_within_u64_range &= value_within_u64_range;
}
NumericalValue::U64(val_u64) => {
let value_within_i64_range = val_u64 < i64::MAX as u64;
*all_values_within_i64_range &= value_within_i64_range;
}
NumericalValue::F64(_) => {
*all_values_within_i64_range = false;
*all_values_within_u64_range = false;
}
},
CompatibleNumericalTypes::StaticType(typ) => {
assert_eq!(numerical_value.numerical_type(), *typ);
}
}
}
pub fn to_numerical_type(self) -> NumericalType {
if self.all_values_within_i64_range {
NumericalType::I64
} else if self.all_values_within_u64_range {
NumericalType::U64
} else {
NumericalType::F64
for numerical_type in [NumericalType::I64, NumericalType::U64] {
if self.is_type_accepted(numerical_type) {
return numerical_type;
}
}
NumericalType::F64
}
}
@@ -262,4 +296,27 @@ mod tests {
test_column_writer_coercion_aux(&[1i64.into(), 1u64.into()], NumericalType::I64);
test_column_writer_coercion_aux(&[u64::MAX.into(), (-1i64).into()], NumericalType::F64);
}
#[test]
#[should_panic]
fn test_compatible_numerical_types_static_incompatible_type() {
let mut compatible_numerical_types =
CompatibleNumericalTypes::StaticType(NumericalType::U64);
compatible_numerical_types.accept_value(NumericalValue::I64(1i64));
}
#[test]
fn test_compatible_numerical_types_static_different_type_forbidden() {
let mut compatible_numerical_types =
CompatibleNumericalTypes::StaticType(NumericalType::U64);
compatible_numerical_types.accept_value(NumericalValue::U64(u64::MAX));
}
#[test]
fn test_compatible_numerical_types_static() {
for typ in [NumericalType::I64, NumericalType::I64, NumericalType::F64] {
let compatible_numerical_types = CompatibleNumericalTypes::StaticType(typ);
assert_eq!(compatible_numerical_types.to_numerical_type(), typ);
}
}
}

View File

@@ -68,20 +68,46 @@ impl Default for ColumnarWriter {
}
}
#[inline]
fn mutate_or_create_column<V, TMutator>(
arena_hash_map: &mut ArenaHashMap,
column_name: &str,
updater: TMutator,
) where
V: Copy + 'static,
TMutator: FnMut(Option<V>) -> V,
{
assert!(
!column_name.as_bytes().contains(&0u8),
"key may not contain the 0 byte"
);
arena_hash_map.mutate_or_create(column_name.as_bytes(), updater);
}
impl ColumnarWriter {
pub fn force_numerical_type(&mut self, column_name: &str, numerical_type: NumericalType) {
let (hash_map, _) = (&mut self.numerical_field_hash_map, &mut self.arena);
mutate_or_create_column(
hash_map,
column_name,
|column_opt: Option<NumericalColumnWriter>| {
let mut column: NumericalColumnWriter = column_opt.unwrap_or_default();
column.force_numerical_type(numerical_type);
column
},
);
}
pub fn record_numerical<T: Into<NumericalValue> + Copy>(
&mut self,
doc: RowId,
column_name: &str,
numerical_value: T,
) {
assert!(
!column_name.as_bytes().contains(&0u8),
"key may not contain the 0 byte"
);
let (hash_map, arena) = (&mut self.numerical_field_hash_map, &mut self.arena);
hash_map.mutate_or_create(
column_name.as_bytes(),
mutate_or_create_column(
hash_map,
column_name,
|column_opt: Option<NumericalColumnWriter>| {
let mut column: NumericalColumnWriter = column_opt.unwrap_or_default();
column.record_numerical_value(doc, numerical_value.into(), arena);
@@ -91,33 +117,23 @@ impl ColumnarWriter {
}
pub fn record_bool(&mut self, doc: RowId, column_name: &str, val: bool) {
assert!(
!column_name.as_bytes().contains(&0u8),
"key may not contain the 0 byte"
);
let (hash_map, arena) = (&mut self.bool_field_hash_map, &mut self.arena);
hash_map.mutate_or_create(
column_name.as_bytes(),
|column_opt: Option<ColumnWriter>| {
let mut column: ColumnWriter = column_opt.unwrap_or_default();
column.record(doc, val, arena);
column
},
);
mutate_or_create_column(hash_map, column_name, |column_opt: Option<ColumnWriter>| {
let mut column: ColumnWriter = column_opt.unwrap_or_default();
column.record(doc, val, arena);
column
});
}
pub fn record_str(&mut self, doc: RowId, column_name: &str, value: &str) {
assert!(
!column_name.as_bytes().contains(&0u8),
"key may not contain the 0 byte"
);
let (hash_map, arena, dictionaries) = (
&mut self.bytes_field_hash_map,
&mut self.arena,
&mut self.dictionaries,
);
hash_map.mutate_or_create(
column_name.as_bytes(),
mutate_or_create_column(
hash_map,
column_name,
|column_opt: Option<StrColumnWriter>| {
let mut column: StrColumnWriter = column_opt.unwrap_or_else(|| {
// Each column has its own dictionary