Exposing u64-based FastFieldRangeWeight (#2024)

This commit is contained in:
Paul Masurel
2023-05-03 18:32:00 +09:00
committed by GitHub
parent 73452284ae
commit f28ddb711e
7 changed files with 150 additions and 94 deletions

View File

@@ -414,8 +414,8 @@ impl FacetCounts {
pub fn get<T>(&self, facet_from: T) -> FacetChildIterator<'_>
where Facet: From<T> {
let facet = Facet::from(facet_from);
let left_bound = Bound::Excluded(facet.clone());
let right_bound = if facet.is_root() {
let lower_bound = Bound::Excluded(facet.clone());
let upper_bound = if facet.is_root() {
Bound::Unbounded
} else {
let mut facet_after_bytes: String = facet.encoded_str().to_owned();
@@ -424,7 +424,7 @@ impl FacetCounts {
Bound::Excluded(facet_after)
};
let underlying: btree_map::Range<'_, _, _> =
self.facet_counts.range((left_bound, right_bound));
self.facet_counts.range((lower_bound, upper_bound));
FacetChildIterator { underlying }
}

View File

@@ -51,7 +51,7 @@ pub use self::phrase_prefix_query::PhrasePrefixQuery;
pub use self::phrase_query::PhraseQuery;
pub use self::query::{EnableScoring, Query, QueryClone};
pub use self::query_parser::{QueryParser, QueryParserError};
pub use self::range_query::RangeQuery;
pub use self::range_query::{FastFieldRangeWeight, IPFastFieldRangeWeight, RangeQuery};
pub use self::regex_query::RegexQuery;
pub use self::reqopt_scorer::RequiredOptionalScorer;
pub use self::score_combiner::{

View File

@@ -974,8 +974,8 @@ mod test {
let query = make_query_parser().parse_query("title:[A TO B]").unwrap();
assert_eq!(
format!("{:?}", query),
"RangeQuery { field: \"title\", value_type: Str, left_bound: Included([97]), \
right_bound: Included([98]), limit: None }"
"RangeQuery { field: \"title\", value_type: Str, lower_bound: Included([97]), \
upper_bound: Included([98]), limit: None }"
);
}

View File

@@ -7,7 +7,9 @@ mod range_query;
mod range_query_ip_fastfield;
mod range_query_u64_fastfield;
pub use self::range_query::RangeQuery;
pub use self::range_query::{RangeQuery, RangeWeight};
pub use self::range_query_ip_fastfield::IPFastFieldRangeWeight;
pub use self::range_query_u64_fastfield::FastFieldRangeWeight;
// TODO is this correct?
pub(crate) fn is_type_valid_for_fastfield_range_query(typ: Type) -> bool {
@@ -18,10 +20,7 @@ pub(crate) fn is_type_valid_for_fastfield_range_query(typ: Type) -> bool {
}
}
fn map_bound<TFrom, TTo, Transform: Fn(&TFrom) -> TTo>(
bound: &Bound<TFrom>,
transform: &Transform,
) -> Bound<TTo> {
fn map_bound<TFrom, TTo>(bound: &Bound<TFrom>, transform: impl Fn(&TFrom) -> TTo) -> Bound<TTo> {
use self::Bound::*;
match bound {
Excluded(ref from_val) => Excluded(transform(from_val)),
@@ -29,3 +28,15 @@ fn map_bound<TFrom, TTo, Transform: Fn(&TFrom) -> TTo>(
Unbounded => Unbounded,
}
}
fn map_bound_res<TFrom, TTo, Err>(
bound: &Bound<TFrom>,
transform: impl Fn(&TFrom) -> Result<TTo, Err>,
) -> Result<Bound<TTo>, Err> {
use self::Bound::*;
Ok(match bound {
Excluded(ref from_val) => Excluded(transform(from_val)?),
Included(ref from_val) => Included(transform(from_val)?),
Unbounded => Unbounded,
})
}

View File

@@ -2,6 +2,7 @@ use std::io;
use std::net::Ipv6Addr;
use std::ops::{Bound, Range};
use columnar::MonotonicallyMappableToU128;
use common::{BinarySerializable, BitSet};
use super::map_bound;
@@ -9,8 +10,8 @@ use super::range_query_u64_fastfield::FastFieldRangeWeight;
use crate::core::SegmentReader;
use crate::error::TantivyError;
use crate::query::explanation::does_not_match;
use crate::query::range_query::is_type_valid_for_fastfield_range_query;
use crate::query::range_query::range_query_ip_fastfield::IPFastFieldRangeWeight;
use crate::query::range_query::{is_type_valid_for_fastfield_range_query, map_bound_res};
use crate::query::{BitSetDocSet, ConstScorer, EnableScoring, Explanation, Query, Scorer, Weight};
use crate::schema::{Field, IndexRecordOption, Term, Type};
use crate::termdict::{TermDictionary, TermStreamer};
@@ -69,8 +70,8 @@ use crate::{DateTime, DocId, Score};
pub struct RangeQuery {
field: String,
value_type: Type,
left_bound: Bound<Vec<u8>>,
right_bound: Bound<Vec<u8>>,
lower_bound: Bound<Vec<u8>>,
upper_bound: Bound<Vec<u8>>,
limit: Option<u64>,
}
@@ -82,15 +83,15 @@ impl RangeQuery {
pub fn new_term_bounds(
field: String,
value_type: Type,
left_bound: &Bound<Term>,
right_bound: &Bound<Term>,
lower_bound: &Bound<Term>,
upper_bound: &Bound<Term>,
) -> RangeQuery {
let verify_and_unwrap_term = |val: &Term| val.serialized_value_bytes().to_owned();
RangeQuery {
field,
value_type,
left_bound: map_bound(left_bound, &verify_and_unwrap_term),
right_bound: map_bound(right_bound, &verify_and_unwrap_term),
lower_bound: map_bound(lower_bound, verify_and_unwrap_term),
upper_bound: map_bound(upper_bound, verify_and_unwrap_term),
limit: None,
}
}
@@ -116,8 +117,8 @@ impl RangeQuery {
/// will panic when the `Weight` object is created.
pub fn new_i64_bounds(
field: String,
left_bound: Bound<i64>,
right_bound: Bound<i64>,
lower_bound: Bound<i64>,
upper_bound: Bound<i64>,
) -> RangeQuery {
let make_term_val = |val: &i64| {
Term::from_field_i64(Field::from_field_id(0), *val)
@@ -127,8 +128,8 @@ impl RangeQuery {
RangeQuery {
field,
value_type: Type::I64,
left_bound: map_bound(&left_bound, &make_term_val),
right_bound: map_bound(&right_bound, &make_term_val),
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
@@ -154,8 +155,8 @@ impl RangeQuery {
/// will panic when the `Weight` object is created.
pub fn new_f64_bounds(
field: String,
left_bound: Bound<f64>,
right_bound: Bound<f64>,
lower_bound: Bound<f64>,
upper_bound: Bound<f64>,
) -> RangeQuery {
let make_term_val = |val: &f64| {
Term::from_field_f64(Field::from_field_id(0), *val)
@@ -165,8 +166,8 @@ impl RangeQuery {
RangeQuery {
field,
value_type: Type::F64,
left_bound: map_bound(&left_bound, &make_term_val),
right_bound: map_bound(&right_bound, &make_term_val),
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
@@ -180,8 +181,8 @@ impl RangeQuery {
/// will panic when the `Weight` object is created.
pub fn new_u64_bounds(
field: String,
left_bound: Bound<u64>,
right_bound: Bound<u64>,
lower_bound: Bound<u64>,
upper_bound: Bound<u64>,
) -> RangeQuery {
let make_term_val = |val: &u64| {
Term::from_field_u64(Field::from_field_id(0), *val)
@@ -191,8 +192,8 @@ impl RangeQuery {
RangeQuery {
field,
value_type: Type::U64,
left_bound: map_bound(&left_bound, &make_term_val),
right_bound: map_bound(&right_bound, &make_term_val),
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
@@ -203,8 +204,8 @@ impl RangeQuery {
/// will panic when the `Weight` object is created.
pub fn new_ip_bounds(
field: String,
left_bound: Bound<Ipv6Addr>,
right_bound: Bound<Ipv6Addr>,
lower_bound: Bound<Ipv6Addr>,
upper_bound: Bound<Ipv6Addr>,
) -> RangeQuery {
let make_term_val = |val: &Ipv6Addr| {
Term::from_field_ip_addr(Field::from_field_id(0), *val)
@@ -214,8 +215,8 @@ impl RangeQuery {
RangeQuery {
field,
value_type: Type::IpAddr,
left_bound: map_bound(&left_bound, &make_term_val),
right_bound: map_bound(&right_bound, &make_term_val),
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
@@ -241,8 +242,8 @@ impl RangeQuery {
/// will panic when the `Weight` object is created.
pub fn new_date_bounds(
field: String,
left_bound: Bound<DateTime>,
right_bound: Bound<DateTime>,
lower_bound: Bound<DateTime>,
upper_bound: Bound<DateTime>,
) -> RangeQuery {
let make_term_val = |val: &DateTime| {
Term::from_field_date(Field::from_field_id(0), *val)
@@ -252,8 +253,8 @@ impl RangeQuery {
RangeQuery {
field,
value_type: Type::Date,
left_bound: map_bound(&left_bound, &make_term_val),
right_bound: map_bound(&right_bound, &make_term_val),
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
@@ -277,13 +278,17 @@ impl RangeQuery {
///
/// If the field is not of the type `Str`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_str_bounds(field: String, left: Bound<&str>, right: Bound<&str>) -> RangeQuery {
pub fn new_str_bounds(
field: String,
lower_bound: Bound<&str>,
upper_bound: Bound<&str>,
) -> RangeQuery {
let make_term_val = |val: &&str| val.as_bytes().to_vec();
RangeQuery {
field,
value_type: Type::Str,
left_bound: map_bound(&left, &make_term_val),
right_bound: map_bound(&right, &make_term_val),
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
@@ -340,10 +345,21 @@ impl Query for RangeQuery {
if field_type.is_fast() && is_type_valid_for_fastfield_range_query(self.value_type) {
if field_type.is_ip_addr() {
let parse_ip_from_bytes = |data: &Vec<u8>| {
let ip_u128_bytes: [u8; 16] = data.as_slice().try_into().map_err(|_| {
crate::TantivyError::InvalidArgument(
"Expected 8 bytes for ip address".to_string(),
)
})?;
let ip_u128 = u128::from_be_bytes(ip_u128_bytes);
crate::Result::<Ipv6Addr>::Ok(Ipv6Addr::from_u128(ip_u128))
};
let lower_bound = map_bound_res(&self.lower_bound, parse_ip_from_bytes)?;
let upper_bound = map_bound_res(&self.upper_bound, parse_ip_from_bytes)?;
Ok(Box::new(IPFastFieldRangeWeight::new(
self.field.to_string(),
&self.left_bound,
&self.right_bound,
lower_bound,
upper_bound,
)))
} else {
// We run the range query on u64 value space for performance reasons and simpicity
@@ -353,19 +369,19 @@ impl Query for RangeQuery {
u64::from_be(BinarySerializable::deserialize(&mut &data[..]).unwrap())
};
let left_bound = map_bound(&self.left_bound, &parse_from_bytes);
let right_bound = map_bound(&self.right_bound, &parse_from_bytes);
Ok(Box::new(FastFieldRangeWeight::new(
let lower_bound = map_bound(&self.lower_bound, parse_from_bytes);
let upper_bound = map_bound(&self.upper_bound, parse_from_bytes);
Ok(Box::new(FastFieldRangeWeight::new_u64_lenient(
self.field.to_string(),
left_bound,
right_bound,
lower_bound,
upper_bound,
)))
}
} else {
Ok(Box::new(RangeWeight {
field: self.field.to_string(),
left_bound: self.left_bound.clone(),
right_bound: self.right_bound.clone(),
lower_bound: self.lower_bound.clone(),
upper_bound: self.upper_bound.clone(),
limit: self.limit,
}))
}
@@ -374,8 +390,8 @@ impl Query for RangeQuery {
pub struct RangeWeight {
field: String,
left_bound: Bound<Vec<u8>>,
right_bound: Bound<Vec<u8>>,
lower_bound: Bound<Vec<u8>>,
upper_bound: Bound<Vec<u8>>,
limit: Option<u64>,
}
@@ -383,12 +399,12 @@ impl RangeWeight {
fn term_range<'a>(&self, term_dict: &'a TermDictionary) -> io::Result<TermStreamer<'a>> {
use std::ops::Bound::*;
let mut term_stream_builder = term_dict.range();
term_stream_builder = match self.left_bound {
term_stream_builder = match self.lower_bound {
Included(ref term_val) => term_stream_builder.ge(term_val),
Excluded(ref term_val) => term_stream_builder.gt(term_val),
Unbounded => term_stream_builder,
};
term_stream_builder = match self.right_bound {
term_stream_builder = match self.upper_bound {
Included(ref term_val) => term_stream_builder.le(term_val),
Excluded(ref term_val) => term_stream_builder.lt(term_val),
Unbounded => term_stream_builder,

View File

@@ -6,9 +6,7 @@ use std::net::Ipv6Addr;
use std::ops::{Bound, RangeInclusive};
use columnar::{Column, MonotonicallyMappableToU128};
use common::BinarySerializable;
use super::map_bound;
use crate::query::range_query::fast_field_range_query::RangeDocSet;
use crate::query::{ConstScorer, EmptyScorer, Explanation, Scorer, Weight};
use crate::{DocId, DocSet, Score, SegmentReader, TantivyError};
@@ -16,24 +14,17 @@ use crate::{DocId, DocSet, Score, SegmentReader, TantivyError};
/// `IPFastFieldRangeWeight` uses the ip address fast field to execute range queries.
pub struct IPFastFieldRangeWeight {
field: String,
left_bound: Bound<Ipv6Addr>,
right_bound: Bound<Ipv6Addr>,
lower_bound: Bound<Ipv6Addr>,
upper_bound: Bound<Ipv6Addr>,
}
impl IPFastFieldRangeWeight {
// TODO fix code smell... why do we end up working with Vec<u8> here?
pub fn new(field: String, left_bound: &Bound<Vec<u8>>, right_bound: &Bound<Vec<u8>>) -> Self {
let parse_ip_from_bytes = |data: &Vec<u8>| {
let ip_u128: u128 =
u128::from_be(BinarySerializable::deserialize(&mut &data[..]).unwrap());
Ipv6Addr::from_u128(ip_u128)
};
let left_bound = map_bound(left_bound, &parse_ip_from_bytes);
let right_bound = map_bound(right_bound, &parse_ip_from_bytes);
/// Creates a new IPFastFieldRangeWeight.
pub fn new(field: String, lower_bound: Bound<Ipv6Addr>, upper_bound: Bound<Ipv6Addr>) -> Self {
Self {
field,
left_bound,
right_bound,
lower_bound,
upper_bound,
}
}
}
@@ -45,8 +36,8 @@ impl Weight for IPFastFieldRangeWeight {
return Ok(Box::new(EmptyScorer))
};
let value_range = bound_to_value_range(
&self.left_bound,
&self.right_bound,
&self.lower_bound,
&self.upper_bound,
ip_addr_column.min_value(),
ip_addr_column.max_value(),
);
@@ -68,18 +59,18 @@ impl Weight for IPFastFieldRangeWeight {
}
fn bound_to_value_range(
left_bound: &Bound<Ipv6Addr>,
right_bound: &Bound<Ipv6Addr>,
lower_bound: &Bound<Ipv6Addr>,
upper_bound: &Bound<Ipv6Addr>,
min_value: Ipv6Addr,
max_value: Ipv6Addr,
) -> RangeInclusive<Ipv6Addr> {
let start_value = match left_bound {
let start_value = match lower_bound {
Bound::Included(ip_addr) => *ip_addr,
Bound::Excluded(ip_addr) => Ipv6Addr::from(ip_addr.to_u128() + 1),
Bound::Unbounded => min_value,
};
let end_value = match right_bound {
let end_value = match upper_bound {
Bound::Included(ip_addr) => *ip_addr,
Bound::Excluded(ip_addr) => Ipv6Addr::from(ip_addr.to_u128() - 1),
Bound::Unbounded => max_value,
@@ -181,8 +172,8 @@ pub mod tests {
let searcher = index.reader().unwrap().searcher();
let range_weight = IPFastFieldRangeWeight {
field: "ips".to_string(),
left_bound: Bound::Included(ip_addrs[1]),
right_bound: Bound::Included(ip_addrs[2]),
lower_bound: Bound::Included(ip_addrs[1]),
upper_bound: Bound::Included(ip_addrs[2]),
};
let count = range_weight.count(searcher.segment_reader(0)).unwrap();
assert_eq!(count, 2);

View File

@@ -4,41 +4,79 @@
use std::ops::{Bound, RangeInclusive};
use columnar::MonotonicallyMappableToU64;
use columnar::{ColumnType, HasAssociatedColumnType, MonotonicallyMappableToU64};
use super::fast_field_range_query::RangeDocSet;
use super::map_bound;
use crate::query::{ConstScorer, EmptyScorer, Explanation, Scorer, Weight};
use crate::query::{ConstScorer, EmptyScorer, Explanation, Query, Scorer, Weight};
use crate::{DocId, DocSet, Score, SegmentReader, TantivyError};
/// `FastFieldRangeWeight` uses the fast field to execute range queries.
#[derive(Clone, Debug)]
pub struct FastFieldRangeWeight {
field: String,
left_bound: Bound<u64>,
right_bound: Bound<u64>,
lower_bound: Bound<u64>,
upper_bound: Bound<u64>,
column_type_opt: Option<ColumnType>,
}
impl FastFieldRangeWeight {
pub fn new(field: String, left_bound: Bound<u64>, right_bound: Bound<u64>) -> Self {
let left_bound = map_bound(&left_bound, &|val| *val);
let right_bound = map_bound(&right_bound, &|val| *val);
/// Create a new FastFieldRangeWeight, using the u64 representation of any fast field.
pub(crate) fn new_u64_lenient(
field: String,
lower_bound: Bound<u64>,
upper_bound: Bound<u64>,
) -> Self {
let lower_bound = map_bound(&lower_bound, |val| *val);
let upper_bound = map_bound(&upper_bound, |val| *val);
Self {
field,
left_bound,
right_bound,
lower_bound,
upper_bound,
column_type_opt: None,
}
}
/// Create a new `FastFieldRangeWeight` for a range of a u64-mappable type .
pub fn new<T: HasAssociatedColumnType + MonotonicallyMappableToU64>(
field: String,
lower_bound: Bound<T>,
upper_bound: Bound<T>,
) -> Self {
let lower_bound = map_bound(&lower_bound, |val| val.to_u64());
let upper_bound = map_bound(&upper_bound, |val| val.to_u64());
Self {
field,
lower_bound,
upper_bound,
column_type_opt: Some(T::column_type()),
}
}
}
impl Query for FastFieldRangeWeight {
fn weight(
&self,
_enable_scoring: crate::query::EnableScoring<'_>,
) -> crate::Result<Box<dyn Weight>> {
Ok(Box::new(self.clone()))
}
}
impl Weight for FastFieldRangeWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let fast_field_reader = reader.fast_fields();
let Some((column, _)) = fast_field_reader.u64_lenient(&self.field)? else {
let column_type_opt: Option<[ColumnType; 1]> =
self.column_type_opt.map(|column_type| [column_type]);
let column_type_opt_ref: Option<&[ColumnType]> = column_type_opt
.as_ref()
.map(|column_types| column_types.as_slice());
let Some((column, _)) = fast_field_reader.u64_lenient_for_type(column_type_opt_ref, &self.field)? else {
return Ok(Box::new(EmptyScorer));
};
let value_range = bound_to_value_range(
&self.left_bound,
&self.right_bound,
&self.lower_bound,
&self.upper_bound,
column.min_value(),
column.max_value(),
);
@@ -64,12 +102,12 @@ impl Weight for FastFieldRangeWeight {
}
fn bound_to_value_range<T: MonotonicallyMappableToU64>(
left_bound: &Bound<T>,
right_bound: &Bound<T>,
lower_bound: &Bound<T>,
upper_bound: &Bound<T>,
min_value: T,
max_value: T,
) -> RangeInclusive<T> {
let mut start_value = match left_bound {
let mut start_value = match lower_bound {
Bound::Included(val) => *val,
Bound::Excluded(val) => T::from_u64(val.to_u64() + 1),
Bound::Unbounded => min_value,
@@ -77,7 +115,7 @@ fn bound_to_value_range<T: MonotonicallyMappableToU64>(
if start_value.partial_cmp(&min_value) == Some(std::cmp::Ordering::Less) {
start_value = min_value;
}
let end_value = match right_bound {
let end_value = match upper_bound {
Bound::Included(val) => *val,
Bound::Excluded(val) => T::from_u64(val.to_u64() - 1),
Bound::Unbounded => max_value,
@@ -170,7 +208,7 @@ pub mod tests {
writer.add_document(doc!(field=>52_000u64)).unwrap();
writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher();
let range_query = FastFieldRangeWeight::new(
let range_query = FastFieldRangeWeight::new_u64_lenient(
"test_field".to_string(),
Bound::Included(50_000),
Bound::Included(50_002),