mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2025-12-26 20:19:57 +00:00
feat: add support for u64,i64,f64 fields in term aggregation (#1883)
* feat: add support for u64,i64,f64 fields in term aggregation * hash enum values * fix build * Apply suggestions from code review Co-authored-by: Paul Masurel <paul@quickwit.io> --------- Co-authored-by: Paul Masurel <paul@quickwit.io>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use columnar::Cardinality;
|
||||
use columnar::{Cardinality, ColumnType};
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -15,7 +15,7 @@ use crate::aggregation::intermediate_agg_result::{
|
||||
use crate::aggregation::segment_agg_result::{
|
||||
build_segment_agg_collector, SegmentAggregationCollector,
|
||||
};
|
||||
use crate::aggregation::VecWithNames;
|
||||
use crate::aggregation::{f64_from_fastfield_u64, Key, VecWithNames};
|
||||
use crate::error::DataCorruption;
|
||||
use crate::TantivyError;
|
||||
|
||||
@@ -25,6 +25,10 @@ use crate::TantivyError;
|
||||
/// If the text is untokenized and single value, that means one term per document and therefore it
|
||||
/// is in fact doc count.
|
||||
///
|
||||
/// ## Prerequisite
|
||||
/// Term aggregations work only on [fast fields](`crate::fastfield`) of type `u64`, `f64`, `i64` and
|
||||
/// text.
|
||||
///
|
||||
/// ### Terminology
|
||||
/// Shard parameters are supposed to be equivalent to elasticsearch shard parameter.
|
||||
/// Since they are
|
||||
@@ -199,9 +203,9 @@ impl TermsAggregationInternal {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
/// Container to store term_ids and their buckets.
|
||||
/// Container to store term_ids/or u64 values and their buckets.
|
||||
struct TermBuckets {
|
||||
pub(crate) entries: FxHashMap<u32, TermBucketEntry>,
|
||||
pub(crate) entries: FxHashMap<u64, TermBucketEntry>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
@@ -262,6 +266,7 @@ pub struct SegmentTermCollector {
|
||||
term_buckets: TermBuckets,
|
||||
req: TermsAggregationInternal,
|
||||
blueprint: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
field_type: ColumnType,
|
||||
accessor_idx: usize,
|
||||
}
|
||||
|
||||
@@ -310,7 +315,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
|
||||
let entry = self
|
||||
.term_buckets
|
||||
.entries
|
||||
.entry(term_id as u32)
|
||||
.entry(term_id)
|
||||
.or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint));
|
||||
entry.doc_count += 1;
|
||||
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
|
||||
@@ -323,7 +328,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
|
||||
let entry = self
|
||||
.term_buckets
|
||||
.entries
|
||||
.entry(term_id as u32)
|
||||
.entry(term_id)
|
||||
.or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint));
|
||||
entry.doc_count += 1;
|
||||
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
|
||||
@@ -348,6 +353,7 @@ impl SegmentTermCollector {
|
||||
pub(crate) fn from_req_and_validate(
|
||||
req: &TermsAggregation,
|
||||
sub_aggregations: &AggregationsWithAccessor,
|
||||
field_type: ColumnType,
|
||||
accessor_idx: usize,
|
||||
) -> crate::Result<Self> {
|
||||
let term_buckets = TermBuckets::default();
|
||||
@@ -378,6 +384,7 @@ impl SegmentTermCollector {
|
||||
req: TermsAggregationInternal::from_req(req),
|
||||
term_buckets,
|
||||
blueprint,
|
||||
field_type,
|
||||
accessor_idx,
|
||||
})
|
||||
}
|
||||
@@ -386,7 +393,7 @@ impl SegmentTermCollector {
|
||||
self,
|
||||
agg_with_accessor: &BucketAggregationWithAccessor,
|
||||
) -> crate::Result<IntermediateBucketResult> {
|
||||
let mut entries: Vec<(u32, TermBucketEntry)> =
|
||||
let mut entries: Vec<(u64, TermBucketEntry)> =
|
||||
self.term_buckets.entries.into_iter().collect();
|
||||
|
||||
let order_by_sub_aggregation =
|
||||
@@ -423,41 +430,52 @@ impl SegmentTermCollector {
|
||||
cut_off_buckets(&mut entries, self.req.segment_size as usize)
|
||||
};
|
||||
|
||||
let mut dict: FxHashMap<String, IntermediateTermBucketEntry> = Default::default();
|
||||
let mut dict: FxHashMap<Key, IntermediateTermBucketEntry> = Default::default();
|
||||
dict.reserve(entries.len());
|
||||
if self.field_type == ColumnType::Str {
|
||||
let term_dict = agg_with_accessor
|
||||
.str_dict_column
|
||||
.as_ref()
|
||||
.expect("internal error: term dictionary not found for term aggregation");
|
||||
|
||||
let str_column = agg_with_accessor
|
||||
.str_dict_column
|
||||
.as_ref()
|
||||
.expect("Missing str column"); //< TODO Fixme
|
||||
|
||||
let mut buffer = String::new();
|
||||
for (term_id, entry) in entries {
|
||||
if !str_column.ord_to_str(term_id as u64, &mut buffer)? {
|
||||
return Err(TantivyError::InternalError(format!(
|
||||
"Couldn't find term_id {} in dict",
|
||||
term_id
|
||||
)));
|
||||
}
|
||||
dict.insert(
|
||||
buffer.to_string(),
|
||||
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
|
||||
);
|
||||
}
|
||||
if self.req.min_doc_count == 0 {
|
||||
// TODO: Handle rev streaming for descending sorting by keys
|
||||
let mut stream = str_column.dictionary().stream()?;
|
||||
while let Some((key, _ord)) = stream.next() {
|
||||
if dict.len() >= self.req.segment_size as usize {
|
||||
break;
|
||||
let mut buffer = String::new();
|
||||
for (term_id, entry) in entries {
|
||||
if !term_dict.ord_to_str(term_id, &mut buffer)? {
|
||||
return Err(TantivyError::InternalError(format!(
|
||||
"Couldn't find term_id {} in dict",
|
||||
term_id
|
||||
)));
|
||||
}
|
||||
dict.insert(
|
||||
Key::Str(buffer.to_string()),
|
||||
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
|
||||
);
|
||||
}
|
||||
if self.req.min_doc_count == 0 {
|
||||
// TODO: Handle rev streaming for descending sorting by keys
|
||||
let mut stream = term_dict.dictionary().stream()?;
|
||||
while let Some((key, _ord)) = stream.next() {
|
||||
if dict.len() >= self.req.segment_size as usize {
|
||||
break;
|
||||
}
|
||||
|
||||
let key = std::str::from_utf8(key)
|
||||
.map_err(|utf8_err| DataCorruption::comment_only(utf8_err.to_string()))?;
|
||||
if !dict.contains_key(key) {
|
||||
dict.insert(key.to_owned(), Default::default());
|
||||
let key = Key::Str(
|
||||
std::str::from_utf8(key)
|
||||
.map_err(|utf8_err| DataCorruption::comment_only(utf8_err.to_string()))?
|
||||
.to_string(),
|
||||
);
|
||||
dict.entry(key).or_default();
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (val, entry) in entries {
|
||||
let val = f64_from_fastfield_u64(val, &self.field_type);
|
||||
dict.insert(
|
||||
Key::F64(val),
|
||||
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
Ok(IntermediateBucketResult::Terms(
|
||||
IntermediateTermBucketResult {
|
||||
@@ -477,6 +495,11 @@ impl GetDocCount for (u32, TermBucketEntry) {
|
||||
self.1.doc_count
|
||||
}
|
||||
}
|
||||
impl GetDocCount for (u64, TermBucketEntry) {
|
||||
fn doc_count(&self) -> u64 {
|
||||
self.1.doc_count
|
||||
}
|
||||
}
|
||||
impl GetDocCount for (String, IntermediateTermBucketEntry) {
|
||||
fn doc_count(&self) -> u64 {
|
||||
self.1.doc_count
|
||||
@@ -620,7 +643,8 @@ mod tests {
|
||||
fn terms_aggregation_test_order_count_merge_segment(merge_segments: bool) -> crate::Result<()> {
|
||||
let segment_and_terms = vec![
|
||||
vec![(5.0, "terma".to_string())],
|
||||
vec![(4.0, "termb".to_string())],
|
||||
vec![(2.0, "termb".to_string())],
|
||||
vec![(2.0, "terma".to_string())],
|
||||
vec![(1.0, "termc".to_string())],
|
||||
vec![(1.0, "termc".to_string())],
|
||||
vec![(1.0, "termc".to_string())],
|
||||
@@ -661,7 +685,7 @@ mod tests {
|
||||
}),
|
||||
..Default::default()
|
||||
}),
|
||||
sub_aggregation: sub_agg,
|
||||
sub_aggregation: sub_agg.clone(),
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
@@ -670,18 +694,114 @@ mod tests {
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["key"], "termb");
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 2);
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["avg_score"]["value"], 6.0);
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["avg_score"]["value"], 5.0);
|
||||
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["key"], "termc");
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 3);
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["avg_score"]["value"], 1.0);
|
||||
|
||||
assert_eq!(res["my_texts"]["buckets"][2]["key"], "terma");
|
||||
assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 5);
|
||||
assert_eq!(res["my_texts"]["buckets"][2]["avg_score"]["value"], 5.0);
|
||||
assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 6);
|
||||
assert_eq!(res["my_texts"]["buckets"][2]["avg_score"]["value"], 4.5);
|
||||
|
||||
assert_eq!(res["my_texts"]["sum_other_doc_count"], 0);
|
||||
|
||||
// Agg on non string
|
||||
//
|
||||
let agg_req: Aggregations = vec![
|
||||
(
|
||||
"my_scores1".to_string(),
|
||||
Aggregation::Bucket(BucketAggregation {
|
||||
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
|
||||
field: "score".to_string(),
|
||||
order: Some(CustomOrder {
|
||||
order: Order::Asc,
|
||||
target: OrderTarget::Count,
|
||||
}),
|
||||
..Default::default()
|
||||
}),
|
||||
sub_aggregation: sub_agg.clone(),
|
||||
}),
|
||||
),
|
||||
(
|
||||
"my_scores2".to_string(),
|
||||
Aggregation::Bucket(BucketAggregation {
|
||||
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
|
||||
field: "score_f64".to_string(),
|
||||
order: Some(CustomOrder {
|
||||
order: Order::Asc,
|
||||
target: OrderTarget::Count,
|
||||
}),
|
||||
..Default::default()
|
||||
}),
|
||||
sub_aggregation: sub_agg.clone(),
|
||||
}),
|
||||
),
|
||||
(
|
||||
"my_scores3".to_string(),
|
||||
Aggregation::Bucket(BucketAggregation {
|
||||
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
|
||||
field: "score_i64".to_string(),
|
||||
order: Some(CustomOrder {
|
||||
order: Order::Asc,
|
||||
target: OrderTarget::Count,
|
||||
}),
|
||||
..Default::default()
|
||||
}),
|
||||
sub_aggregation: sub_agg,
|
||||
}),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["my_scores1"]["buckets"][0]["key"], 8.0);
|
||||
assert_eq!(res["my_scores1"]["buckets"][0]["doc_count"], 1);
|
||||
assert_eq!(res["my_scores1"]["buckets"][0]["avg_score"]["value"], 8.0);
|
||||
|
||||
assert_eq!(res["my_scores1"]["buckets"][1]["key"], 2.0);
|
||||
assert_eq!(res["my_scores1"]["buckets"][1]["doc_count"], 2);
|
||||
assert_eq!(res["my_scores1"]["buckets"][1]["avg_score"]["value"], 2.0);
|
||||
|
||||
assert_eq!(res["my_scores1"]["buckets"][2]["key"], 1.0);
|
||||
assert_eq!(res["my_scores1"]["buckets"][2]["doc_count"], 3);
|
||||
assert_eq!(res["my_scores1"]["buckets"][2]["avg_score"]["value"], 1.0);
|
||||
|
||||
assert_eq!(res["my_scores1"]["buckets"][3]["key"], 5.0);
|
||||
assert_eq!(res["my_scores1"]["buckets"][3]["doc_count"], 5);
|
||||
assert_eq!(res["my_scores1"]["buckets"][3]["avg_score"]["value"], 5.0);
|
||||
|
||||
assert_eq!(res["my_scores1"]["sum_other_doc_count"], 0);
|
||||
|
||||
assert_eq!(res["my_scores2"]["buckets"][0]["key"], 8.0);
|
||||
assert_eq!(res["my_scores2"]["buckets"][0]["doc_count"], 1);
|
||||
assert_eq!(res["my_scores2"]["buckets"][0]["avg_score"]["value"], 8.0);
|
||||
|
||||
assert_eq!(res["my_scores2"]["buckets"][1]["key"], 2.0);
|
||||
assert_eq!(res["my_scores2"]["buckets"][1]["doc_count"], 2);
|
||||
assert_eq!(res["my_scores2"]["buckets"][1]["avg_score"]["value"], 2.0);
|
||||
|
||||
assert_eq!(res["my_scores2"]["buckets"][2]["key"], 1.0);
|
||||
assert_eq!(res["my_scores2"]["buckets"][2]["doc_count"], 3);
|
||||
assert_eq!(res["my_scores2"]["buckets"][2]["avg_score"]["value"], 1.0);
|
||||
|
||||
assert_eq!(res["my_scores2"]["sum_other_doc_count"], 0);
|
||||
|
||||
assert_eq!(res["my_scores3"]["buckets"][0]["key"], 8.0);
|
||||
assert_eq!(res["my_scores3"]["buckets"][0]["doc_count"], 1);
|
||||
assert_eq!(res["my_scores3"]["buckets"][0]["avg_score"]["value"], 8.0);
|
||||
|
||||
assert_eq!(res["my_scores3"]["buckets"][1]["key"], 2.0);
|
||||
assert_eq!(res["my_scores3"]["buckets"][1]["doc_count"], 2);
|
||||
assert_eq!(res["my_scores3"]["buckets"][1]["avg_score"]["value"], 2.0);
|
||||
|
||||
assert_eq!(res["my_scores3"]["buckets"][2]["key"], 1.0);
|
||||
assert_eq!(res["my_scores3"]["buckets"][2]["doc_count"], 3);
|
||||
assert_eq!(res["my_scores3"]["buckets"][2]["avg_score"]["value"], 1.0);
|
||||
|
||||
assert_eq!(res["my_scores3"]["sum_other_doc_count"], 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -373,7 +373,7 @@ impl IntermediateBucketResult {
|
||||
IntermediateBucketResult::Terms(term_res_left),
|
||||
IntermediateBucketResult::Terms(term_res_right),
|
||||
) => {
|
||||
merge_maps(&mut term_res_left.entries, term_res_right.entries);
|
||||
merge_key_maps(&mut term_res_left.entries, term_res_right.entries);
|
||||
term_res_left.sum_other_doc_count += term_res_right.sum_other_doc_count;
|
||||
term_res_left.doc_count_error_upper_bound +=
|
||||
term_res_right.doc_count_error_upper_bound;
|
||||
@@ -383,7 +383,7 @@ impl IntermediateBucketResult {
|
||||
IntermediateBucketResult::Range(range_res_left),
|
||||
IntermediateBucketResult::Range(range_res_right),
|
||||
) => {
|
||||
merge_maps(&mut range_res_left.buckets, range_res_right.buckets);
|
||||
merge_serialized_key_maps(&mut range_res_left.buckets, range_res_right.buckets);
|
||||
}
|
||||
(
|
||||
IntermediateBucketResult::Histogram {
|
||||
@@ -435,7 +435,7 @@ pub struct IntermediateRangeBucketResult {
|
||||
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
/// Term aggregation including error counts
|
||||
pub struct IntermediateTermBucketResult {
|
||||
pub(crate) entries: FxHashMap<String, IntermediateTermBucketEntry>,
|
||||
pub(crate) entries: FxHashMap<Key, IntermediateTermBucketEntry>,
|
||||
pub(crate) sum_other_doc_count: u64,
|
||||
pub(crate) doc_count_error_upper_bound: u64,
|
||||
}
|
||||
@@ -454,7 +454,7 @@ impl IntermediateTermBucketResult {
|
||||
.map(|(key, entry)| {
|
||||
Ok(BucketEntry {
|
||||
key_as_string: None,
|
||||
key: Key::Str(key),
|
||||
key,
|
||||
doc_count: entry.doc_count,
|
||||
sub_aggregation: entry
|
||||
.sub_aggregation
|
||||
@@ -532,7 +532,7 @@ trait MergeFruits {
|
||||
fn merge_fruits(&mut self, other: Self);
|
||||
}
|
||||
|
||||
fn merge_maps<V: MergeFruits + Clone>(
|
||||
fn merge_serialized_key_maps<V: MergeFruits + Clone>(
|
||||
entries_left: &mut FxHashMap<SerializedKey, V>,
|
||||
mut entries_right: FxHashMap<SerializedKey, V>,
|
||||
) {
|
||||
@@ -547,6 +547,21 @@ fn merge_maps<V: MergeFruits + Clone>(
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_key_maps<V: MergeFruits + Clone>(
|
||||
entries_left: &mut FxHashMap<Key, V>,
|
||||
mut entries_right: FxHashMap<Key, V>,
|
||||
) {
|
||||
for (name, entry_left) in entries_left.iter_mut() {
|
||||
if let Some(entry_right) = entries_right.remove(name) {
|
||||
entry_left.merge_fruits(entry_right);
|
||||
}
|
||||
}
|
||||
|
||||
for (key, res) in entries_right.into_iter() {
|
||||
entries_left.entry(key).or_insert(res);
|
||||
}
|
||||
}
|
||||
|
||||
/// This is the histogram entry for a bucket, which contains a key, count, and optionally
|
||||
/// sub_aggregations.
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
//! There are two categories: [Metrics](metric) and [Buckets](bucket).
|
||||
//!
|
||||
//! ## Prerequisite
|
||||
//! Currently aggregations work only on [fast fields](`crate::fastfield`). Single value fast fields
|
||||
//! Currently aggregations work only on [fast fields](`crate::fastfield`). Fast fields
|
||||
//! of type `u64`, `f64`, `i64`, `date` and fast fields on text fields.
|
||||
//!
|
||||
//! ## Usage
|
||||
@@ -262,7 +262,7 @@ impl<T: Clone> VecWithNames<T> {
|
||||
/// The serialized key is used in a `HashMap`.
|
||||
pub type SerializedKey = String;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, PartialOrd)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialOrd)]
|
||||
/// The key to identify a bucket.
|
||||
#[serde(untagged)]
|
||||
pub enum Key {
|
||||
@@ -271,6 +271,26 @@ pub enum Key {
|
||||
/// `f64` key
|
||||
F64(f64),
|
||||
}
|
||||
impl Eq for Key {}
|
||||
impl std::hash::Hash for Key {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
core::mem::discriminant(self).hash(state);
|
||||
match self {
|
||||
Key::Str(text) => text.hash(state),
|
||||
Key::F64(val) => val.to_bits().hash(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Key {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(Self::Str(l), Self::Str(r)) => l == r,
|
||||
(Self::F64(l), Self::F64(r)) => l == r,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Key {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
|
||||
@@ -150,6 +150,7 @@ pub(crate) fn build_bucket_segment_agg_collector(
|
||||
SegmentTermCollector::from_req_and_validate(
|
||||
terms_req,
|
||||
&req.sub_aggregation,
|
||||
req.field_type,
|
||||
accessor_idx,
|
||||
)?,
|
||||
)),
|
||||
|
||||
Reference in New Issue
Block a user