diff --git a/zenith_utils/src/ordered_vec.rs b/zenith_utils/src/ordered_vec.rs index f35c4f6ca8..14837e0f3d 100644 --- a/zenith_utils/src/ordered_vec.rs +++ b/zenith_utils/src/ordered_vec.rs @@ -20,6 +20,12 @@ impl OrderedVec { } pub fn range>(&self, range: R) -> &[(K, V)] { + match (range.start_bound(), range.end_bound()) { + (Bound::Excluded(l), Bound::Excluded(u)) if l == u => panic!("Invalid excluded"), + // TODO check for l <= x with or patterns + _ => {} + } + let start_idx = match range.start_bound() { Bound::Included(key) => match self.0.binary_search_by_key(key, extract_key) { Ok(idx) => idx, @@ -34,12 +40,12 @@ impl OrderedVec { let end_idx = match range.end_bound() { Bound::Included(key) => match self.0.binary_search_by_key(key, extract_key) { - Ok(idx) => idx, - Err(idx) => idx + 1, + Ok(idx) => idx + 1, + Err(idx) => idx, }, Bound::Excluded(key) => match self.0.binary_search_by_key(key, extract_key) { - Ok(idx) => idx + 1, - Err(idx) => idx + 1, + Ok(idx) => idx, + Err(idx) => idx, }, Bound::Unbounded => self.0.len(), }; @@ -74,3 +80,135 @@ impl From> for OrderedVec { fn extract_key(pair: &(K, V)) -> K { pair.0 } + +#[cfg(test)] +mod tests { + use std::{ + collections::BTreeMap, + ops::{Bound, RangeBounds}, + }; + + use super::OrderedVec; + + #[test] + #[should_panic] + fn invalid_range() { + let mut map = BTreeMap::new(); + map.insert(0, ()); + + let vec: OrderedVec = OrderedVec::from(map); + struct InvalidRange; + impl RangeBounds for InvalidRange { + fn start_bound(&self) -> Bound<&i32> { + Bound::Excluded(&0) + } + + fn end_bound(&self) -> Bound<&i32> { + Bound::Excluded(&0) + } + } + + vec.range(InvalidRange); + } + + #[test] + fn range_tests() { + let mut map = BTreeMap::new(); + map.insert(0, ()); + map.insert(2, ()); + map.insert(4, ()); + let vec = OrderedVec::from(map); + + assert_eq!(vec.range(0..0), &[]); + assert_eq!(vec.range(0..1), &[(0, ())]); + assert_eq!(vec.range(0..2), &[(0, ())]); + assert_eq!(vec.range(0..3), &[(0, ()), (2, ())]); + + assert_eq!(vec.range(..0), &[]); + assert_eq!(vec.range(..1), &[(0, ())]); + + assert_eq!(vec.range(..3), &[(0, ()), (2, ())]); + assert_eq!(vec.range(..3), &[(0, ()), (2, ())]); + + assert_eq!(vec.range(0..=0), &[(0, ())]); + assert_eq!(vec.range(0..=1), &[(0, ())]); + assert_eq!(vec.range(0..=2), &[(0, ()), (2, ())]); + assert_eq!(vec.range(0..=3), &[(0, ()), (2, ())]); + + assert_eq!(vec.range(..=0), &[(0, ())]); + assert_eq!(vec.range(..=1), &[(0, ())]); + assert_eq!(vec.range(..=2), &[(0, ()), (2, ())]); + assert_eq!(vec.range(..=3), &[(0, ()), (2, ())]); + } + + struct BoundIter { + min: i32, + max: i32, + + next: Option>, + } + + impl BoundIter { + fn new(min: i32, max: i32) -> Self { + Self { + min, + max, + + next: Some(Bound::Unbounded), + } + } + } + + impl Iterator for BoundIter { + type Item = Bound; + + fn next(&mut self) -> Option { + let cur = self.next?; + + self.next = match &cur { + Bound::Unbounded => Some(Bound::Included(self.min)), + Bound::Included(x) => { + if *x >= self.max { + Some(Bound::Excluded(self.min)) + } else { + Some(Bound::Included(x + 1)) + } + } + Bound::Excluded(x) => { + if *x >= self.max { + None + } else { + Some(Bound::Excluded(x + 1)) + } + } + }; + + Some(cur) + } + } + + #[test] + fn range_exhaustive() { + let map: BTreeMap = (1..=7).step_by(2).map(|x| (x, ())).collect(); + let vec = OrderedVec::from(map.clone()); + + const RANGE_MIN: i32 = 0; + const RANGE_MAX: i32 = 8; + for lower_bound in BoundIter::new(RANGE_MIN, RANGE_MAX) { + let ub_min = match lower_bound { + Bound::Unbounded => RANGE_MIN, + Bound::Included(x) => x, + Bound::Excluded(x) => x + 1, + }; + for upper_bound in BoundIter::new(ub_min, RANGE_MAX) { + let map_range: Vec<(i32, ())> = map + .range((lower_bound, upper_bound)) + .map(|(&x, _)| (x, ())) + .collect(); + let vec_slice = vec.range((lower_bound, upper_bound)); + + assert_eq!(map_range, vec_slice); + } + } + } +}