add segment_size, add get term dict fields, add tests

This commit is contained in:
Pascal Seitz
2022-04-12 14:45:10 +08:00
parent 24432bf523
commit 46724b4a05
5 changed files with 163 additions and 37 deletions

View File

@@ -114,6 +114,15 @@ impl BucketAggregationInternal {
}
}
/// Extract all fields, where the term directory is used in the tree.
pub fn get_term_dict_field_names(aggs: &Aggregations) -> HashSet<String> {
let mut term_dict_field_names = Default::default();
for el in aggs.values() {
el.get_term_dict_field_names(&mut term_dict_field_names)
}
term_dict_field_names
}
/// Extract all fast field names used in the tree.
pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet<String> {
let mut fast_field_names = Default::default();
@@ -136,6 +145,12 @@ pub enum Aggregation {
}
impl Aggregation {
fn get_term_dict_field_names(&self, term_field_names: &mut HashSet<String>) {
if let Aggregation::Bucket(bucket) = self {
bucket.get_term_dict_field_names(term_field_names)
}
}
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
match self {
Aggregation::Bucket(bucket) => bucket.get_fast_field_names(fast_field_names),
@@ -168,6 +183,11 @@ pub struct BucketAggregation {
}
impl BucketAggregation {
fn get_term_dict_field_names(&self, term_dict_field_names: &mut HashSet<String>) {
if let BucketAggregationType::Terms(terms) = &self.bucket_agg {
term_dict_field_names.insert(terms.field.to_string());
}
}
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
self.bucket_agg.get_fast_field_names(fast_field_names);
fast_field_names.extend(get_fast_field_names(&self.sub_aggregation));

View File

@@ -39,7 +39,7 @@ use crate::{DocId, TantivyError};
/// # Request JSON Format
/// ```json
/// {
/// "range": {
/// "my_ranges": {
/// "field": "score",
/// "ranges": [
/// { "to": 3.0 },

View File

@@ -53,7 +53,7 @@ use crate::DocId;
/// ```json
/// {
/// "genres": {
/// "field": "genre",
/// "terms":{ "field": "genre" }
/// }
/// }
/// ```
@@ -65,10 +65,21 @@ pub struct TermsAggregation {
/// Larger values for size are more expensive.
pub size: Option<u32>,
/// The get more accurate results, we fetch more than `size` from each segment.
/// By default we fetch `shard_size` terms, which defaults to size * 1.5 + 10.
/// Unused by tantivy.
///
/// Since tantivy doesn't know shards, this parameter is merely there to be used by consumers
/// of tantivy. shard_size is the number of terms returned by each shard.
/// The default value in elasticsearch is size * 1.5 + 10.
///
/// Should never be smaller than size.
pub shard_size: Option<u32>,
/// The get more accurate results, we fetch more than `size` from each segment.
/// TODO document default
///
/// Increasing this value is will increase the cost for more accuracy.
pub segment_size: Option<u32>,
/// If you set the `show_term_doc_count_error` parameter to true, the terms aggregation will
/// include doc_count_error_upper_bound, which is an upper bound to the error on the
/// doc_count returned by each shard. Its the sum of the size of the largest bucket on
@@ -76,8 +87,8 @@ pub struct TermsAggregation {
#[serde(default = "default_show_term_doc_count_error")]
pub show_term_doc_count_error: bool,
/// Filter all terms than are lower `min_doc_count`.
pub min_doc_count: Option<usize>,
/// Filter all terms than are lower `min_doc_count`. Defaults to 1.
pub min_doc_count: Option<u64>,
}
impl Default for TermsAggregation {
fn default() -> Self {
@@ -87,6 +98,7 @@ impl Default for TermsAggregation {
shard_size: Default::default(),
show_term_doc_count_error: true,
min_doc_count: Default::default(),
segment_size: Default::default(),
}
}
}
@@ -104,44 +116,42 @@ pub(crate) struct TermsAggregationInternal {
/// Larger values for size are more expensive.
pub size: u32,
/// The get more accurate results, we fetch more than `size` from each segment.
/// By default we fetch `shard_size` terms, which defaults to size * 1.5 + 10.
///
/// Cannot be smaller than size. In that case it will be set automatically to size.
pub shard_size: u32,
/// If you set the `show_term_doc_count_error` parameter to true, the terms aggregation will
/// include doc_count_error_upper_bound, which is an upper bound to the error on the
/// doc_count returned by each shard. Its the sum of the size of the largest bucket on
/// each segment that didnt fit into `shard_size`.
pub show_term_doc_count_error: bool,
/// Filter all terms than are lower `min_doc_count`.
pub min_doc_count: Option<usize>,
/// The get more accurate results, we fetch more than `size` from each segment.
///
/// Increasing this value is will increase the cost for more accuracy.
pub segment_size: u32,
/// Filter all terms than are lower `min_doc_count`. Defaults to 1.
pub min_doc_count: u64,
}
impl TermsAggregationInternal {
pub(crate) fn from_req(req: &TermsAggregation) -> Self {
let size = req.size.unwrap_or(10);
let mut shard_size = req
.shard_size
let mut segment_size = req
.segment_size
.unwrap_or((size as f32 * 1.5_f32) as u32 + 10);
shard_size = shard_size.max(size);
segment_size = segment_size.max(size);
TermsAggregationInternal {
field: req.field.to_string(),
size,
shard_size,
segment_size,
show_term_doc_count_error: req.show_term_doc_count_error,
min_doc_count: req.min_doc_count,
min_doc_count: req.min_doc_count.unwrap_or(1),
}
}
}
const TERM_BUCKET_SIZE: usize = 100;
#[derive(Clone, Debug, PartialEq)]
/// Chunks the term_id value range in TERM_BUCKET_SIZE blocks.
/// Container to store term_ids and their buckets.
struct TermBuckets {
pub(crate) entries: FnvHashMap<u32, TermBucketEntry>,
blueprint: Option<SegmentAggregationResultsCollector>,
@@ -189,10 +199,9 @@ impl TermBucketEntry {
impl TermBuckets {
pub(crate) fn from_req_and_validate(
sub_aggregation: &AggregationsWithAccessor,
max_term_id: usize,
_max_term_id: usize,
) -> crate::Result<Self> {
let has_sub_aggregations = sub_aggregation.is_empty();
let _num_chunks = (max_term_id / TERM_BUCKET_SIZE) + 1;
let blueprint = if has_sub_aggregations {
let sub_aggregation =
@@ -259,7 +268,7 @@ impl SegmentTermCollector {
let term_buckets =
TermBuckets::from_req_and_validate(sub_aggregations, max_term_id as usize)?;
let has_sub_aggregations = sub_aggregations.is_empty();
let has_sub_aggregations = !sub_aggregations.is_empty();
let blueprint = if has_sub_aggregations {
let sub_aggregation =
SegmentAggregationResultsCollector::from_req_and_validate(sub_aggregations)?;
@@ -283,7 +292,7 @@ impl SegmentTermCollector {
let mut entries: Vec<_> = self.term_buckets.entries.into_iter().collect();
let (term_doc_count_before_cutoff, sum_other_doc_count) =
cut_off_buckets(&mut entries, self.req.shard_size as usize);
cut_off_buckets(&mut entries, self.req.segment_size as usize);
let inverted_index = agg_with_accessor
.inverted_index
@@ -403,7 +412,8 @@ pub(crate) fn cut_off_buckets<T: GetDocCount + Debug>(
mod tests {
use super::*;
use crate::aggregation::agg_req::{
Aggregation, Aggregations, BucketAggregation, BucketAggregationType,
get_term_dict_field_names, Aggregation, Aggregations, BucketAggregation,
BucketAggregationType,
};
use crate::aggregation::tests::{exec_request, get_test_index_from_terms};
@@ -476,6 +486,37 @@ mod tests {
);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 1);
// test min_doc_count
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
size: Some(2),
shard_size: Some(2),
min_doc_count: Some(3),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let res = exec_request(agg_req.clone(), &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5);
assert_eq!(
res["my_texts"]["buckets"][1]["key"],
serde_json::Value::Null
);
assert_eq!(res["my_texts"]["sum_other_doc_count"], 0); // TODO sum_other_doc_count with min_doc_count
assert_eq!(
get_term_dict_field_names(&agg_req),
vec!["string_id".to_string(),].into_iter().collect()
);
Ok(())
}
@@ -496,7 +537,7 @@ mod tests {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
size: Some(2),
shard_size: Some(2),
segment_size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
@@ -599,4 +640,13 @@ mod bench {
fn bench_fnv_buckets_1_000_000_of_50(b: &mut test::Bencher) {
bench_term_hashmap(b, 1_000_000u64, 50u64)
}
#[bench]
fn bench_term_buckets_1_000_000_of_1_000_000(b: &mut test::Bencher) {
bench_term_buckets(b, 1_000_000u64, 1_000_000u64)
}
#[bench]
fn bench_fnv_buckets_1_000_000_of_1_000_000(b: &mut test::Bencher) {
bench_term_hashmap(b, 1_000_000u64, 1_000_000u64)
}
}

View File

@@ -251,6 +251,7 @@ impl IntermediateTermBucketResult {
let mut buckets: Vec<BucketEntry> = self
.entries
.into_iter()
.filter(|bucket| bucket.1.doc_count >= req.min_doc_count)
.map(|(key, entry)| BucketEntry {
key: Key::Str(key),
doc_count: entry.doc_count,

View File

@@ -456,15 +456,13 @@ mod tests {
merge_segments: bool,
use_distributed_collector: bool,
) -> crate::Result<()> {
let index = get_test_index_with_num_docs(merge_segments, 80)?;
let mut values_and_terms = (0..80)
.map(|val| vec![(val as f64, "terma".to_string())])
.collect::<Vec<_>>();
values_and_terms.last_mut().unwrap()[0].1 = "termb".to_string();
let index = get_test_index_from_values_and_terms(merge_segments, &values_and_terms)?;
let reader = index.reader()?;
let text_field = reader.searcher().schema().get_field("text").unwrap();
let term_query = TermQuery::new(
Term::from_field_text(text_field, "cool"),
IndexRecordOption::Basic,
);
assert_eq!(DOC_BLOCK_SIZE, 64);
// In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block.
@@ -509,6 +507,19 @@ mod tests {
}
}
}
},
"term_agg_test":{
"terms": {
"field": "string_id"
},
"aggs": {
"bucketsL2": {
"histogram": {
"field": "score",
"interval": 70.0
}
}
}
}
});
@@ -521,17 +532,18 @@ mod tests {
let searcher = reader.searcher();
AggregationResults::from_intermediate_and_req(
searcher.search(&term_query, &collector).unwrap(),
searcher.search(&AllQuery, &collector).unwrap(),
agg_req,
)
} else {
let collector = AggregationCollector::from_aggs(agg_req);
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
searcher.search(&AllQuery, &collector).unwrap()
};
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
// println!("{}", serde_json::to_string_pretty(&res).unwrap());
assert_eq!(res["bucketsL1"]["buckets"][0]["doc_count"], 3);
assert_eq!(
@@ -558,6 +570,46 @@ mod tests {
);
assert_eq!(res["bucketsL1"]["buckets"][2]["doc_count"], 80 - 70);
assert_eq!(
res["term_agg_test"],
json!(
{
"buckets": [
{
"bucketsL2": {
"buckets": [
{
"doc_count": 70,
"key": 0.0
},
{
"doc_count": 9,
"key": 70.0
}
]
},
"doc_count": 79,
"key": "terma"
},
{
"bucketsL2": {
"buckets": [
{
"doc_count": 1,
"key": 70.0
}
]
},
"doc_count": 1,
"key": "termb"
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
}
)
);
Ok(())
}
@@ -1085,6 +1137,9 @@ mod tests {
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
let index = Index::create_from_tempdir(schema_builder.build())?;
let few_terms_data = vec!["INFO", "ERROR", "WARN", "DEBUG"];
let many_terms_data = (0..150_000)
.map(|num| format!("author{}", num))
.collect::<Vec<_>>();
{
let mut rng = thread_rng();
let mut index_writer = index.writer_for_tests()?;
@@ -1093,7 +1148,7 @@ mod tests {
let val: f64 = rng.gen_range(0.0..1_000_000.0);
index_writer.add_document(doc!(
text_field => "cool",
text_field_many_terms => val.to_string(),
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
score_field => val as u64,
score_field_f64 => val as f64,