improve validation in aggregation, extend invalid field test (#1292)

* improve validation in aggregation, extend invalid field test

improve validation in aggregation
extend invalid field test
Fixes #1291

* collect fast field names on request structure

* fix visibility of AggregationSegmentCollector
This commit is contained in:
PSeitz
2022-02-25 07:21:19 +01:00
committed by GitHub
parent d7b46d2137
commit c4f66eb185
10 changed files with 260 additions and 83 deletions

View File

@@ -44,7 +44,7 @@
//! assert_eq!(agg_req1, agg_req2);
//! ```
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
@@ -57,6 +57,15 @@ use super::metric::{AverageAggregation, StatsAggregation};
/// The key is the user defined name of the aggregation.
pub type Aggregations = HashMap<String, Aggregation>;
/// Extract all fast field names used in the tree.
pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet<String> {
let mut fast_field_names = Default::default();
for el in aggs.values() {
el.get_fast_field_names(&mut fast_field_names)
}
fast_field_names
}
/// Aggregation request of [BucketAggregation] or [MetricAggregation].
///
/// An aggregation is either a bucket or a metric.
@@ -69,6 +78,15 @@ pub enum Aggregation {
Metric(MetricAggregation),
}
impl Aggregation {
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
match self {
Aggregation::Bucket(bucket) => bucket.get_fast_field_names(fast_field_names),
Aggregation::Metric(metric) => metric.get_fast_field_names(fast_field_names),
}
}
}
/// BucketAggregations create buckets of documents. Each bucket is associated with a rule which
/// determines whether or not a document in the falls into it. In other words, the buckets
/// effectively define document sets. Buckets are not necessarily disjunct, therefore a document can
@@ -92,6 +110,13 @@ pub struct BucketAggregation {
pub sub_aggregation: Aggregations,
}
impl BucketAggregation {
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
self.bucket_agg.get_fast_field_names(fast_field_names);
fast_field_names.extend(get_fast_field_names(&self.sub_aggregation));
}
}
/// The bucket aggregation types.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum BucketAggregationType {
@@ -100,6 +125,14 @@ pub enum BucketAggregationType {
Range(RangeAggregation),
}
impl BucketAggregationType {
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
match self {
BucketAggregationType::Range(range) => fast_field_names.insert(range.field.to_string()),
};
}
}
/// The aggregations in this family compute metrics based on values extracted
/// from the documents that are being aggregated. Values are extracted from the fast field of
/// the document.
@@ -117,6 +150,15 @@ pub enum MetricAggregation {
Stats(StatsAggregation),
}
impl MetricAggregation {
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
match self {
MetricAggregation::Average(avg) => fast_field_names.insert(avg.field.to_string()),
MetricAggregation::Stats(stats) => fast_field_names.insert(stats.field.to_string()),
};
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -167,4 +209,62 @@ mod tests {
let agg_req2: String = serde_json::to_string_pretty(&agg_req1).unwrap();
assert_eq!(agg_req2, elasticsearch_compatible_json_req);
}
#[test]
fn test_get_fast_field_names() {
let agg_req2: Aggregations = vec![
(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score2".to_string(),
ranges: vec![
(f64::MIN..3f64).into(),
(3f64..7f64).into(),
(7f64..20f64).into(),
(20f64..f64::MAX).into(),
],
}),
sub_aggregation: Default::default(),
}),
),
(
"metric".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("field123".to_string()),
)),
),
]
.into_iter()
.collect();
let agg_req1: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score".to_string(),
ranges: vec![
(f64::MIN..3f64).into(),
(3f64..7f64).into(),
(7f64..20f64).into(),
(20f64..f64::MAX).into(),
],
}),
sub_aggregation: agg_req2,
}),
)]
.into_iter()
.collect();
assert_eq!(
get_fast_field_names(&agg_req1),
vec![
"score".to_string(),
"score2".to_string(),
"field123".to_string()
]
.into_iter()
.collect()
)
}
}

View File

@@ -4,8 +4,8 @@ use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAgg
use super::bucket::RangeAggregation;
use super::metric::{AverageAggregation, StatsAggregation};
use super::VecWithNames;
use crate::fastfield::DynamicFastFieldReader;
use crate::schema::Type;
use crate::fastfield::{type_and_cardinality, DynamicFastFieldReader, FastType};
use crate::schema::{Cardinality, Type};
use crate::{SegmentReader, TantivyError};
#[derive(Clone, Default)]
@@ -38,7 +38,7 @@ pub struct BucketAggregationWithAccessor {
}
impl BucketAggregationWithAccessor {
fn from_bucket(
fn try_from_bucket(
bucket: &BucketAggregationType,
sub_aggregation: &Aggregations,
reader: &SegmentReader,
@@ -53,7 +53,7 @@ impl BucketAggregationWithAccessor {
Ok(BucketAggregationWithAccessor {
accessor,
field_type,
sub_aggregation: get_aggregations_with_accessor(&sub_aggregation, reader)?,
sub_aggregation: get_aggs_with_accessor_and_validate(&sub_aggregation, reader)?,
bucket_agg: bucket.clone(),
})
}
@@ -68,7 +68,7 @@ pub struct MetricAggregationWithAccessor {
}
impl MetricAggregationWithAccessor {
fn from_metric(
fn try_from_metric(
metric: &MetricAggregation,
reader: &SegmentReader,
) -> crate::Result<MetricAggregationWithAccessor> {
@@ -87,7 +87,7 @@ impl MetricAggregationWithAccessor {
}
}
pub(crate) fn get_aggregations_with_accessor(
pub(crate) fn get_aggs_with_accessor_and_validate(
aggs: &Aggregations,
reader: &SegmentReader,
) -> crate::Result<AggregationsWithAccessor> {
@@ -97,7 +97,7 @@ pub(crate) fn get_aggregations_with_accessor(
match agg {
Aggregation::Bucket(bucket) => buckets.push((
key.to_string(),
BucketAggregationWithAccessor::from_bucket(
BucketAggregationWithAccessor::try_from_bucket(
&bucket.bucket_agg,
&bucket.sub_aggregation,
reader,
@@ -105,7 +105,7 @@ pub(crate) fn get_aggregations_with_accessor(
)),
Aggregation::Metric(metric) => metrics.push((
key.to_string(),
MetricAggregationWithAccessor::from_metric(metric, reader)?,
MetricAggregationWithAccessor::try_from_metric(metric, reader)?,
)),
}
}
@@ -124,15 +124,21 @@ fn get_ff_reader_and_validate(
.get_field(field_name)
.ok_or_else(|| TantivyError::FieldNotFound(field_name.to_string()))?;
let field_type = reader.schema().get_field_entry(field).field_type();
if field_type.value_type() != Type::I64
&& field_type.value_type() != Type::U64
&& field_type.value_type() != Type::F64
{
if let Some((ff_type, cardinality)) = type_and_cardinality(field_type) {
if cardinality == Cardinality::MultiValues || ff_type == FastType::Date {
return Err(TantivyError::InvalidArgument(format!(
"Invalid field type in aggregation {:?}, only Cardinality::SingleValue supported",
field_type.value_type()
)));
}
} else {
return Err(TantivyError::InvalidArgument(format!(
"Invalid field type in aggregation {:?}, only f64, u64, i64 is supported",
"Only single value fast fields of type f64, u64, i64 are supported, but got {:?} ",
field_type.value_type()
)));
}
};
let ff_fields = reader.fast_fields();
ff_fields
.u64_lenient(field)

View File

@@ -1,6 +1,5 @@
use std::ops::Range;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_req_with_accessor::{
@@ -116,7 +115,7 @@ impl SegmentRangeCollector {
IntermediateBucketResult::Range(buckets)
}
pub(crate) fn from_req(
pub(crate) fn from_req_and_validate(
req: &RangeAggregation,
sub_aggregation: &AggregationsWithAccessor,
field_type: Type,
@@ -140,7 +139,7 @@ impl SegmentRangeCollector {
let sub_aggregation = if sub_aggregation.is_empty() {
None
} else {
Some(SegmentAggregationResultsCollector::from_req(
Some(SegmentAggregationResultsCollector::from_req_and_validate(
sub_aggregation,
)?)
};
@@ -239,15 +238,24 @@ impl SegmentRangeCollector {
/// fast field.
/// The alternative would be that every value read would be converted to the f64 range, but that is
/// more computational expensive when many documents are hit.
fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> Range<u64> {
range
.from
.map(|from| f64_to_fastfield_u64(from, field_type))
.unwrap_or(u64::MIN)
..range
.to
.map(|to| f64_to_fastfield_u64(to, field_type))
.unwrap_or(u64::MAX)
fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> crate::Result<Range<u64>> {
let start = if let Some(from) = range.from {
f64_to_fastfield_u64(from, field_type).ok_or::<TantivyError>(
TantivyError::InvalidArgument("invalid field type".to_string()),
)?
} else {
u64::MIN
};
let end = if let Some(to) = range.to {
f64_to_fastfield_u64(to, field_type).ok_or::<TantivyError>(
TantivyError::InvalidArgument("invalid field type".to_string()),
)?
} else {
u64::MAX
};
Ok(start..end)
}
/// Extends the provided buckets to contain the whole value range, by inserting buckets at the
@@ -259,7 +267,7 @@ fn extend_validate_ranges(
let mut converted_buckets = buckets
.iter()
.map(|range| to_u64_range(range, field_type))
.collect_vec();
.collect::<crate::Result<Vec<_>>>()?;
converted_buckets.sort_by_key(|bucket| bucket.start);
if converted_buckets[0].start != u64::MIN {
@@ -335,7 +343,7 @@ mod tests {
ranges,
};
SegmentRangeCollector::from_req(&req, &Default::default(), field_type).unwrap()
SegmentRangeCollector::from_req_and_validate(&req, &Default::default(), field_type).unwrap()
}
#[test]
@@ -499,6 +507,7 @@ mod tests {
#[cfg(all(test, feature = "unstable"))]
mod bench {
use itertools::Itertools;
use rand::seq::SliceRandom;
use rand::thread_rng;

View File

@@ -3,9 +3,9 @@ use super::agg_req_with_accessor::AggregationsWithAccessor;
use super::agg_result::AggregationResults;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationResultsCollector;
use crate::aggregation::agg_req_with_accessor::get_aggregations_with_accessor;
use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_validate;
use crate::collector::{Collector, SegmentCollector};
use crate::TantivyError;
use crate::{SegmentReader, TantivyError};
/// Collector for aggregations.
///
@@ -50,12 +50,7 @@ impl Collector for DistributedAggregationCollector {
_segment_local_id: crate::SegmentOrdinal,
reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let aggs_with_accessor = get_aggregations_with_accessor(&self.agg, reader)?;
let result = SegmentAggregationResultsCollector::from_req(&aggs_with_accessor)?;
Ok(AggregationSegmentCollector {
aggs: aggs_with_accessor,
result,
})
AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader)
}
fn requires_scoring(&self) -> bool {
@@ -80,8 +75,9 @@ impl Collector for AggregationCollector {
_segment_local_id: crate::SegmentOrdinal,
reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let aggs_with_accessor = get_aggregations_with_accessor(&self.agg, reader)?;
let result = SegmentAggregationResultsCollector::from_req(&aggs_with_accessor)?;
let aggs_with_accessor = get_aggs_with_accessor_and_validate(&self.agg, reader)?;
let result =
SegmentAggregationResultsCollector::from_req_and_validate(&aggs_with_accessor)?;
Ok(AggregationSegmentCollector {
aggs: aggs_with_accessor,
result,
@@ -115,11 +111,29 @@ fn merge_fruits(
}
}
/// AggregationSegmentCollector does the aggregation collection on a segment.
pub struct AggregationSegmentCollector {
aggs: AggregationsWithAccessor,
result: SegmentAggregationResultsCollector,
}
impl AggregationSegmentCollector {
/// Creates an AggregationSegmentCollector from an [Aggregations] request and a segment reader.
/// Also includes validation, e.g. checking field types and existence.
pub fn from_agg_req_and_reader(
agg: &Aggregations,
reader: &SegmentReader,
) -> crate::Result<Self> {
let aggs_with_accessor = get_aggs_with_accessor_and_validate(&agg, reader)?;
let result =
SegmentAggregationResultsCollector::from_req_and_validate(&aggs_with_accessor)?;
Ok(AggregationSegmentCollector {
aggs: aggs_with_accessor,
result,
})
}
}
impl SegmentCollector for AggregationSegmentCollector {
type Fruit = IntermediateAggregationResults;

View File

@@ -95,7 +95,7 @@ impl IntermediateStats {
self.max = self.max.max(other.max);
}
/// compute final result
/// compute final resultimprove_docs
pub fn finalize(&self) -> Stats {
Stats {
count: self.count,

View File

@@ -28,7 +28,9 @@
//! let agg_res = searcher.search(&term_query, &collector).unwrap_err();
//! let json_response_string: String = &serde_json::to_string(&agg_res)?;
//! ```
//! # Limitations
//!
//! Currently aggregations work only on single value fast fields of type u64, f64 and i64.
//!
//! # Example
//! Compute the average metric, by building [agg_req::Aggregations], which is built from an (String,
@@ -150,7 +152,9 @@ mod segment_agg_result;
use std::collections::HashMap;
use std::fmt::Display;
pub use collector::{AggregationCollector, DistributedAggregationCollector};
pub use collector::{
AggregationCollector, AggregationSegmentCollector, DistributedAggregationCollector,
};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
@@ -250,13 +254,18 @@ impl Serialize for Key {
}
}
/// Invert of to_fastfield_u64
/// Invert of to_fastfield_u64. Used to convert to f64 for metrics.
///
/// # Panics
/// Only u64, f64, i64 is supported
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &Type) -> f64 {
match field_type {
Type::U64 => val as f64,
Type::I64 => i64::from_u64(val) as f64,
Type::F64 => f64::from_u64(val),
Type::Date | Type::Str | Type::Facet | Type::Bytes | Type::Json => unimplemented!(),
_ => {
panic!("unexpected type {:?}. This should not happen", field_type)
}
}
}
@@ -270,12 +279,12 @@ pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &Type) -> f64 {
/// A f64 value of e.g. 2.0 needs to be converted using the same monotonic
/// conversion function, so that the value matches the u64 value stored in the fast
/// field.
pub(crate) fn f64_to_fastfield_u64(val: f64, field_type: &Type) -> u64 {
pub(crate) fn f64_to_fastfield_u64(val: f64, field_type: &Type) -> Option<u64> {
match field_type {
Type::U64 => val as u64,
Type::I64 => (val as i64).to_u64(),
Type::F64 => val.to_u64(),
Type::Date | Type::Str | Type::Facet | Type::Bytes | Type::Json => unimplemented!(),
Type::U64 => Some(val as u64),
Type::I64 => Some((val as i64).to_u64()),
Type::F64 => Some(val.to_u64()),
_ => None,
}
}
@@ -293,7 +302,7 @@ mod tests {
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::segment_agg_result::DOC_BLOCK_SIZE;
use crate::aggregation::DistributedAggregationCollector;
use crate::query::TermQuery;
use crate::query::{AllQuery, TermQuery};
use crate::schema::{Cardinality, IndexRecordOption, Schema, TextFieldIndexing};
use crate::{Index, Term};
@@ -467,6 +476,11 @@ mod tests {
crate::schema::NumericOptions::default().set_fast(Cardinality::SingleValue);
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
let multivalue =
crate::schema::NumericOptions::default().set_fast(Cardinality::MultiValues);
let scores_field_i64 = schema_builder.add_i64_field("scores_i64", multivalue);
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
let index = Index::create_in_ram(schema_builder.build());
{
@@ -477,12 +491,16 @@ mod tests {
score_field => 1u64,
score_field_f64 => 1f64,
score_field_i64 => 1i64,
scores_field_i64 => 1i64,
scores_field_i64 => 2i64,
))?;
index_writer.add_document(doc!(
text_field => "cool",
score_field => 3u64,
score_field_f64 => 3f64,
score_field_i64 => 3i64,
scores_field_i64 => 5i64,
scores_field_i64 => 5i64,
))?;
index_writer.add_document(doc!(
text_field => "cool",
@@ -852,31 +870,42 @@ mod tests {
let index = get_test_index_2_segments(false)?;
let reader = index.reader()?;
let text_field = reader.searcher().schema().get_field("text").unwrap();
let term_query = TermQuery::new(
Term::from_field_text(text_field, "cool"),
IndexRecordOption::Basic,
);
let avg_on_field = |field_name: &str| {
let agg_req_1: Aggregations = vec![(
"average".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name(field_name.to_string()),
)),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = vec![(
"average".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("text".to_string()),
)),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1);
let collector = AggregationCollector::from_aggs(agg_req_1);
let searcher = reader.searcher();
let agg_res = searcher.search(&term_query, &collector).unwrap_err();
let searcher = reader.searcher();
let agg_res = searcher.search(&AllQuery, &collector).unwrap_err();
agg_res
};
let agg_res = avg_on_field("text");
assert_eq!(
format!("{:?}", agg_res),
r#"InvalidArgument("Invalid field type in aggregation Str, only f64, u64, i64 is supported")"#
r#"InvalidArgument("Only single value fast fields of type f64, u64, i64 are supported, but got Str ")"#
);
let agg_res = avg_on_field("not_exist_field");
assert_eq!(
format!("{:?}", agg_res),
r#"FieldNotFound("not_exist_field")"#
);
let agg_res = avg_on_field("scores_i64");
assert_eq!(
format!("{:?}", agg_res),
r#"InvalidArgument("Invalid field type in aggregation I64, only Cardinality::SingleValue supported")"#
);
Ok(())
}

View File

@@ -5,8 +5,6 @@
use std::fmt::Debug;
use itertools::Itertools;
use super::agg_req::MetricAggregation;
use super::agg_req_with_accessor::{
AggregationsWithAccessor, BucketAggregationWithAccessor, MetricAggregationWithAccessor,
@@ -42,22 +40,27 @@ impl Debug for SegmentAggregationResultsCollector {
}
impl SegmentAggregationResultsCollector {
pub(crate) fn from_req(req: &AggregationsWithAccessor) -> crate::Result<Self> {
pub(crate) fn from_req_and_validate(req: &AggregationsWithAccessor) -> crate::Result<Self> {
let buckets = req
.buckets
.entries()
.map(|(key, req)| {
Ok((
key.to_string(),
SegmentBucketResultCollector::from_req(req)?,
SegmentBucketResultCollector::from_req_and_validate(req)?,
))
})
.collect::<crate::Result<_>>()?;
let metrics = req
.metrics
.entries()
.map(|(key, req)| (key.to_string(), SegmentMetricResultCollector::from_req(req)))
.collect_vec();
.map(|(key, req)| {
Ok((
key.to_string(),
SegmentMetricResultCollector::from_req_and_validate(req)?,
))
})
.collect::<crate::Result<_>>()?;
Ok(SegmentAggregationResultsCollector {
metrics: VecWithNames::from_entries(metrics),
buckets: VecWithNames::from_entries(buckets),
@@ -115,15 +118,17 @@ pub(crate) enum SegmentMetricResultCollector {
}
impl SegmentMetricResultCollector {
pub fn from_req(req: &MetricAggregationWithAccessor) -> Self {
pub fn from_req_and_validate(req: &MetricAggregationWithAccessor) -> crate::Result<Self> {
match &req.metric {
MetricAggregation::Average(AverageAggregation { field: _ }) => {
SegmentMetricResultCollector::Average(SegmentAverageCollector::from_req(
req.field_type,
Ok(SegmentMetricResultCollector::Average(
SegmentAverageCollector::from_req(req.field_type),
))
}
MetricAggregation::Stats(StatsAggregation { field: _ }) => {
SegmentMetricResultCollector::Stats(SegmentStatsCollector::from_req(req.field_type))
Ok(SegmentMetricResultCollector::Stats(
SegmentStatsCollector::from_req(req.field_type),
))
}
}
}
@@ -149,11 +154,15 @@ pub(crate) enum SegmentBucketResultCollector {
}
impl SegmentBucketResultCollector {
pub fn from_req(req: &BucketAggregationWithAccessor) -> crate::Result<Self> {
pub fn from_req_and_validate(req: &BucketAggregationWithAccessor) -> crate::Result<Self> {
match &req.bucket_agg {
BucketAggregationType::Range(range_req) => Ok(Self::Range(
SegmentRangeCollector::from_req(range_req, &req.sub_aggregation, req.field_type)?,
)),
BucketAggregationType::Range(range_req) => {
Ok(Self::Range(SegmentRangeCollector::from_req_and_validate(
range_req,
&req.sub_aggregation,
req.field_type,
)?))
}
}
}

View File

@@ -28,6 +28,7 @@ pub use self::facet_reader::FacetReader;
pub use self::multivalued::{MultiValuedFastFieldReader, MultiValuedFastFieldWriter};
pub use self::reader::{DynamicFastFieldReader, FastFieldReader};
pub use self::readers::FastFieldReaders;
pub(crate) use self::readers::{type_and_cardinality, FastType};
pub use self::serializer::{CompositeFastFieldSerializer, FastFieldDataAccess, FastFieldStats};
pub use self::writer::{FastFieldsWriter, IntFastFieldWriter};
use crate::chrono::{NaiveDateTime, Utc};

View File

@@ -17,14 +17,14 @@ pub struct FastFieldReaders {
fast_fields_composite: CompositeFile,
}
#[derive(Eq, PartialEq, Debug)]
enum FastType {
pub(crate) enum FastType {
I64,
U64,
F64,
Date,
}
fn type_and_cardinality(field_type: &FieldType) -> Option<(FastType, Cardinality)> {
pub(crate) fn type_and_cardinality(field_type: &FieldType) -> Option<(FastType, Cardinality)> {
match field_type {
FieldType::U64(options) => options
.get_fastfield_cardinality()

View File

@@ -74,6 +74,15 @@ impl NumericOptions {
self.fieldnorms && self.indexed
}
/// Returns true iff the value is a fast field and multivalue.
pub fn is_multivalue_fast(&self) -> bool {
if let Some(cardinality) = self.fast {
cardinality == Cardinality::MultiValues
} else {
false
}
}
/// Returns true iff the value is a fast field.
pub fn is_fast(&self) -> bool {
self.fast.is_some()