mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2025-12-23 02:29:57 +00:00
support bool type in term aggregation (#2318)
* support bool type in term aggregation * add Bool to Intermediate Key
This commit is contained in:
@@ -169,8 +169,8 @@ impl AggregationWithAccessor {
|
||||
ColumnType::F64,
|
||||
ColumnType::Str,
|
||||
ColumnType::DateTime,
|
||||
ColumnType::Bool,
|
||||
// ColumnType::Bytes Unsupported
|
||||
// ColumnType::Bool Unsupported
|
||||
// ColumnType::IpAddr Unsupported
|
||||
];
|
||||
|
||||
|
||||
@@ -587,6 +587,9 @@ fn test_aggregation_on_json_object() {
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
|
||||
index_writer
|
||||
.add_document(doc!(json => json!({"color": "red"})))
|
||||
.unwrap();
|
||||
index_writer
|
||||
.add_document(doc!(json => json!({"color": "red"})))
|
||||
.unwrap();
|
||||
@@ -614,8 +617,8 @@ fn test_aggregation_on_json_object() {
|
||||
&serde_json::json!({
|
||||
"jsonagg": {
|
||||
"buckets": [
|
||||
{"doc_count": 2, "key": "red"},
|
||||
{"doc_count": 1, "key": "blue"},
|
||||
{"doc_count": 1, "key": "red"}
|
||||
],
|
||||
"doc_count_error_upper_bound": 0,
|
||||
"sum_other_doc_count": 0
|
||||
@@ -637,6 +640,9 @@ fn test_aggregation_on_nested_json_object() {
|
||||
index_writer
|
||||
.add_document(doc!(json => json!({"color.dot": "blue", "color": {"nested":"blue"} })))
|
||||
.unwrap();
|
||||
index_writer
|
||||
.add_document(doc!(json => json!({"color.dot": "blue", "color": {"nested":"blue"} })))
|
||||
.unwrap();
|
||||
index_writer.commit().unwrap();
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
@@ -664,7 +670,7 @@ fn test_aggregation_on_nested_json_object() {
|
||||
&serde_json::json!({
|
||||
"jsonagg1": {
|
||||
"buckets": [
|
||||
{"doc_count": 1, "key": "blue"},
|
||||
{"doc_count": 2, "key": "blue"},
|
||||
{"doc_count": 1, "key": "red"}
|
||||
],
|
||||
"doc_count_error_upper_bound": 0,
|
||||
@@ -672,7 +678,7 @@ fn test_aggregation_on_nested_json_object() {
|
||||
},
|
||||
"jsonagg2": {
|
||||
"buckets": [
|
||||
{"doc_count": 1, "key": "blue"},
|
||||
{"doc_count": 2, "key": "blue"},
|
||||
{"doc_count": 1, "key": "red"}
|
||||
],
|
||||
"doc_count_error_upper_bound": 0,
|
||||
@@ -814,6 +820,12 @@ fn test_aggregation_on_json_object_mixed_types() {
|
||||
.unwrap();
|
||||
index_writer.commit().unwrap();
|
||||
// => Segment with all values text
|
||||
index_writer
|
||||
.add_document(doc!(json => json!({"mixed_type": "blue"})))
|
||||
.unwrap();
|
||||
index_writer
|
||||
.add_document(doc!(json => json!({"mixed_type": "blue"})))
|
||||
.unwrap();
|
||||
index_writer
|
||||
.add_document(doc!(json => json!({"mixed_type": "blue"})))
|
||||
.unwrap();
|
||||
@@ -825,6 +837,9 @@ fn test_aggregation_on_json_object_mixed_types() {
|
||||
index_writer.commit().unwrap();
|
||||
|
||||
// => Segment with mixed values
|
||||
index_writer
|
||||
.add_document(doc!(json => json!({"mixed_type": "red"})))
|
||||
.unwrap();
|
||||
index_writer
|
||||
.add_document(doc!(json => json!({"mixed_type": "red"})))
|
||||
.unwrap();
|
||||
@@ -870,6 +885,8 @@ fn test_aggregation_on_json_object_mixed_types() {
|
||||
|
||||
let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap();
|
||||
let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap();
|
||||
// pretty print as json
|
||||
use pretty_assertions::assert_eq;
|
||||
assert_eq!(
|
||||
&aggregation_res_json,
|
||||
&serde_json::json!({
|
||||
@@ -885,9 +902,9 @@ fn test_aggregation_on_json_object_mixed_types() {
|
||||
"buckets": [
|
||||
{ "doc_count": 1, "key": 10.0, "min_price": { "value": 10.0 } },
|
||||
{ "doc_count": 1, "key": -20.5, "min_price": { "value": -20.5 } },
|
||||
// TODO bool is also not yet handled in aggregation
|
||||
{ "doc_count": 1, "key": "blue", "min_price": { "value": null } },
|
||||
{ "doc_count": 1, "key": "red", "min_price": { "value": null } },
|
||||
{ "doc_count": 2, "key": "red", "min_price": { "value": null } },
|
||||
{ "doc_count": 2, "key": 1.0, "key_as_string": "true", "min_price": { "value": null } },
|
||||
{ "doc_count": 3, "key": "blue", "min_price": { "value": null } },
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
|
||||
@@ -352,8 +352,10 @@ pub mod tests {
|
||||
let docs = vec![
|
||||
vec![r#"{ "date": "2015-01-01T12:10:30Z", "text": "aaa" }"#],
|
||||
vec![r#"{ "date": "2015-01-01T11:11:30Z", "text": "bbb" }"#],
|
||||
vec![r#"{ "date": "2015-01-01T11:11:30Z", "text": "bbb" }"#],
|
||||
vec![r#"{ "date": "2015-01-02T00:00:00Z", "text": "bbb" }"#],
|
||||
vec![r#"{ "date": "2015-01-06T00:00:00Z", "text": "ccc" }"#],
|
||||
vec![r#"{ "date": "2015-01-06T00:00:00Z", "text": "ccc" }"#],
|
||||
];
|
||||
let index = get_test_index_from_docs(merge_segments, &docs).unwrap();
|
||||
|
||||
@@ -382,7 +384,7 @@ pub mod tests {
|
||||
{
|
||||
"key_as_string" : "2015-01-01T00:00:00Z",
|
||||
"key" : 1420070400000.0,
|
||||
"doc_count" : 4
|
||||
"doc_count" : 6
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -420,15 +422,15 @@ pub mod tests {
|
||||
{
|
||||
"key_as_string" : "2015-01-01T00:00:00Z",
|
||||
"key" : 1420070400000.0,
|
||||
"doc_count" : 4,
|
||||
"doc_count" : 6,
|
||||
"texts": {
|
||||
"buckets": [
|
||||
{
|
||||
"doc_count": 2,
|
||||
"doc_count": 3,
|
||||
"key": "bbb"
|
||||
},
|
||||
{
|
||||
"doc_count": 1,
|
||||
"doc_count": 2,
|
||||
"key": "ccc"
|
||||
},
|
||||
{
|
||||
@@ -467,7 +469,7 @@ pub mod tests {
|
||||
"sales_over_time": {
|
||||
"buckets": [
|
||||
{
|
||||
"doc_count": 2,
|
||||
"doc_count": 3,
|
||||
"key": 1420070400000.0,
|
||||
"key_as_string": "2015-01-01T00:00:00Z"
|
||||
},
|
||||
@@ -492,7 +494,7 @@ pub mod tests {
|
||||
"key_as_string": "2015-01-05T00:00:00Z"
|
||||
},
|
||||
{
|
||||
"doc_count": 1,
|
||||
"doc_count": 2,
|
||||
"key": 1420502400000.0,
|
||||
"key_as_string": "2015-01-06T00:00:00Z"
|
||||
}
|
||||
@@ -533,7 +535,7 @@ pub mod tests {
|
||||
"key_as_string": "2014-12-31T00:00:00Z"
|
||||
},
|
||||
{
|
||||
"doc_count": 2,
|
||||
"doc_count": 3,
|
||||
"key": 1420070400000.0,
|
||||
"key_as_string": "2015-01-01T00:00:00Z"
|
||||
},
|
||||
@@ -558,7 +560,7 @@ pub mod tests {
|
||||
"key_as_string": "2015-01-05T00:00:00Z"
|
||||
},
|
||||
{
|
||||
"doc_count": 1,
|
||||
"doc_count": 2,
|
||||
"key": 1420502400000.0,
|
||||
"key_as_string": "2015-01-06T00:00:00Z"
|
||||
},
|
||||
|
||||
@@ -256,7 +256,7 @@ pub struct SegmentTermCollector {
|
||||
term_buckets: TermBuckets,
|
||||
req: TermsAggregationInternal,
|
||||
blueprint: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
field_type: ColumnType,
|
||||
column_type: ColumnType,
|
||||
accessor_idx: usize,
|
||||
}
|
||||
|
||||
@@ -355,7 +355,7 @@ impl SegmentTermCollector {
|
||||
field_type: ColumnType,
|
||||
accessor_idx: usize,
|
||||
) -> crate::Result<Self> {
|
||||
if field_type == ColumnType::Bytes || field_type == ColumnType::Bool {
|
||||
if field_type == ColumnType::Bytes {
|
||||
return Err(TantivyError::InvalidArgument(format!(
|
||||
"terms aggregation is not supported for column type {:?}",
|
||||
field_type
|
||||
@@ -389,7 +389,7 @@ impl SegmentTermCollector {
|
||||
req: TermsAggregationInternal::from_req(req),
|
||||
term_buckets,
|
||||
blueprint,
|
||||
field_type,
|
||||
column_type: field_type,
|
||||
accessor_idx,
|
||||
})
|
||||
}
|
||||
@@ -466,7 +466,7 @@ impl SegmentTermCollector {
|
||||
Ok(intermediate_entry)
|
||||
};
|
||||
|
||||
if self.field_type == ColumnType::Str {
|
||||
if self.column_type == ColumnType::Str {
|
||||
let term_dict = agg_with_accessor
|
||||
.str_dict_column
|
||||
.as_ref()
|
||||
@@ -531,28 +531,34 @@ impl SegmentTermCollector {
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if self.field_type == ColumnType::DateTime {
|
||||
} else if self.column_type == ColumnType::DateTime {
|
||||
for (val, doc_count) in entries {
|
||||
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
|
||||
let val = i64::from_u64(val);
|
||||
let date = format_date(val)?;
|
||||
dict.insert(IntermediateKey::Str(date), intermediate_entry);
|
||||
}
|
||||
} else if self.column_type == ColumnType::Bool {
|
||||
for (val, doc_count) in entries {
|
||||
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
|
||||
let val = bool::from_u64(val);
|
||||
dict.insert(IntermediateKey::Bool(val), intermediate_entry);
|
||||
}
|
||||
} else {
|
||||
for (val, doc_count) in entries {
|
||||
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
|
||||
let val = f64_from_fastfield_u64(val, &self.field_type);
|
||||
let val = f64_from_fastfield_u64(val, &self.column_type);
|
||||
dict.insert(IntermediateKey::F64(val), intermediate_entry);
|
||||
}
|
||||
};
|
||||
|
||||
Ok(IntermediateBucketResult::Terms(
|
||||
IntermediateTermBucketResult {
|
||||
Ok(IntermediateBucketResult::Terms {
|
||||
buckets: IntermediateTermBucketResult {
|
||||
entries: dict,
|
||||
sum_other_doc_count,
|
||||
doc_count_error_upper_bound: term_doc_count_before_cutoff,
|
||||
},
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1365,7 +1371,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn terms_aggregation_different_tokenizer_on_ff_test() -> crate::Result<()> {
|
||||
let terms = vec!["Hello Hello", "Hallo Hallo"];
|
||||
let terms = vec!["Hello Hello", "Hallo Hallo", "Hallo Hallo"];
|
||||
|
||||
let index = get_test_index_from_terms(true, &[terms])?;
|
||||
|
||||
@@ -1383,7 +1389,7 @@ mod tests {
|
||||
println!("{}", serde_json::to_string_pretty(&res).unwrap());
|
||||
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["key"], "Hallo Hallo");
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 1);
|
||||
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 2);
|
||||
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["key"], "Hello Hello");
|
||||
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 1);
|
||||
@@ -1894,4 +1900,40 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn terms_aggregation_bool() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field = schema_builder.add_bool_field("bool_field", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
{
|
||||
let mut writer = index.writer_with_num_threads(1, 15_000_000)?;
|
||||
writer.add_document(doc!(field=>true))?;
|
||||
writer.add_document(doc!(field=>false))?;
|
||||
writer.add_document(doc!(field=>true))?;
|
||||
writer.commit()?;
|
||||
}
|
||||
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"my_bool": {
|
||||
"terms": {
|
||||
"field": "bool_field"
|
||||
},
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = exec_request_with_query(agg_req, &index, None)?;
|
||||
|
||||
assert_eq!(res["my_bool"]["buckets"][0]["key"], 1.0);
|
||||
assert_eq!(res["my_bool"]["buckets"][0]["key_as_string"], "true");
|
||||
assert_eq!(res["my_bool"]["buckets"][0]["doc_count"], 2);
|
||||
assert_eq!(res["my_bool"]["buckets"][1]["key"], 0.0);
|
||||
assert_eq!(res["my_bool"]["buckets"][1]["key_as_string"], "false");
|
||||
assert_eq!(res["my_bool"]["buckets"][1]["doc_count"], 1);
|
||||
assert_eq!(res["my_bool"]["buckets"][2]["key"], serde_json::Value::Null);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,11 +73,13 @@ impl SegmentAggregationCollector for TermMissingAgg {
|
||||
|
||||
entries.insert(missing.into(), missing_entry);
|
||||
|
||||
let bucket = IntermediateBucketResult::Terms(IntermediateTermBucketResult {
|
||||
entries,
|
||||
sum_other_doc_count: 0,
|
||||
doc_count_error_upper_bound: 0,
|
||||
});
|
||||
let bucket = IntermediateBucketResult::Terms {
|
||||
buckets: IntermediateTermBucketResult {
|
||||
entries,
|
||||
sum_other_doc_count: 0,
|
||||
doc_count_error_upper_bound: 0,
|
||||
},
|
||||
};
|
||||
|
||||
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
|
||||
|
||||
|
||||
@@ -41,6 +41,8 @@ pub struct IntermediateAggregationResults {
|
||||
/// This might seem redundant with `Key`, but the point is to have a different
|
||||
/// Serialize implementation.
|
||||
pub enum IntermediateKey {
|
||||
/// Bool key
|
||||
Bool(bool),
|
||||
/// String key
|
||||
Str(String),
|
||||
/// `f64` key
|
||||
@@ -59,6 +61,7 @@ impl From<IntermediateKey> for Key {
|
||||
match value {
|
||||
IntermediateKey::Str(s) => Self::Str(s),
|
||||
IntermediateKey::F64(f) => Self::F64(f),
|
||||
IntermediateKey::Bool(f) => Self::F64(f as u64 as f64),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -71,6 +74,7 @@ impl std::hash::Hash for IntermediateKey {
|
||||
match self {
|
||||
IntermediateKey::Str(text) => text.hash(state),
|
||||
IntermediateKey::F64(val) => val.to_bits().hash(state),
|
||||
IntermediateKey::Bool(val) => val.hash(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -166,9 +170,9 @@ impl IntermediateAggregationResults {
|
||||
pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult {
|
||||
use AggregationVariants::*;
|
||||
match req.agg {
|
||||
Terms(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Terms(
|
||||
Default::default(),
|
||||
)),
|
||||
Terms(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Terms {
|
||||
buckets: Default::default(),
|
||||
}),
|
||||
Range(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Range(
|
||||
Default::default(),
|
||||
)),
|
||||
@@ -363,11 +367,14 @@ pub enum IntermediateBucketResult {
|
||||
Histogram {
|
||||
/// The column_type of the underlying `Column` is DateTime
|
||||
is_date_agg: bool,
|
||||
/// The buckets
|
||||
/// The histogram buckets
|
||||
buckets: Vec<IntermediateHistogramBucketEntry>,
|
||||
},
|
||||
/// Term aggregation
|
||||
Terms(IntermediateTermBucketResult),
|
||||
Terms {
|
||||
/// The term buckets
|
||||
buckets: IntermediateTermBucketResult,
|
||||
},
|
||||
}
|
||||
|
||||
impl IntermediateBucketResult {
|
||||
@@ -444,7 +451,7 @@ impl IntermediateBucketResult {
|
||||
};
|
||||
Ok(BucketResult::Histogram { buckets })
|
||||
}
|
||||
IntermediateBucketResult::Terms(terms) => terms.into_final_result(
|
||||
IntermediateBucketResult::Terms { buckets: terms } => terms.into_final_result(
|
||||
req.agg
|
||||
.as_term()
|
||||
.expect("unexpected aggregation, expected term aggregation"),
|
||||
@@ -457,8 +464,12 @@ impl IntermediateBucketResult {
|
||||
fn merge_fruits(&mut self, other: IntermediateBucketResult) -> crate::Result<()> {
|
||||
match (self, other) {
|
||||
(
|
||||
IntermediateBucketResult::Terms(term_res_left),
|
||||
IntermediateBucketResult::Terms(term_res_right),
|
||||
IntermediateBucketResult::Terms {
|
||||
buckets: term_res_left,
|
||||
},
|
||||
IntermediateBucketResult::Terms {
|
||||
buckets: term_res_right,
|
||||
},
|
||||
) => {
|
||||
merge_maps(&mut term_res_left.entries, term_res_right.entries)?;
|
||||
term_res_left.sum_other_doc_count += term_res_right.sum_other_doc_count;
|
||||
@@ -542,8 +553,15 @@ impl IntermediateTermBucketResult {
|
||||
.into_iter()
|
||||
.filter(|bucket| bucket.1.doc_count as u64 >= req.min_doc_count)
|
||||
.map(|(key, entry)| {
|
||||
let key_as_string = match key {
|
||||
IntermediateKey::Bool(key) => {
|
||||
let val = if key { "true" } else { "false" };
|
||||
Some(val.to_string())
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
Ok(BucketEntry {
|
||||
key_as_string: None,
|
||||
key_as_string,
|
||||
key: key.into(),
|
||||
doc_count: entry.doc_count as u64,
|
||||
sub_aggregation: entry
|
||||
|
||||
@@ -281,6 +281,7 @@ pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 {
|
||||
ColumnType::U64 => val as f64,
|
||||
ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64,
|
||||
ColumnType::F64 => f64::from_u64(val),
|
||||
ColumnType::Bool => val as f64,
|
||||
_ => {
|
||||
panic!("unexpected type {field_type:?}. This should not happen")
|
||||
}
|
||||
@@ -301,6 +302,7 @@ pub(crate) fn f64_to_fastfield_u64(val: f64, field_type: &ColumnType) -> Option<
|
||||
ColumnType::U64 => Some(val as u64),
|
||||
ColumnType::I64 | ColumnType::DateTime => Some((val as i64).to_u64()),
|
||||
ColumnType::F64 => Some(val.to_u64()),
|
||||
ColumnType::Bool => Some(val as u64),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user