diff --git a/src/collector/facet_collector.rs b/src/collector/facet_collector.rs index 45ffab021..22de5c3bc 100644 --- a/src/collector/facet_collector.rs +++ b/src/collector/facet_collector.rs @@ -414,8 +414,8 @@ impl FacetCounts { pub fn get(&self, facet_from: T) -> FacetChildIterator<'_> where Facet: From { let facet = Facet::from(facet_from); - let left_bound = Bound::Excluded(facet.clone()); - let right_bound = if facet.is_root() { + let lower_bound = Bound::Excluded(facet.clone()); + let upper_bound = if facet.is_root() { Bound::Unbounded } else { let mut facet_after_bytes: String = facet.encoded_str().to_owned(); @@ -424,7 +424,7 @@ impl FacetCounts { Bound::Excluded(facet_after) }; let underlying: btree_map::Range<'_, _, _> = - self.facet_counts.range((left_bound, right_bound)); + self.facet_counts.range((lower_bound, upper_bound)); FacetChildIterator { underlying } } diff --git a/src/query/mod.rs b/src/query/mod.rs index 28da5ee56..8fc23169e 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -51,7 +51,7 @@ pub use self::phrase_prefix_query::PhrasePrefixQuery; pub use self::phrase_query::PhraseQuery; pub use self::query::{EnableScoring, Query, QueryClone}; pub use self::query_parser::{QueryParser, QueryParserError}; -pub use self::range_query::RangeQuery; +pub use self::range_query::{FastFieldRangeWeight, IPFastFieldRangeWeight, RangeQuery}; pub use self::regex_query::RegexQuery; pub use self::reqopt_scorer::RequiredOptionalScorer; pub use self::score_combiner::{ diff --git a/src/query/query_parser/query_parser.rs b/src/query/query_parser/query_parser.rs index a3ca5f2ca..6b66f93aa 100644 --- a/src/query/query_parser/query_parser.rs +++ b/src/query/query_parser/query_parser.rs @@ -974,8 +974,8 @@ mod test { let query = make_query_parser().parse_query("title:[A TO B]").unwrap(); assert_eq!( format!("{:?}", query), - "RangeQuery { field: \"title\", value_type: Str, left_bound: Included([97]), \ - right_bound: Included([98]), limit: None }" + "RangeQuery { field: \"title\", value_type: Str, lower_bound: Included([97]), \ + upper_bound: Included([98]), limit: None }" ); } diff --git a/src/query/range_query/mod.rs b/src/query/range_query/mod.rs index 6b4e27f0d..a4f2aec63 100644 --- a/src/query/range_query/mod.rs +++ b/src/query/range_query/mod.rs @@ -7,7 +7,9 @@ mod range_query; mod range_query_ip_fastfield; mod range_query_u64_fastfield; -pub use self::range_query::RangeQuery; +pub use self::range_query::{RangeQuery, RangeWeight}; +pub use self::range_query_ip_fastfield::IPFastFieldRangeWeight; +pub use self::range_query_u64_fastfield::FastFieldRangeWeight; // TODO is this correct? pub(crate) fn is_type_valid_for_fastfield_range_query(typ: Type) -> bool { @@ -18,10 +20,7 @@ pub(crate) fn is_type_valid_for_fastfield_range_query(typ: Type) -> bool { } } -fn map_bound TTo>( - bound: &Bound, - transform: &Transform, -) -> Bound { +fn map_bound(bound: &Bound, transform: impl Fn(&TFrom) -> TTo) -> Bound { use self::Bound::*; match bound { Excluded(ref from_val) => Excluded(transform(from_val)), @@ -29,3 +28,15 @@ fn map_bound TTo>( Unbounded => Unbounded, } } + +fn map_bound_res( + bound: &Bound, + transform: impl Fn(&TFrom) -> Result, +) -> Result, Err> { + use self::Bound::*; + Ok(match bound { + Excluded(ref from_val) => Excluded(transform(from_val)?), + Included(ref from_val) => Included(transform(from_val)?), + Unbounded => Unbounded, + }) +} diff --git a/src/query/range_query/range_query.rs b/src/query/range_query/range_query.rs index 8e080368f..09bf735d8 100644 --- a/src/query/range_query/range_query.rs +++ b/src/query/range_query/range_query.rs @@ -2,6 +2,7 @@ use std::io; use std::net::Ipv6Addr; use std::ops::{Bound, Range}; +use columnar::MonotonicallyMappableToU128; use common::{BinarySerializable, BitSet}; use super::map_bound; @@ -9,8 +10,8 @@ use super::range_query_u64_fastfield::FastFieldRangeWeight; use crate::core::SegmentReader; use crate::error::TantivyError; use crate::query::explanation::does_not_match; -use crate::query::range_query::is_type_valid_for_fastfield_range_query; use crate::query::range_query::range_query_ip_fastfield::IPFastFieldRangeWeight; +use crate::query::range_query::{is_type_valid_for_fastfield_range_query, map_bound_res}; use crate::query::{BitSetDocSet, ConstScorer, EnableScoring, Explanation, Query, Scorer, Weight}; use crate::schema::{Field, IndexRecordOption, Term, Type}; use crate::termdict::{TermDictionary, TermStreamer}; @@ -69,8 +70,8 @@ use crate::{DateTime, DocId, Score}; pub struct RangeQuery { field: String, value_type: Type, - left_bound: Bound>, - right_bound: Bound>, + lower_bound: Bound>, + upper_bound: Bound>, limit: Option, } @@ -82,15 +83,15 @@ impl RangeQuery { pub fn new_term_bounds( field: String, value_type: Type, - left_bound: &Bound, - right_bound: &Bound, + lower_bound: &Bound, + upper_bound: &Bound, ) -> RangeQuery { let verify_and_unwrap_term = |val: &Term| val.serialized_value_bytes().to_owned(); RangeQuery { field, value_type, - left_bound: map_bound(left_bound, &verify_and_unwrap_term), - right_bound: map_bound(right_bound, &verify_and_unwrap_term), + lower_bound: map_bound(lower_bound, verify_and_unwrap_term), + upper_bound: map_bound(upper_bound, verify_and_unwrap_term), limit: None, } } @@ -116,8 +117,8 @@ impl RangeQuery { /// will panic when the `Weight` object is created. pub fn new_i64_bounds( field: String, - left_bound: Bound, - right_bound: Bound, + lower_bound: Bound, + upper_bound: Bound, ) -> RangeQuery { let make_term_val = |val: &i64| { Term::from_field_i64(Field::from_field_id(0), *val) @@ -127,8 +128,8 @@ impl RangeQuery { RangeQuery { field, value_type: Type::I64, - left_bound: map_bound(&left_bound, &make_term_val), - right_bound: map_bound(&right_bound, &make_term_val), + lower_bound: map_bound(&lower_bound, make_term_val), + upper_bound: map_bound(&upper_bound, make_term_val), limit: None, } } @@ -154,8 +155,8 @@ impl RangeQuery { /// will panic when the `Weight` object is created. pub fn new_f64_bounds( field: String, - left_bound: Bound, - right_bound: Bound, + lower_bound: Bound, + upper_bound: Bound, ) -> RangeQuery { let make_term_val = |val: &f64| { Term::from_field_f64(Field::from_field_id(0), *val) @@ -165,8 +166,8 @@ impl RangeQuery { RangeQuery { field, value_type: Type::F64, - left_bound: map_bound(&left_bound, &make_term_val), - right_bound: map_bound(&right_bound, &make_term_val), + lower_bound: map_bound(&lower_bound, make_term_val), + upper_bound: map_bound(&upper_bound, make_term_val), limit: None, } } @@ -180,8 +181,8 @@ impl RangeQuery { /// will panic when the `Weight` object is created. pub fn new_u64_bounds( field: String, - left_bound: Bound, - right_bound: Bound, + lower_bound: Bound, + upper_bound: Bound, ) -> RangeQuery { let make_term_val = |val: &u64| { Term::from_field_u64(Field::from_field_id(0), *val) @@ -191,8 +192,8 @@ impl RangeQuery { RangeQuery { field, value_type: Type::U64, - left_bound: map_bound(&left_bound, &make_term_val), - right_bound: map_bound(&right_bound, &make_term_val), + lower_bound: map_bound(&lower_bound, make_term_val), + upper_bound: map_bound(&upper_bound, make_term_val), limit: None, } } @@ -203,8 +204,8 @@ impl RangeQuery { /// will panic when the `Weight` object is created. pub fn new_ip_bounds( field: String, - left_bound: Bound, - right_bound: Bound, + lower_bound: Bound, + upper_bound: Bound, ) -> RangeQuery { let make_term_val = |val: &Ipv6Addr| { Term::from_field_ip_addr(Field::from_field_id(0), *val) @@ -214,8 +215,8 @@ impl RangeQuery { RangeQuery { field, value_type: Type::IpAddr, - left_bound: map_bound(&left_bound, &make_term_val), - right_bound: map_bound(&right_bound, &make_term_val), + lower_bound: map_bound(&lower_bound, make_term_val), + upper_bound: map_bound(&upper_bound, make_term_val), limit: None, } } @@ -241,8 +242,8 @@ impl RangeQuery { /// will panic when the `Weight` object is created. pub fn new_date_bounds( field: String, - left_bound: Bound, - right_bound: Bound, + lower_bound: Bound, + upper_bound: Bound, ) -> RangeQuery { let make_term_val = |val: &DateTime| { Term::from_field_date(Field::from_field_id(0), *val) @@ -252,8 +253,8 @@ impl RangeQuery { RangeQuery { field, value_type: Type::Date, - left_bound: map_bound(&left_bound, &make_term_val), - right_bound: map_bound(&right_bound, &make_term_val), + lower_bound: map_bound(&lower_bound, make_term_val), + upper_bound: map_bound(&upper_bound, make_term_val), limit: None, } } @@ -277,13 +278,17 @@ impl RangeQuery { /// /// If the field is not of the type `Str`, tantivy /// will panic when the `Weight` object is created. - pub fn new_str_bounds(field: String, left: Bound<&str>, right: Bound<&str>) -> RangeQuery { + pub fn new_str_bounds( + field: String, + lower_bound: Bound<&str>, + upper_bound: Bound<&str>, + ) -> RangeQuery { let make_term_val = |val: &&str| val.as_bytes().to_vec(); RangeQuery { field, value_type: Type::Str, - left_bound: map_bound(&left, &make_term_val), - right_bound: map_bound(&right, &make_term_val), + lower_bound: map_bound(&lower_bound, make_term_val), + upper_bound: map_bound(&upper_bound, make_term_val), limit: None, } } @@ -340,10 +345,21 @@ impl Query for RangeQuery { if field_type.is_fast() && is_type_valid_for_fastfield_range_query(self.value_type) { if field_type.is_ip_addr() { + let parse_ip_from_bytes = |data: &Vec| { + let ip_u128_bytes: [u8; 16] = data.as_slice().try_into().map_err(|_| { + crate::TantivyError::InvalidArgument( + "Expected 8 bytes for ip address".to_string(), + ) + })?; + let ip_u128 = u128::from_be_bytes(ip_u128_bytes); + crate::Result::::Ok(Ipv6Addr::from_u128(ip_u128)) + }; + let lower_bound = map_bound_res(&self.lower_bound, parse_ip_from_bytes)?; + let upper_bound = map_bound_res(&self.upper_bound, parse_ip_from_bytes)?; Ok(Box::new(IPFastFieldRangeWeight::new( self.field.to_string(), - &self.left_bound, - &self.right_bound, + lower_bound, + upper_bound, ))) } else { // We run the range query on u64 value space for performance reasons and simpicity @@ -353,19 +369,19 @@ impl Query for RangeQuery { u64::from_be(BinarySerializable::deserialize(&mut &data[..]).unwrap()) }; - let left_bound = map_bound(&self.left_bound, &parse_from_bytes); - let right_bound = map_bound(&self.right_bound, &parse_from_bytes); - Ok(Box::new(FastFieldRangeWeight::new( + let lower_bound = map_bound(&self.lower_bound, parse_from_bytes); + let upper_bound = map_bound(&self.upper_bound, parse_from_bytes); + Ok(Box::new(FastFieldRangeWeight::new_u64_lenient( self.field.to_string(), - left_bound, - right_bound, + lower_bound, + upper_bound, ))) } } else { Ok(Box::new(RangeWeight { field: self.field.to_string(), - left_bound: self.left_bound.clone(), - right_bound: self.right_bound.clone(), + lower_bound: self.lower_bound.clone(), + upper_bound: self.upper_bound.clone(), limit: self.limit, })) } @@ -374,8 +390,8 @@ impl Query for RangeQuery { pub struct RangeWeight { field: String, - left_bound: Bound>, - right_bound: Bound>, + lower_bound: Bound>, + upper_bound: Bound>, limit: Option, } @@ -383,12 +399,12 @@ impl RangeWeight { fn term_range<'a>(&self, term_dict: &'a TermDictionary) -> io::Result> { use std::ops::Bound::*; let mut term_stream_builder = term_dict.range(); - term_stream_builder = match self.left_bound { + term_stream_builder = match self.lower_bound { Included(ref term_val) => term_stream_builder.ge(term_val), Excluded(ref term_val) => term_stream_builder.gt(term_val), Unbounded => term_stream_builder, }; - term_stream_builder = match self.right_bound { + term_stream_builder = match self.upper_bound { Included(ref term_val) => term_stream_builder.le(term_val), Excluded(ref term_val) => term_stream_builder.lt(term_val), Unbounded => term_stream_builder, diff --git a/src/query/range_query/range_query_ip_fastfield.rs b/src/query/range_query/range_query_ip_fastfield.rs index 1cc6c1871..c8e303e3e 100644 --- a/src/query/range_query/range_query_ip_fastfield.rs +++ b/src/query/range_query/range_query_ip_fastfield.rs @@ -6,9 +6,7 @@ use std::net::Ipv6Addr; use std::ops::{Bound, RangeInclusive}; use columnar::{Column, MonotonicallyMappableToU128}; -use common::BinarySerializable; -use super::map_bound; use crate::query::range_query::fast_field_range_query::RangeDocSet; use crate::query::{ConstScorer, EmptyScorer, Explanation, Scorer, Weight}; use crate::{DocId, DocSet, Score, SegmentReader, TantivyError}; @@ -16,24 +14,17 @@ use crate::{DocId, DocSet, Score, SegmentReader, TantivyError}; /// `IPFastFieldRangeWeight` uses the ip address fast field to execute range queries. pub struct IPFastFieldRangeWeight { field: String, - left_bound: Bound, - right_bound: Bound, + lower_bound: Bound, + upper_bound: Bound, } impl IPFastFieldRangeWeight { - // TODO fix code smell... why do we end up working with Vec here? - pub fn new(field: String, left_bound: &Bound>, right_bound: &Bound>) -> Self { - let parse_ip_from_bytes = |data: &Vec| { - let ip_u128: u128 = - u128::from_be(BinarySerializable::deserialize(&mut &data[..]).unwrap()); - Ipv6Addr::from_u128(ip_u128) - }; - let left_bound = map_bound(left_bound, &parse_ip_from_bytes); - let right_bound = map_bound(right_bound, &parse_ip_from_bytes); + /// Creates a new IPFastFieldRangeWeight. + pub fn new(field: String, lower_bound: Bound, upper_bound: Bound) -> Self { Self { field, - left_bound, - right_bound, + lower_bound, + upper_bound, } } } @@ -45,8 +36,8 @@ impl Weight for IPFastFieldRangeWeight { return Ok(Box::new(EmptyScorer)) }; let value_range = bound_to_value_range( - &self.left_bound, - &self.right_bound, + &self.lower_bound, + &self.upper_bound, ip_addr_column.min_value(), ip_addr_column.max_value(), ); @@ -68,18 +59,18 @@ impl Weight for IPFastFieldRangeWeight { } fn bound_to_value_range( - left_bound: &Bound, - right_bound: &Bound, + lower_bound: &Bound, + upper_bound: &Bound, min_value: Ipv6Addr, max_value: Ipv6Addr, ) -> RangeInclusive { - let start_value = match left_bound { + let start_value = match lower_bound { Bound::Included(ip_addr) => *ip_addr, Bound::Excluded(ip_addr) => Ipv6Addr::from(ip_addr.to_u128() + 1), Bound::Unbounded => min_value, }; - let end_value = match right_bound { + let end_value = match upper_bound { Bound::Included(ip_addr) => *ip_addr, Bound::Excluded(ip_addr) => Ipv6Addr::from(ip_addr.to_u128() - 1), Bound::Unbounded => max_value, @@ -181,8 +172,8 @@ pub mod tests { let searcher = index.reader().unwrap().searcher(); let range_weight = IPFastFieldRangeWeight { field: "ips".to_string(), - left_bound: Bound::Included(ip_addrs[1]), - right_bound: Bound::Included(ip_addrs[2]), + lower_bound: Bound::Included(ip_addrs[1]), + upper_bound: Bound::Included(ip_addrs[2]), }; let count = range_weight.count(searcher.segment_reader(0)).unwrap(); assert_eq!(count, 2); diff --git a/src/query/range_query/range_query_u64_fastfield.rs b/src/query/range_query/range_query_u64_fastfield.rs index d79767eb9..a508bd258 100644 --- a/src/query/range_query/range_query_u64_fastfield.rs +++ b/src/query/range_query/range_query_u64_fastfield.rs @@ -4,41 +4,79 @@ use std::ops::{Bound, RangeInclusive}; -use columnar::MonotonicallyMappableToU64; +use columnar::{ColumnType, HasAssociatedColumnType, MonotonicallyMappableToU64}; use super::fast_field_range_query::RangeDocSet; use super::map_bound; -use crate::query::{ConstScorer, EmptyScorer, Explanation, Scorer, Weight}; +use crate::query::{ConstScorer, EmptyScorer, Explanation, Query, Scorer, Weight}; use crate::{DocId, DocSet, Score, SegmentReader, TantivyError}; /// `FastFieldRangeWeight` uses the fast field to execute range queries. +#[derive(Clone, Debug)] pub struct FastFieldRangeWeight { field: String, - left_bound: Bound, - right_bound: Bound, + lower_bound: Bound, + upper_bound: Bound, + column_type_opt: Option, } impl FastFieldRangeWeight { - pub fn new(field: String, left_bound: Bound, right_bound: Bound) -> Self { - let left_bound = map_bound(&left_bound, &|val| *val); - let right_bound = map_bound(&right_bound, &|val| *val); + /// Create a new FastFieldRangeWeight, using the u64 representation of any fast field. + pub(crate) fn new_u64_lenient( + field: String, + lower_bound: Bound, + upper_bound: Bound, + ) -> Self { + let lower_bound = map_bound(&lower_bound, |val| *val); + let upper_bound = map_bound(&upper_bound, |val| *val); Self { field, - left_bound, - right_bound, + lower_bound, + upper_bound, + column_type_opt: None, } } + + /// Create a new `FastFieldRangeWeight` for a range of a u64-mappable type . + pub fn new( + field: String, + lower_bound: Bound, + upper_bound: Bound, + ) -> Self { + let lower_bound = map_bound(&lower_bound, |val| val.to_u64()); + let upper_bound = map_bound(&upper_bound, |val| val.to_u64()); + Self { + field, + lower_bound, + upper_bound, + column_type_opt: Some(T::column_type()), + } + } +} + +impl Query for FastFieldRangeWeight { + fn weight( + &self, + _enable_scoring: crate::query::EnableScoring<'_>, + ) -> crate::Result> { + Ok(Box::new(self.clone())) + } } impl Weight for FastFieldRangeWeight { fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result> { let fast_field_reader = reader.fast_fields(); - let Some((column, _)) = fast_field_reader.u64_lenient(&self.field)? else { + let column_type_opt: Option<[ColumnType; 1]> = + self.column_type_opt.map(|column_type| [column_type]); + let column_type_opt_ref: Option<&[ColumnType]> = column_type_opt + .as_ref() + .map(|column_types| column_types.as_slice()); + let Some((column, _)) = fast_field_reader.u64_lenient_for_type(column_type_opt_ref, &self.field)? else { return Ok(Box::new(EmptyScorer)); }; let value_range = bound_to_value_range( - &self.left_bound, - &self.right_bound, + &self.lower_bound, + &self.upper_bound, column.min_value(), column.max_value(), ); @@ -64,12 +102,12 @@ impl Weight for FastFieldRangeWeight { } fn bound_to_value_range( - left_bound: &Bound, - right_bound: &Bound, + lower_bound: &Bound, + upper_bound: &Bound, min_value: T, max_value: T, ) -> RangeInclusive { - let mut start_value = match left_bound { + let mut start_value = match lower_bound { Bound::Included(val) => *val, Bound::Excluded(val) => T::from_u64(val.to_u64() + 1), Bound::Unbounded => min_value, @@ -77,7 +115,7 @@ fn bound_to_value_range( if start_value.partial_cmp(&min_value) == Some(std::cmp::Ordering::Less) { start_value = min_value; } - let end_value = match right_bound { + let end_value = match upper_bound { Bound::Included(val) => *val, Bound::Excluded(val) => T::from_u64(val.to_u64() - 1), Bound::Unbounded => max_value, @@ -170,7 +208,7 @@ pub mod tests { writer.add_document(doc!(field=>52_000u64)).unwrap(); writer.commit().unwrap(); let searcher = index.reader().unwrap().searcher(); - let range_query = FastFieldRangeWeight::new( + let range_query = FastFieldRangeWeight::new_u64_lenient( "test_field".to_string(), Bound::Included(50_000), Bound::Included(50_002),