From b2e980b450ae9e13d6975107df07b0dda758cb83 Mon Sep 17 00:00:00 2001 From: Stu Hood Date: Sat, 27 Dec 2025 18:37:15 -0700 Subject: [PATCH] Property test for `Comparator`/`ValueRange` consistency, and fixes. --- src/aggregation/metric/top_hits.rs | 46 +++++- src/collector/sort_key/mod.rs | 2 +- src/collector/sort_key/order.rs | 75 ++++++++- src/collector/sort_key/sort_key_computer.rs | 163 ++++++++++++++++++-- src/collector/top_score_collector.rs | 1 + 5 files changed, 264 insertions(+), 23 deletions(-) diff --git a/src/aggregation/metric/top_hits.rs b/src/aggregation/metric/top_hits.rs index 6a8bdf826..e1da51d6f 100644 --- a/src/aggregation/metric/top_hits.rs +++ b/src/aggregation/metric/top_hits.rs @@ -1,7 +1,8 @@ +use std::cmp::Ordering; use std::collections::HashMap; use std::net::Ipv6Addr; -use columnar::{Column, ColumnType, ColumnarReader, DynamicColumn}; +use columnar::{Column, ColumnType, ColumnarReader, DynamicColumn, ValueRange}; use common::json_path_writer::JSON_PATH_SEGMENT_SEP_STR; use common::DateTime; use regex::Regex; @@ -16,7 +17,7 @@ use crate::aggregation::intermediate_agg_result::{ }; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::AggregationError; -use crate::collector::sort_key::ReverseComparator; +use crate::collector::sort_key::{Comparator, ReverseComparator}; use crate::collector::TopNComputer; use crate::schema::OwnedValue; use crate::{DocAddress, DocId, SegmentOrdinal}; @@ -383,7 +384,7 @@ impl From for OwnedValue { /// Holds a fast field value in its u64 representation, and the order in which it should be sorted. #[derive(Clone, Serialize, Deserialize, Debug)] -struct DocValueAndOrder { +pub(crate) struct DocValueAndOrder { /// A fast field value in its u64 representation. value: Option, /// Sort order for the value @@ -455,6 +456,37 @@ impl PartialEq for DocSortValuesAndFields { impl Eq for DocSortValuesAndFields {} +impl Comparator for ReverseComparator { + #[inline(always)] + fn compare(&self, lhs: &DocSortValuesAndFields, rhs: &DocSortValuesAndFields) -> Ordering { + rhs.cmp(lhs) + } + + fn threshold_to_valuerange( + &self, + threshold: DocSortValuesAndFields, + ) -> ValueRange { + ValueRange::LessThan(threshold, true) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct TopHitsSegmentSortKey(pub Vec); + +impl Comparator for ReverseComparator { + #[inline(always)] + fn compare(&self, lhs: &TopHitsSegmentSortKey, rhs: &TopHitsSegmentSortKey) -> Ordering { + rhs.cmp(lhs) + } + + fn threshold_to_valuerange( + &self, + threshold: TopHitsSegmentSortKey, + ) -> ValueRange { + ValueRange::LessThan(threshold, true) + } +} + /// The TopHitsCollector used for collecting over segments and merging results. #[derive(Clone, Serialize, Deserialize, Debug)] pub struct TopHitsTopNComputer { @@ -518,7 +550,7 @@ impl TopHitsTopNComputer { pub(crate) struct TopHitsSegmentCollector { segment_ordinal: SegmentOrdinal, accessor_idx: usize, - top_n: TopNComputer, DocAddress, ReverseComparator>, + top_n: TopNComputer, } impl TopHitsSegmentCollector { @@ -539,13 +571,15 @@ impl TopHitsSegmentCollector { req: &TopHitsAggregationReq, ) -> TopHitsTopNComputer { let mut top_hits_computer = TopHitsTopNComputer::new(req); + // Map TopHitsSegmentSortKey back to Vec if needed or use directly + // The TopNComputer here stores TopHitsSegmentSortKey. let top_results = self.top_n.into_vec(); for res in top_results { let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id); top_hits_computer.collect( DocSortValuesAndFields { - sorts: res.sort_key, + sorts: res.sort_key.0, doc_value_fields, }, res.doc, @@ -579,7 +613,7 @@ impl TopHitsSegmentCollector { .collect(); self.top_n.push( - sorts, + TopHitsSegmentSortKey(sorts), DocAddress { segment_ord: self.segment_ordinal, doc_id, diff --git a/src/collector/sort_key/mod.rs b/src/collector/sort_key/mod.rs index b89059809..6f4915ede 100644 --- a/src/collector/sort_key/mod.rs +++ b/src/collector/sort_key/mod.rs @@ -604,7 +604,7 @@ mod tests { segments_data in proptest::collection::vec( proptest::collection::vec( proptest::option::of(0..100u64), - 1..10_usize // segment size + 1..1000_usize // segment size ), 1..4_usize // num segments ) diff --git a/src/collector/sort_key/order.rs b/src/collector/sort_key/order.rs index a188559aa..75d9b6fe1 100644 --- a/src/collector/sort_key/order.rs +++ b/src/collector/sort_key/order.rs @@ -142,16 +142,68 @@ impl Comparator for NaturalComparator { #[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)] pub struct ReverseComparator; -impl Comparator for ReverseComparator -where NaturalComparator: Comparator +macro_rules! impl_reverse_comparator_primitive { + ($($t:ty),*) => { + $( + impl Comparator<$t> for ReverseComparator { + #[inline(always)] + fn compare(&self, lhs: &$t, rhs: &$t) -> Ordering { + NaturalComparator.compare(rhs, lhs) + } + + fn threshold_to_valuerange(&self, threshold: $t) -> ValueRange<$t> { + ValueRange::LessThan(threshold, true) + } + } + )* + } +} + +impl_reverse_comparator_primitive!( + bool, + u8, + u16, + u32, + u64, + u128, + usize, + i8, + i16, + i32, + i64, + i128, + isize, + f32, + f64, + String, + crate::DateTime, + Vec, + crate::schema::Facet +); + +impl Comparator> + for ReverseComparator { #[inline(always)] - fn compare(&self, lhs: &T, rhs: &T) -> Ordering { + fn compare(&self, lhs: &Option, rhs: &Option) -> Ordering { NaturalComparator.compare(rhs, lhs) } - fn threshold_to_valuerange(&self, threshold: T) -> ValueRange { - ValueRange::LessThan(threshold, true) + fn threshold_to_valuerange(&self, threshold: Option) -> ValueRange> { + let is_some = threshold.is_some(); + ValueRange::LessThan(threshold, is_some) + } +} + +impl Comparator for ReverseComparator { + #[inline(always)] + fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering { + NaturalComparator.compare(rhs, lhs) + } + + fn threshold_to_valuerange(&self, threshold: OwnedValue) -> ValueRange { + let is_not_null = !matches!(threshold, OwnedValue::Null); + ValueRange::LessThan(threshold, is_not_null) } } @@ -181,7 +233,11 @@ where ReverseComparator: Comparator } fn threshold_to_valuerange(&self, threshold: Option) -> ValueRange> { - ValueRange::LessThan(threshold, false) + if threshold.is_some() { + ValueRange::LessThan(threshold, false) + } else { + ValueRange::GreaterThan(threshold, false) + } } } @@ -284,7 +340,12 @@ where NaturalComparator: Comparator } fn threshold_to_valuerange(&self, threshold: Option) -> ValueRange> { - ValueRange::GreaterThan(threshold, true) + if threshold.is_some() { + let is_some = threshold.is_some(); + ValueRange::GreaterThan(threshold, is_some) + } else { + ValueRange::LessThan(threshold, false) + } } } diff --git a/src/collector/sort_key/sort_key_computer.rs b/src/collector/sort_key/sort_key_computer.rs index 875448f51..517f4a97e 100644 --- a/src/collector/sort_key/sort_key_computer.rs +++ b/src/collector/sort_key/sort_key_computer.rs @@ -743,27 +743,39 @@ pub(crate) fn range_contains_none(range: &ValueRange>) -> bool { match range { ValueRange::All => true, ValueRange::Inclusive(r) => r.contains(&None), - ValueRange::GreaterThan(threshold, match_nulls) => *match_nulls || (None > *threshold), - ValueRange::LessThan(threshold, match_nulls) => *match_nulls || (None < *threshold), + ValueRange::GreaterThan(_threshold, match_nulls) => *match_nulls, + ValueRange::LessThan(_threshold, match_nulls) => *match_nulls, } } pub(crate) fn convert_optional_u64_range_to_u64_range( range: ValueRange>, ) -> ValueRange { - if range_contains_none(&range) { - return ValueRange::All; - } match range { ValueRange::Inclusive(r) => { let start = r.start().unwrap_or(0); let end = r.end().unwrap_or(u64::MAX); ValueRange::Inclusive(start..=end) } - ValueRange::GreaterThan(Some(val), _match_nulls) => ValueRange::GreaterThan(val, false), - ValueRange::GreaterThan(None, _match_nulls) => ValueRange::Inclusive(u64::MIN..=u64::MAX), - ValueRange::LessThan(None, _match_nulls) => ValueRange::Inclusive(1..=0), - _ => ValueRange::All, + ValueRange::GreaterThan(Some(val), match_nulls) => { + ValueRange::GreaterThan(val, match_nulls) + } + ValueRange::GreaterThan(None, match_nulls) => { + if match_nulls { + ValueRange::All + } else { + ValueRange::Inclusive(u64::MIN..=u64::MAX) + } + } + ValueRange::LessThan(None, match_nulls) => { + if match_nulls { + ValueRange::LessThan(u64::MIN, true) + } else { + ValueRange::Inclusive(1..=0) + } + } + ValueRange::LessThan(Some(val), match_nulls) => ValueRange::LessThan(val, match_nulls), + ValueRange::All => ValueRange::All, } } @@ -922,3 +934,136 @@ mod tests { ); } } + +#[cfg(test)] +mod proptest_tests { + use proptest::prelude::*; + + use super::*; + use crate::collector::sort_key::order::*; + + // Re-implement logic to interpret ValueRange> manually to verify expectations + fn range_contains_opt(range: &ValueRange>, val: &Option) -> bool { + match range { + ValueRange::All => true, + ValueRange::Inclusive(r) => r.contains(val), + ValueRange::GreaterThan(t, match_nulls) => { + if val.is_none() { + *match_nulls + } else { + val > t + } + } + ValueRange::LessThan(t, match_nulls) => { + if val.is_none() { + *match_nulls + } else { + val < t + } + } + } + } + + fn range_contains_u64(range: &ValueRange, val: &u64) -> bool { + match range { + ValueRange::All => true, + ValueRange::Inclusive(r) => r.contains(val), + ValueRange::GreaterThan(t, _) => val > t, + ValueRange::LessThan(t, _) => val < t, + } + } + + proptest! { + #[test] + fn test_comparator_consistency_natural_none_is_lower( + threshold in any::>(), + val in any::>() + ) { + check_comparator::(threshold, val)?; + } + + #[test] + fn test_comparator_consistency_reverse( + threshold in any::>(), + val in any::>() + ) { + check_comparator::(threshold, val)?; + } + + #[test] + fn test_comparator_consistency_reverse_none_is_lower( + threshold in any::>(), + val in any::>() + ) { + check_comparator::(threshold, val)?; + } + + #[test] + fn test_comparator_consistency_natural_none_is_higher( + threshold in any::>(), + val in any::>() + ) { + check_comparator::(threshold, val)?; + } + } + + fn check_comparator>>( + threshold: Option, + val: Option, + ) -> std::result::Result<(), proptest::test_runner::TestCaseError> { + let comparator = C::default(); + let range = comparator.threshold_to_valuerange(threshold); + let ordering = comparator.compare(&val, &threshold); + let should_be_in_range = ordering == Ordering::Greater; + + let in_range_opt = range_contains_opt(&range, &val); + + prop_assert_eq!( + in_range_opt, + should_be_in_range, + "Comparator consistency failed for {:?}. Threshold: {:?}, Val: {:?}, Range: {:?}, \ + Ordering: {:?}. range_contains_opt says {}, but compare says {}", + std::any::type_name::(), + threshold, + val, + range, + ordering, + in_range_opt, + should_be_in_range + ); + + // Check range_contains_none + let expected_none_in_range = range_contains_opt(&range, &None); + let actual_none_in_range = range_contains_none(&range); + prop_assert_eq!( + actual_none_in_range, + expected_none_in_range, + "range_contains_none failed for {:?}. Range: {:?}. Expected (from \ + range_contains_opt): {}, Actual: {}", + std::any::type_name::(), + range, + expected_none_in_range, + actual_none_in_range + ); + + // Check convert_optional_u64_range_to_u64_range + let u64_range = convert_optional_u64_range_to_u64_range(range.clone()); + if let Some(v) = val { + let in_u64_range = range_contains_u64(&u64_range, &v); + let in_opt_range = range_contains_opt(&range, &Some(v)); + prop_assert_eq!( + in_u64_range, + in_opt_range, + "convert_optional_u64_range_to_u64_range failed for {:?}. Val: {:?}, OptRange: \ + {:?}, U64Range: {:?}. Opt says {}, U64 says {}", + std::any::type_name::(), + v, + range, + u64_range, + in_opt_range, + in_u64_range + ); + } + Ok(()) + } +} diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 1ccd5fd93..3c26faa71 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -597,6 +597,7 @@ where D: Ord, TSortKey: Clone, NaturalComparator: Comparator, + ReverseComparator: Comparator, { /// Create a new `TopNComputer`. /// Internally it will allocate a buffer of size `2 * top_n`.