feat: cardinality aggregation (#2337)

* WiP: cardinality aggregation

* Collect unique entries first, then insert into HyperLogLog

* Handle `missing`

* Hybrid approach

* Review changes

- insert `missing` value at most once
- `term_id` -> `term_ord`
- iterate directly over entries without collecting first

* Use salted hasher to include column type

* fix: formatting

* More review fixes

* Add cardinality to test_aggregation_flushing

* Formatting
This commit is contained in:
Raphael Coeffic
2024-07-01 01:49:42 +02:00
committed by GitHub
parent e453848134
commit d9db5302d9
9 changed files with 472 additions and 5 deletions

View File

@@ -64,6 +64,7 @@ tantivy-bitpacker = { version = "0.6", path = "./bitpacker" }
common = { version = "0.7", path = "./common/", package = "tantivy-common" }
tokenizer-api = { version = "0.3", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
sketches-ddsketch = { version = "0.3.0", features = ["use_serde"] }
hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
futures-util = { version = "0.3.28", optional = true }
fnv = "1.0.7"

View File

@@ -34,8 +34,9 @@ use super::bucket::{
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
};
use super::metric::{
AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation,
PercentilesAggregationReq, StatsAggregation, SumAggregation, TopHitsAggregation,
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
MaxAggregation, MinAggregation, PercentilesAggregationReq, StatsAggregation, SumAggregation,
TopHitsAggregation,
};
/// The top-level aggregation request structure, which contains [`Aggregation`] and their user
@@ -160,6 +161,9 @@ pub enum AggregationVariants {
/// Finds the top k values matching some order
#[serde(rename = "top_hits")]
TopHits(TopHitsAggregation),
/// Computes an estimate of the number of unique values
#[serde(rename = "cardinality")]
Cardinality(CardinalityAggregationReq),
}
impl AggregationVariants {
@@ -179,6 +183,7 @@ impl AggregationVariants {
AggregationVariants::Sum(sum) => vec![sum.field_name()],
AggregationVariants::Percentiles(per) => vec![per.field_name()],
AggregationVariants::TopHits(top_hits) => top_hits.field_names(),
AggregationVariants::Cardinality(per) => vec![per.field_name()],
}
}

View File

@@ -11,8 +11,8 @@ use super::bucket::{
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
};
use super::metric::{
AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation,
StatsAggregation, SumAggregation,
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
MaxAggregation, MinAggregation, StatsAggregation, SumAggregation,
};
use super::segment_agg_result::AggregationLimits;
use super::VecWithNames;
@@ -162,6 +162,11 @@ impl AggregationWithAccessor {
field: ref field_name,
ref missing,
..
})
| Cardinality(CardinalityAggregationReq {
field: ref field_name,
ref missing,
..
}) => {
let str_dict_column = reader.fast_fields().str(field_name)?;
let allowed_column_types = [

View File

@@ -98,6 +98,8 @@ pub enum MetricResult {
Percentiles(PercentilesMetricResult),
/// Top hits metric result
TopHits(TopHitsMetricResult),
/// Cardinality metric result
Cardinality(SingleMetricResult),
}
impl MetricResult {
@@ -116,6 +118,7 @@ impl MetricResult {
MetricResult::TopHits(_) => Err(TantivyError::AggregationError(
AggregationError::InvalidRequest("top_hits can't be used to order".to_string()),
)),
MetricResult::Cardinality(card) => Ok(card.value),
}
}
}

View File

@@ -110,6 +110,16 @@ fn test_aggregation_flushing(
}
}
}
},
"cardinality_string_id":{
"cardinality": {
"field": "string_id"
}
},
"cardinality_score":{
"cardinality": {
"field": "score"
}
}
});
@@ -212,6 +222,9 @@ fn test_aggregation_flushing(
)
);
assert_eq!(res["cardinality_string_id"]["value"], 2.0);
assert_eq!(res["cardinality_score"]["value"], 80.0);
Ok(())
}

View File

@@ -26,6 +26,7 @@ use super::segment_agg_result::AggregationLimits;
use super::{format_date, AggregationError, Key, SerializedKey};
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
use crate::aggregation::bucket::TermsAggregationInternal;
use crate::aggregation::metric::CardinalityCollector;
use crate::TantivyError;
/// Contains the intermediate aggregation result, which is optimized to be merged with other
@@ -227,6 +228,9 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
TopHits(ref req) => IntermediateAggregationResult::Metric(
IntermediateMetricResult::TopHits(TopHitsTopNComputer::new(req)),
),
Cardinality(_) => IntermediateAggregationResult::Metric(
IntermediateMetricResult::Cardinality(CardinalityCollector::default()),
),
}
}
@@ -291,6 +295,8 @@ pub enum IntermediateMetricResult {
Sum(IntermediateSum),
/// Intermediate top_hits result
TopHits(TopHitsTopNComputer),
/// Intermediate cardinality result
Cardinality(CardinalityCollector),
}
impl IntermediateMetricResult {
@@ -324,6 +330,9 @@ impl IntermediateMetricResult {
IntermediateMetricResult::TopHits(top_hits) => {
MetricResult::TopHits(top_hits.into_final_result())
}
IntermediateMetricResult::Cardinality(cardinality) => {
MetricResult::Cardinality(cardinality.finalize().into())
}
}
}
@@ -372,6 +381,12 @@ impl IntermediateMetricResult {
(IntermediateMetricResult::TopHits(left), IntermediateMetricResult::TopHits(right)) => {
left.merge_fruits(right)?;
}
(
IntermediateMetricResult::Cardinality(left),
IntermediateMetricResult::Cardinality(right),
) => {
left.merge_fruits(right)?;
}
_ => {
panic!("incompatible fruit types in tree or missing merge_fruits handler");
}

View File

@@ -0,0 +1,417 @@
use std::collections::hash_map::DefaultHasher;
use std::hash::{BuildHasher, Hasher};
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{BytesColumn, StrColumn};
use common::f64_to_u64;
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
use rustc_hash::FxHashSet;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
#[derive(Clone, Debug, Serialize, Deserialize)]
struct BuildSaltedHasher {
salt: u8,
}
impl BuildHasher for BuildSaltedHasher {
type Hasher = DefaultHasher;
fn build_hasher(&self) -> Self::Hasher {
let mut hasher = DefaultHasher::new();
hasher.write_u8(self.salt);
hasher
}
}
/// # Cardinality
///
/// The cardinality aggregation allows for computing an estimate
/// of the number of different values in a data set based on the
/// HyperLogLog++ alogrithm.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CardinalityAggregationReq {
/// The field name to compute the percentiles on.
pub field: String,
/// The missing parameter defines how documents that are missing a value should be treated.
/// By default they will be ignored but it is also possible to treat them as if they had a
/// value. Examples in JSON format:
/// { "field": "my_numbers", "missing": "10.0" }
#[serde(skip_serializing_if = "Option::is_none", default)]
pub missing: Option<Key>,
}
impl CardinalityAggregationReq {
/// Creates a new [`CardinalityAggregationReq`] instance from a field name.
pub fn from_field_name(field_name: String) -> Self {
Self {
field: field_name,
missing: None,
}
}
/// Returns the field name the aggregation is computed on.
pub fn field_name(&self) -> &str {
&self.field
}
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentCardinalityCollector {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
column_type: ColumnType,
accessor_idx: usize,
missing: Option<Key>,
}
impl SegmentCardinalityCollector {
pub fn from_req(column_type: ColumnType, accessor_idx: usize, missing: &Option<Key>) -> Self {
Self {
cardinality: CardinalityCollector::new(column_type as u8),
entries: Default::default(),
column_type,
accessor_idx,
missing: missing.clone(),
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_accessor: &mut AggregationWithAccessor,
) {
if let Some(missing) = agg_accessor.missing_value_for_accessor {
agg_accessor.column_block_accessor.fetch_block_with_missing(
docs,
&agg_accessor.accessor,
missing,
);
} else {
agg_accessor
.column_block_accessor
.fetch_block(docs, &agg_accessor.accessor);
}
}
fn into_intermediate_metric_result(
mut self,
agg_with_accessor: &AggregationWithAccessor,
) -> crate::Result<IntermediateMetricResult> {
if self.column_type == ColumnType::Str {
let mut buffer = String::new();
let term_dict = agg_with_accessor
.str_dict_column
.as_ref()
.cloned()
.unwrap_or_else(|| {
StrColumn::wrap(BytesColumn::empty(agg_with_accessor.accessor.num_docs()))
});
let mut has_missing = false;
for term_ord in self.entries.into_iter() {
if term_ord == u64::MAX {
has_missing = true;
} else {
if !term_dict.ord_to_str(term_ord, &mut buffer)? {
return Err(TantivyError::InternalError(format!(
"Couldn't find term_ord {term_ord} in dict"
)));
}
self.cardinality.sketch.insert_any(&buffer);
}
}
if has_missing {
let missing_key = self
.missing
.as_ref()
.expect("Found placeholder term_ord but `missing` is None");
match missing_key {
Key::Str(missing) => {
self.cardinality.sketch.insert_any(&missing);
}
Key::F64(val) => {
let val = f64_to_u64(*val);
self.cardinality.sketch.insert_any(&val);
}
}
}
}
Ok(IntermediateMetricResult::Cardinality(self.cardinality))
}
}
impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
let intermediate_result = self.into_intermediate_metric_result(agg_with_accessor)?;
results.push(
name,
IntermediateAggregationResult::Metric(intermediate_result),
)?;
Ok(())
}
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_with_accessor)
}
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.fetch_block_with_field(docs, bucket_agg_accessor);
let col_block_accessor = &bucket_agg_accessor.column_block_accessor;
if self.column_type == ColumnType::Str {
for term_ord in col_block_accessor.iter_vals() {
self.entries.insert(term_ord);
}
} else if self.column_type == ColumnType::IpAddr {
let compact_space_accessor = bucket_agg_accessor
.accessor
.values
.clone()
.downcast_arc::<CompactSpaceU64Accessor>()
.map_err(|_| {
TantivyError::AggregationError(
crate::aggregation::AggregationError::InternalError(
"Type mismatch: Could not downcast to CompactSpaceU64Accessor"
.to_string(),
),
)
})?;
for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
self.cardinality.sketch.insert_any(&val);
}
} else {
for val in col_block_accessor.iter_vals() {
self.cardinality.sketch.insert_any(&val);
}
}
Ok(())
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
/// The percentiles collector used during segment collection and for merging results.
pub struct CardinalityCollector {
sketch: HyperLogLogPlus<u64, BuildSaltedHasher>,
}
impl Default for CardinalityCollector {
fn default() -> Self {
Self::new(0)
}
}
impl PartialEq for CardinalityCollector {
fn eq(&self, _other: &Self) -> bool {
false
}
}
impl CardinalityCollector {
/// Compute the final cardinality estimate.
pub fn finalize(self) -> Option<f64> {
Some(self.sketch.clone().count().trunc())
}
fn new(salt: u8) -> Self {
Self {
sketch: HyperLogLogPlus::new(16, BuildSaltedHasher { salt }).unwrap(),
}
}
pub(crate) fn merge_fruits(&mut self, right: CardinalityCollector) -> crate::Result<()> {
self.sketch.merge(&right.sketch).map_err(|err| {
TantivyError::AggregationError(AggregationError::InternalError(format!(
"Error while merging cardinality {err:?}"
)))
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::net::IpAddr;
use std::str::FromStr;
use columnar::MonotonicallyMappableToU64;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::tests::{exec_request, get_test_index_from_terms};
use crate::schema::{IntoIpv6Addr, Schema, FAST};
use crate::Index;
#[test]
fn cardinality_aggregation_test_empty_index() -> crate::Result<()> {
let values = vec![];
let index = get_test_index_from_terms(false, &values)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "string_id",
}
},
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 0.0);
Ok(())
}
#[test]
fn cardinality_aggregation_test_single_segment() -> crate::Result<()> {
cardinality_aggregation_test_merge_segment(true)
}
#[test]
fn cardinality_aggregation_test() -> crate::Result<()> {
cardinality_aggregation_test_merge_segment(false)
}
fn cardinality_aggregation_test_merge_segment(merge_segments: bool) -> crate::Result<()> {
let segment_and_terms = vec![
vec!["terma"],
vec!["termb"],
vec!["termc"],
vec!["terma"],
vec!["terma"],
vec!["terma"],
vec!["termb"],
vec!["terma"],
];
let index = get_test_index_from_terms(merge_segments, &segment_and_terms)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "string_id",
}
},
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 3.0);
Ok(())
}
#[test]
fn cardinality_aggregation_u64() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let id_field = schema_builder.add_u64_field("id", FAST);
let index = Index::create_in_ram(schema_builder.build());
{
let mut writer = index.writer_for_tests()?;
writer.add_document(doc!(id_field => 1u64))?;
writer.add_document(doc!(id_field => 2u64))?;
writer.add_document(doc!(id_field => 3u64))?;
writer.add_document(doc!())?;
writer.commit()?;
}
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "id",
"missing": 0u64
},
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 4.0);
Ok(())
}
#[test]
fn cardinality_aggregation_ip_addr() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let field = schema_builder.add_ip_addr_field("ip_field", FAST);
let index = Index::create_in_ram(schema_builder.build());
{
let mut writer = index.writer_for_tests()?;
// IpV6 loopback
writer.add_document(doc!(field=>IpAddr::from_str("::1").unwrap().into_ipv6_addr()))?;
writer.add_document(doc!(field=>IpAddr::from_str("::1").unwrap().into_ipv6_addr()))?;
// IpV4
writer.add_document(
doc!(field=>IpAddr::from_str("127.0.0.1").unwrap().into_ipv6_addr()),
)?;
writer.commit()?;
}
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "ip_field"
},
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 2.0);
Ok(())
}
#[test]
fn cardinality_aggregation_json() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let field = schema_builder.add_json_field("json", FAST);
let index = Index::create_in_ram(schema_builder.build());
{
let mut writer = index.writer_for_tests()?;
writer.add_document(doc!(field => json!({"value": false})))?;
writer.add_document(doc!(field => json!({"value": true})))?;
writer.add_document(doc!(field => json!({"value": i64::from_u64(0u64)})))?;
writer.add_document(doc!(field => json!({"value": i64::from_u64(1u64)})))?;
writer.commit()?;
}
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "json.value"
},
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 4.0);
Ok(())
}
}

View File

@@ -17,6 +17,7 @@
//! - [Percentiles](PercentilesAggregationReq)
mod average;
mod cardinality;
mod count;
mod extended_stats;
mod max;
@@ -29,6 +30,7 @@ mod top_hits;
use std::collections::HashMap;
pub use average::*;
pub use cardinality::*;
pub use count::*;
pub use extended_stats::*;
pub use max::*;

View File

@@ -16,7 +16,10 @@ use super::metric::{
SumAggregation,
};
use crate::aggregation::bucket::TermMissingAgg;
use crate::aggregation::metric::{SegmentExtendedStatsCollector, TopHitsSegmentCollector};
use crate::aggregation::metric::{
CardinalityAggregationReq, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
TopHitsSegmentCollector,
};
pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug {
fn add_intermediate_aggregation_result(
@@ -169,6 +172,9 @@ pub(crate) fn build_single_agg_segment_collector(
accessor_idx,
req.segment_ordinal,
))),
Cardinality(CardinalityAggregationReq { missing, .. }) => Ok(Box::new(
SegmentCardinalityCollector::from_req(req.field_type, accessor_idx, missing),
)),
}
}