Compare commits

..

11 Commits

Author SHA1 Message Date
Paul Masurel
b2573a3b16 low cardinality optimisation 2025-11-19 18:41:10 +01:00
Moe
70e591e230 feat: added filter aggregation (#2711)
* Initial impl

* Added `Filter` impl in `build_single_agg_segment_collector_with_reader` + Added tests

* Added `Filter(FilterBucketResult)` + Made tests work.

* Fixed type issues.

* Fixed a test.

* 8a7a73a: Pass `segment_reader`

* Added more tests.

* Improved parsing + tests

* refactoring

* Added more tests.

* refactoring: moved parsing code under QueryParser

* Use Tantivy syntax instead of ES

* Added a sanity check test.

* Simplified impl + tests

* Added back tests in a more maintable way

* nitz.

* nitz

* implemented very simple fast-path

* improved a comment

* implemented fast field support

* Used `BoundsRange`

* Improved fast field impl + tests

* Simplified execution.

* Fixed exports + nitz

* Improved the tests to check to the expected result.

* Improved test by checking the whole result JSON

* Removed brittle perf checks.

* Added efficiency verification tests.

* Added one more efficiency check test.

* Improved the efficiency tests.

* Removed unnecessary parsing code + added direct Query obj

* Fixed tests.

* Improved tests

* Fixed code structure

* Fixed lint issues

* nitz.

* nitz

* nitz.

* nitz.

* nitz.

* Added an example

* Fixed PR comments.

* Applied PR comments + nitz

* nitz.

* Improved the code.

* Fixed a perf issue.

* Added batch processing.

* Made the example more interesting

* Fixed bucket count

* Renamed Direct to CustomQuery

* Fixed lint issues.

* No need for scorer to be an `Option`

* nitz

* Used BitSet

* Added an optimization for AllQuery

* Fixed merge issues.

* Fixed lint issues.

* Added benchmark for FILTER

* Removed the Option wrapper.

* nitz.

* Applied PR comments.

* Fixed the AllQuery optimization

* Applied PR comments.

* feat: used `erased_serde` to allow filter query to be serialized

* further improved a comment

* Added back tests.

* removed an unused method

* removed an unused method

* Added documentation

* nitz.

* Added query builder.

* Fixed a comment.

* Applied PR comments.

* Fixed doctest issues.

* Added ser/de

* Removed bench in test

* Fixed a lint issue.
2025-11-18 20:54:31 +01:00
Arthur
5277367cb0 remove duplicated call to index_writer.commit() in example (#2732) 2025-11-12 14:52:44 +01:00
Paul Masurel
8b02bff9b8 Removing obsolete benchmark screenshot (#2730)
Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-11-05 09:55:13 +01:00
PSeitz
60225bdd45 cleanup (#2724)
Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-10-23 10:23:34 +02:00
PSeitz
938bfec8b7 use FxHashMap for Aggregations Request (#2722)
Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-10-21 15:59:18 +02:00
PSeitz
dabcaa5809 fix merge intermediate aggregation results (#2719)
Previously the merging relied on the order of the results, which is invalid since https://github.com/quickwit-oss/tantivy/pull/2035.
This bug is only hit in specific scenarios, when the aggregation collectors are built in a different order on different segments.

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-10-17 12:41:31 +02:00
PSeitz
d410a3b0c0 Add Filtering for Term Aggregations (#2717)
* Add Filtering for Term Aggregations

Closes #2702

* add AggregationsSegmentCtx memory consumption

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-10-15 17:39:53 +02:00
Remi
fc93391d0e Minor clarifications on the AggregationsWithAccessor refacto (#2716) 2025-10-14 19:59:33 +02:00
PSeitz
f8e79271ab Replace AggregationsWithAccessor (#2715)
* add nested histogram-termagg benchmark

* Replace AggregationsWithAccessor with AggData

With AggregationsWithAccessor pre-computation and caching was done on the collector level.
If you have 10000 sub collectors (e.g. a term aggregation with sub aggregations) this is very inefficient.
`AggData` instead moves the data from the collector to a node which reflects the cardinality of the request tree instead of the cardinality of the segment collector.
It also moves the global struct shared with all aggregations in to aggregation specific structs. So each aggregation has its own space to store cached data and aggregation specific information.

This also breaks up the dependency to the elastic search aggregation structure somewhat.

Due to lifetime issues, we move the agg request specific object out of `AggData` during the collection and move it back at the end (for now). That's some unnecessary work, which costs CPU.

This allows better caching and will also pave the way for another potential optimization, by separating the collector and its storage. Currently we allocate a new collector for each sub aggregation bucket (for nested aggregations), but ideally we would have just one collector instance.

* renames

* move request data to agg request files

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-10-14 09:22:11 +02:00
PSeitz
33835b6a01 Add DocSet::cost() (#2707)
* query: add DocSet cost hint and use it for intersection ordering

- Add DocSet::cost()
- Use cost() instead of size_hint() to order scorers in intersect_scorers

This isolates cost-related changes without the new seek APIs from
PR #2538

* add comments

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-10-13 16:25:49 +02:00
54 changed files with 5216 additions and 1563 deletions

View File

@@ -69,6 +69,7 @@ hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
futures-util = { version = "0.3.28", optional = true }
futures-channel = { version = "0.3.28", optional = true }
fnv = "1.0.7"
typetag = "0.2.21"
[target.'cfg(windows)'.dependencies]
winapi = "0.3.9"
@@ -87,7 +88,7 @@ more-asserts = "0.3.1"
rand_distr = "0.4.3"
time = { version = "0.3.10", features = ["serde-well-known", "macros"] }
postcard = { version = "1.0.4", features = [
"use-std",
"use-std",
], default-features = false }
[target.'cfg(not(windows))'.dev-dependencies]
@@ -175,4 +176,3 @@ harness = false
[[bench]]
name = "and_or_queries"
harness = false

View File

@@ -23,8 +23,6 @@ performance for different types of queries/collections.
Your mileage WILL vary depending on the nature of queries and their load.
<img src="doc/assets/images/searchbenchmark.png">
Details about the benchmark can be found at this [repository](https://github.com/quickwit-oss/search-benchmark-game).
## Features

View File

@@ -71,8 +71,15 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, histogram);
register!(group, histogram_hard_bounds);
register!(group, histogram_with_avg_sub_agg);
register!(group, histogram_with_term_agg_few);
register!(group, avg_and_range_with_avg_sub_agg);
// Filter aggregation benchmarks
register!(group, filter_agg_all_query_count_agg);
register!(group, filter_agg_term_query_count_agg);
register!(group, filter_agg_all_query_with_sub_aggs);
register!(group, filter_agg_term_query_with_sub_aggs);
group.run();
}
@@ -339,6 +346,17 @@ fn histogram_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn histogram_with_term_agg_few(index: &Index) {
let agg_req = json!({
"rangef64": {
"histogram": { "field": "score_f64", "interval": 10 },
"aggs": {
"my_texts": { "terms": { "field": "text_few_terms" } }
}
}
});
execute_agg(index, agg_req);
}
fn avg_and_range_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"rangef64": {
@@ -460,3 +478,61 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
Ok(index)
}
// Filter aggregation benchmarks
fn filter_agg_all_query_count_agg(index: &Index) {
let agg_req = json!({
"filtered": {
"filter": "*",
"aggs": {
"count": { "value_count": { "field": "score" } }
}
}
});
execute_agg(index, agg_req);
}
fn filter_agg_term_query_count_agg(index: &Index) {
let agg_req = json!({
"filtered": {
"filter": "text:cool",
"aggs": {
"count": { "value_count": { "field": "score" } }
}
}
});
execute_agg(index, agg_req);
}
fn filter_agg_all_query_with_sub_aggs(index: &Index) {
let agg_req = json!({
"filtered": {
"filter": "*",
"aggs": {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms" }
}
}
}
});
execute_agg(index, agg_req);
}
fn filter_agg_term_query_with_sub_aggs(index: &Index) {
let agg_req = json!({
"filtered": {
"filter": "text:cool",
"aggs": {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms" }
}
}
}
});
execute_agg(index, agg_req);
}

View File

@@ -29,7 +29,6 @@ impl BinarySerializable for VIntU128 {
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
#[allow(clippy::unbuffered_bytes)]
let mut bytes = reader.bytes();
let mut result = 0u128;
let mut shift = 0u64;
@@ -53,7 +52,7 @@ impl BinarySerializable for VIntU128 {
}
}
/// Wrapper over a `u64` that serializes as a variable int.
/// Wrapper over a `u64` that serializes as a variable int.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct VInt(pub u64);
@@ -197,7 +196,6 @@ impl BinarySerializable for VInt {
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
#[allow(clippy::unbuffered_bytes)]
let mut bytes = reader.bytes();
let mut result = 0u64;
let mut shift = 0u64;

Binary file not shown.

Before

Width:  |  Height:  |  Size: 653 KiB

View File

@@ -0,0 +1,212 @@
// # Filter Aggregation Example
//
// This example demonstrates filter aggregations - creating buckets of documents
// matching specific queries, with nested aggregations computed on each bucket.
//
// Filter aggregations are useful for computing metrics on different subsets of
// your data in a single query, like "average price overall + average price for
// electronics + count of in-stock items".
use serde_json::json;
use tantivy::aggregation::agg_req::Aggregations;
use tantivy::aggregation::AggregationCollector;
use tantivy::query::AllQuery;
use tantivy::schema::{Schema, FAST, INDEXED, TEXT};
use tantivy::{doc, Index};
fn main() -> tantivy::Result<()> {
// Create a simple product schema
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("category", TEXT | FAST);
schema_builder.add_text_field("brand", TEXT | FAST);
schema_builder.add_u64_field("price", FAST);
schema_builder.add_f64_field("rating", FAST);
schema_builder.add_bool_field("in_stock", FAST | INDEXED);
let schema = schema_builder.build();
// Create index and add sample products
let index = Index::create_in_ram(schema.clone());
let mut writer = index.writer(50_000_000)?;
writer.add_document(doc!(
schema.get_field("category")? => "electronics",
schema.get_field("brand")? => "apple",
schema.get_field("price")? => 999u64,
schema.get_field("rating")? => 4.5f64,
schema.get_field("in_stock")? => true
))?;
writer.add_document(doc!(
schema.get_field("category")? => "electronics",
schema.get_field("brand")? => "samsung",
schema.get_field("price")? => 799u64,
schema.get_field("rating")? => 4.2f64,
schema.get_field("in_stock")? => true
))?;
writer.add_document(doc!(
schema.get_field("category")? => "clothing",
schema.get_field("brand")? => "nike",
schema.get_field("price")? => 120u64,
schema.get_field("rating")? => 4.1f64,
schema.get_field("in_stock")? => false
))?;
writer.add_document(doc!(
schema.get_field("category")? => "books",
schema.get_field("brand")? => "penguin",
schema.get_field("price")? => 25u64,
schema.get_field("rating")? => 4.8f64,
schema.get_field("in_stock")? => true
))?;
writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
// Example 1: Basic filter with metric aggregation
println!("=== Example 1: Electronics average price ===");
let agg_req = json!({
"electronics": {
"filter": "category:electronics",
"aggs": {
"avg_price": { "avg": { "field": "price" } }
}
}
});
let agg: Aggregations = serde_json::from_value(agg_req)?;
let collector = AggregationCollector::from_aggs(agg, Default::default());
let result = searcher.search(&AllQuery, &collector)?;
let expected = json!({
"electronics": {
"doc_count": 2,
"avg_price": { "value": 899.0 }
}
});
assert_eq!(serde_json::to_value(&result)?, expected);
println!("{}\n", serde_json::to_string_pretty(&result)?);
// Example 2: Multiple independent filters
println!("=== Example 2: Multiple filters in one query ===");
let agg_req = json!({
"electronics": {
"filter": "category:electronics",
"aggs": { "avg_price": { "avg": { "field": "price" } } }
},
"in_stock": {
"filter": "in_stock:true",
"aggs": { "count": { "value_count": { "field": "brand" } } }
},
"high_rated": {
"filter": "rating:[4.5 TO *]",
"aggs": { "count": { "value_count": { "field": "brand" } } }
}
});
let agg: Aggregations = serde_json::from_value(agg_req)?;
let collector = AggregationCollector::from_aggs(agg, Default::default());
let result = searcher.search(&AllQuery, &collector)?;
let expected = json!({
"electronics": {
"doc_count": 2,
"avg_price": { "value": 899.0 }
},
"in_stock": {
"doc_count": 3,
"count": { "value": 3.0 }
},
"high_rated": {
"doc_count": 2,
"count": { "value": 2.0 }
}
});
assert_eq!(serde_json::to_value(&result)?, expected);
println!("{}\n", serde_json::to_string_pretty(&result)?);
// Example 3: Nested filters - progressive refinement
println!("=== Example 3: Nested filters ===");
let agg_req = json!({
"in_stock": {
"filter": "in_stock:true",
"aggs": {
"electronics": {
"filter": "category:electronics",
"aggs": {
"expensive": {
"filter": "price:[800 TO *]",
"aggs": {
"avg_rating": { "avg": { "field": "rating" } }
}
}
}
}
}
}
});
let agg: Aggregations = serde_json::from_value(agg_req)?;
let collector = AggregationCollector::from_aggs(agg, Default::default());
let result = searcher.search(&AllQuery, &collector)?;
let expected = json!({
"in_stock": {
"doc_count": 3, // apple, samsung, penguin
"electronics": {
"doc_count": 2, // apple, samsung
"expensive": {
"doc_count": 1, // only apple (999)
"avg_rating": { "value": 4.5 }
}
}
}
});
assert_eq!(serde_json::to_value(&result)?, expected);
println!("{}\n", serde_json::to_string_pretty(&result)?);
// Example 4: Filter with sub-aggregation (terms)
println!("=== Example 4: Filter with terms sub-aggregation ===");
let agg_req = json!({
"electronics": {
"filter": "category:electronics",
"aggs": {
"by_brand": {
"terms": { "field": "brand" },
"aggs": {
"avg_price": { "avg": { "field": "price" } }
}
}
}
}
});
let agg: Aggregations = serde_json::from_value(agg_req)?;
let collector = AggregationCollector::from_aggs(agg, Default::default());
let result = searcher.search(&AllQuery, &collector)?;
let expected = json!({
"electronics": {
"doc_count": 2,
"by_brand": {
"buckets": [
{
"key": "samsung",
"doc_count": 1,
"avg_price": { "value": 799.0 }
},
{
"key": "apple",
"doc_count": 1,
"avg_price": { "value": 999.0 }
}
],
"sum_other_doc_count": 0,
"doc_count_error_upper_bound": 0
}
}
});
assert_eq!(serde_json::to_value(&result)?, expected);
println!("{}", serde_json::to_string_pretty(&result)?);
Ok(())
}

View File

@@ -85,7 +85,6 @@ fn main() -> tantivy::Result<()> {
index_writer.add_document(doc!(
title => "The Diary of a Young Girl",
))?;
index_writer.commit()?;
// ### Committing
//

View File

@@ -20,17 +20,16 @@ Contains all metric aggregations, like average aggregation. Metric aggregations
#### agg_req
agg_req contains the users aggregation request. Deserialization from json is compatible with elasticsearch aggregation requests.
#### agg_req_with_accessor
agg_req_with_accessor contains the users aggregation request enriched with fast field accessors etc, which are
#### agg_data
agg_data contains the users aggregation request enriched with fast field accessors etc, which are
used during collection.
#### segment_agg_result
segment_agg_result contains the aggregation result tree, which is used for collection of a segment.
The tree from agg_req_with_accessor is passed during collection.
agg_data is passed during collection.
#### intermediate_agg_result
intermediate_agg_result contains the aggregation tree for merging with other trees.
#### agg_result
agg_result contains the final aggregation tree.

View File

@@ -0,0 +1,104 @@
//! This will enhance the request tree with access to the fastfield and metadata.
use std::io;
use columnar::{Column, ColumnType};
use crate::aggregation::{f64_to_fastfield_u64, Key};
use crate::index::SegmentReader;
/// Get the missing value as internal u64 representation
///
/// For terms we use u64::MAX as sentinel value
/// For numerical data we convert the value into the representation
/// we would get from the fast field, when we open it as u64_lenient_for_type.
///
/// That way we can use it the same way as if it would come from the fastfield.
pub(crate) fn get_missing_val_as_u64_lenient(
column_type: ColumnType,
missing: &Key,
field_name: &str,
) -> crate::Result<Option<u64>> {
let missing_val = match missing {
Key::Str(_) if column_type == ColumnType::Str => Some(u64::MAX),
// Allow fallback to number on text fields
Key::F64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::U64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::I64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::F64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val, &column_type)
}
// NOTE: We may loose precision of the passed missing value by casting i64 and u64 to f64.
Key::I64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val as f64, &column_type)
}
Key::U64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val as f64, &column_type)
}
_ => {
return Err(crate::TantivyError::InvalidArgument(format!(
"Missing value {missing:?} for field {field_name} is not supported for column \
type {column_type:?}"
)));
}
};
Ok(missing_val)
}
pub(crate) fn get_numeric_or_date_column_types() -> &'static [ColumnType] {
&[
ColumnType::F64,
ColumnType::U64,
ColumnType::I64,
ColumnType::DateTime,
]
}
/// Get fast field reader or empty as default.
pub(crate) fn get_ff_reader(
reader: &SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
) -> crate::Result<(columnar::Column<u64>, ColumnType)> {
let ff_fields = reader.fast_fields();
let ff_field_with_type = ff_fields
.u64_lenient_for_type(allowed_column_types, field_name)?
.unwrap_or_else(|| {
(
Column::build_empty_column(reader.num_docs()),
ColumnType::U64,
)
});
Ok(ff_field_with_type)
}
pub(crate) fn get_dynamic_columns(
reader: &SegmentReader,
field_name: &str,
) -> crate::Result<Vec<columnar::DynamicColumn>> {
let ff_fields = reader.fast_fields().dynamic_column_handles(field_name)?;
let cols = ff_fields
.iter()
.map(|h| h.open())
.collect::<io::Result<_>>()?;
assert!(!ff_fields.is_empty(), "field {field_name} not found");
Ok(cols)
}
/// Get all fast field reader or empty as default.
///
/// Is guaranteed to return at least one column.
pub(crate) fn get_all_ff_reader_or_empty(
reader: &SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
fallback_type: ColumnType,
) -> crate::Result<Vec<(columnar::Column<u64>, ColumnType)>> {
let ff_fields = reader.fast_fields();
let mut ff_field_with_type =
ff_fields.u64_lenient_for_type_all(allowed_column_types, field_name)?;
if ff_field_with_type.is_empty() {
ff_field_with_type.push((Column::build_empty_column(reader.num_docs()), fallback_type));
}
Ok(ff_field_with_type)
}

1083
src/aggregation/agg_data.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -70,7 +70,7 @@ impl AggregationLimitsGuard {
/// *memory_limit*
/// memory_limit is defined in bytes.
/// Aggregation fails when the estimated memory consumption of the aggregation is higher than
/// memory_limit.
/// memory_limit.
/// memory_limit will default to `DEFAULT_MEMORY_LIMIT` (500MB)
///
/// *bucket_limit*

View File

@@ -26,12 +26,14 @@
//! let _agg_req: Aggregations = serde_json::from_str(elasticsearch_compatible_json_req).unwrap();
//! ```
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use super::bucket::{
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
DateHistogramAggregationReq, FilterAggregation, HistogramAggregation, RangeAggregation,
TermsAggregation,
};
use super::metric::{
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
@@ -43,7 +45,7 @@ use super::metric::{
/// defined names. It is also used in buckets aggregations to define sub-aggregations.
///
/// The key is the user defined name of the aggregation.
pub type Aggregations = HashMap<String, Aggregation>;
pub type Aggregations = FxHashMap<String, Aggregation>;
/// Aggregation request.
///
@@ -129,6 +131,9 @@ pub enum AggregationVariants {
/// Put data into buckets of terms.
#[serde(rename = "terms")]
Terms(TermsAggregation),
/// Filter documents into a single bucket.
#[serde(rename = "filter")]
Filter(FilterAggregation),
// Metric aggregation types
/// Computes the average of the extracted values.
@@ -174,6 +179,7 @@ impl AggregationVariants {
AggregationVariants::Range(range) => vec![range.field.as_str()],
AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::Filter(filter) => filter.get_fast_field_names(),
AggregationVariants::Average(avg) => vec![avg.field_name()],
AggregationVariants::Count(count) => vec![count.field_name()],
AggregationVariants::Max(max) => vec![max.field_name()],
@@ -208,13 +214,6 @@ impl AggregationVariants {
_ => None,
}
}
pub(crate) fn as_top_hits(&self) -> Option<&TopHitsAggregationReq> {
match &self {
AggregationVariants::TopHits(top_hits) => Some(top_hits),
_ => None,
}
}
pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> {
match &self {
AggregationVariants::Percentiles(percentile_req) => Some(percentile_req),

View File

@@ -1,471 +0,0 @@
//! This will enhance the request tree with access to the fastfield and metadata.
use std::collections::HashMap;
use std::io;
use columnar::{Column, ColumnBlockAccessor, ColumnType, DynamicColumn, StrColumn};
use super::agg_req::{Aggregation, AggregationVariants, Aggregations};
use super::bucket::{
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
};
use super::metric::{
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
MaxAggregation, MinAggregation, StatsAggregation, SumAggregation,
};
use super::segment_agg_result::AggregationLimitsGuard;
use super::VecWithNames;
use crate::aggregation::{f64_to_fastfield_u64, Key};
use crate::index::SegmentReader;
use crate::SegmentOrdinal;
#[derive(Default)]
pub(crate) struct AggregationsWithAccessor {
pub aggs: VecWithNames<AggregationWithAccessor>,
}
impl AggregationsWithAccessor {
fn from_data(aggs: VecWithNames<AggregationWithAccessor>) -> Self {
Self { aggs }
}
pub fn is_empty(&self) -> bool {
self.aggs.is_empty()
}
}
pub struct AggregationWithAccessor {
pub(crate) segment_ordinal: SegmentOrdinal,
/// In general there can be buckets without fast field access, e.g. buckets that are created
/// based on search terms. That is not that case currently, but eventually this needs to be
/// Option or moved.
pub(crate) accessor: Column<u64>,
/// Load insert u64 for missing use case
pub(crate) missing_value_for_accessor: Option<u64>,
pub(crate) str_dict_column: Option<StrColumn>,
pub(crate) field_type: ColumnType,
pub(crate) sub_aggregation: AggregationsWithAccessor,
pub(crate) limits: AggregationLimitsGuard,
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
/// Used for missing term aggregation, which checks all columns for existence.
/// And also for `top_hits` aggregation, which may sort on multiple fields.
/// By convention the missing aggregation is chosen, when this property is set
/// (instead bein set in `agg`).
/// If this needs to used by other aggregations, we need to refactor this.
// NOTE: we can make all other aggregations use this instead of the `accessor` and `field_type`
// (making them obsolete) But will it have a performance impact?
pub(crate) accessors: Vec<(Column<u64>, ColumnType)>,
/// Map field names to all associated column accessors.
/// This field is used for `docvalue_fields`, which is currently only supported for `top_hits`.
pub(crate) value_accessors: HashMap<String, Vec<DynamicColumn>>,
pub(crate) agg: Aggregation,
}
impl AggregationWithAccessor {
/// May return multiple accessors if the aggregation is e.g. on mixed field types.
fn try_from_agg(
agg: &Aggregation,
sub_aggregation: &Aggregations,
reader: &SegmentReader,
segment_ordinal: SegmentOrdinal,
limits: AggregationLimitsGuard,
) -> crate::Result<Vec<AggregationWithAccessor>> {
let mut agg = agg.clone();
let add_agg_with_accessor = |agg: &Aggregation,
accessor: Column<u64>,
column_type: ColumnType,
aggs: &mut Vec<AggregationWithAccessor>|
-> crate::Result<()> {
let res = AggregationWithAccessor {
segment_ordinal,
accessor,
accessors: Default::default(),
value_accessors: Default::default(),
field_type: column_type,
sub_aggregation: get_aggs_with_segment_accessor_and_validate(
sub_aggregation,
reader,
segment_ordinal,
&limits,
)?,
agg: agg.clone(),
limits: limits.clone(),
missing_value_for_accessor: None,
str_dict_column: None,
column_block_accessor: Default::default(),
};
aggs.push(res);
Ok(())
};
let add_agg_with_accessors = |agg: &Aggregation,
accessors: Vec<(Column<u64>, ColumnType)>,
aggs: &mut Vec<AggregationWithAccessor>,
value_accessors: HashMap<String, Vec<DynamicColumn>>|
-> crate::Result<()> {
let (accessor, field_type) = accessors.first().expect("at least one accessor");
let limits = limits.clone();
let res = AggregationWithAccessor {
segment_ordinal,
// TODO: We should do away with the `accessor` field altogether
accessor: accessor.clone(),
value_accessors,
field_type: *field_type,
accessors,
sub_aggregation: get_aggs_with_segment_accessor_and_validate(
sub_aggregation,
reader,
segment_ordinal,
&limits,
)?,
agg: agg.clone(),
limits,
missing_value_for_accessor: None,
str_dict_column: None,
column_block_accessor: Default::default(),
};
aggs.push(res);
Ok(())
};
let mut res: Vec<AggregationWithAccessor> = Vec::new();
use AggregationVariants::*;
match agg.agg {
Range(RangeAggregation {
field: ref field_name,
..
}) => {
let (accessor, column_type) =
get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
Histogram(HistogramAggregation {
field: ref field_name,
..
}) => {
let (accessor, column_type) =
get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
DateHistogram(DateHistogramAggregationReq {
field: ref field_name,
..
}) => {
let (accessor, column_type) =
// Only DateTime is supported for DateHistogram
get_ff_reader(reader, field_name, Some(&[ColumnType::DateTime]))?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
Terms(TermsAggregation {
field: ref field_name,
ref missing,
..
})
| Cardinality(CardinalityAggregationReq {
field: ref field_name,
ref missing,
..
}) => {
let str_dict_column = reader.fast_fields().str(field_name)?;
let allowed_column_types = [
ColumnType::I64,
ColumnType::U64,
ColumnType::F64,
ColumnType::Str,
ColumnType::DateTime,
ColumnType::Bool,
ColumnType::IpAddr,
// ColumnType::Bytes Unsupported
];
// In case the column is empty we want the shim column to match the missing type
let fallback_type = missing
.as_ref()
.map(|missing| match missing {
Key::Str(_) => ColumnType::Str,
Key::F64(_) => ColumnType::F64,
Key::I64(_) => ColumnType::I64,
Key::U64(_) => ColumnType::U64,
})
.unwrap_or(ColumnType::U64);
let column_and_types = get_all_ff_reader_or_empty(
reader,
field_name,
Some(&allowed_column_types),
fallback_type,
)?;
let missing_and_more_than_one_col = column_and_types.len() > 1 && missing.is_some();
let text_on_non_text_col = column_and_types.len() == 1
&& column_and_types[0].1.numerical_type().is_some()
&& missing
.as_ref()
.map(|m| matches!(m, Key::Str(_)))
.unwrap_or(false);
// Actually we could convert the text to a number and have the fast path, if it is
// provided in Rfc3339 format. But this use case is probably common
// enough to justify the effort.
let text_on_date_col = column_and_types.len() == 1
&& column_and_types[0].1 == ColumnType::DateTime
&& missing
.as_ref()
.map(|m| matches!(m, Key::Str(_)))
.unwrap_or(false);
let use_special_missing_agg =
missing_and_more_than_one_col || text_on_non_text_col || text_on_date_col;
if use_special_missing_agg {
let column_and_types =
get_all_ff_reader_or_empty(reader, field_name, None, fallback_type)?;
let accessors = column_and_types
.iter()
.map(|c_t| (c_t.0.clone(), c_t.1))
.collect();
add_agg_with_accessors(&agg, accessors, &mut res, Default::default())?;
}
for (accessor, column_type) in column_and_types {
let missing_value_term_agg = if use_special_missing_agg {
None
} else {
missing.clone()
};
let missing_value_for_accessor =
if let Some(missing) = missing_value_term_agg.as_ref() {
get_missing_val_as_u64_lenient(
column_type,
missing,
agg.agg.get_fast_field_names()[0],
)?
} else {
None
};
let limits = limits.clone();
let agg = AggregationWithAccessor {
segment_ordinal,
missing_value_for_accessor,
accessor,
accessors: Default::default(),
value_accessors: Default::default(),
field_type: column_type,
sub_aggregation: get_aggs_with_segment_accessor_and_validate(
sub_aggregation,
reader,
segment_ordinal,
&limits,
)?,
agg: agg.clone(),
str_dict_column: str_dict_column.clone(),
limits,
column_block_accessor: Default::default(),
};
res.push(agg);
}
}
Average(AverageAggregation {
field: ref field_name,
..
})
| Max(MaxAggregation {
field: ref field_name,
..
})
| Min(MinAggregation {
field: ref field_name,
..
})
| Stats(StatsAggregation {
field: ref field_name,
..
})
| ExtendedStats(ExtendedStatsAggregation {
field: ref field_name,
..
})
| Sum(SumAggregation {
field: ref field_name,
..
}) => {
let (accessor, column_type) =
get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
Count(CountAggregation {
field: ref field_name,
..
}) => {
let allowed_column_types = [
ColumnType::I64,
ColumnType::U64,
ColumnType::F64,
ColumnType::Str,
ColumnType::DateTime,
ColumnType::Bool,
ColumnType::IpAddr,
// ColumnType::Bytes Unsupported
];
let (accessor, column_type) =
get_ff_reader(reader, field_name, Some(&allowed_column_types))?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
Percentiles(ref percentiles) => {
let (accessor, column_type) = get_ff_reader(
reader,
percentiles.field_name(),
Some(get_numeric_or_date_column_types()),
)?;
add_agg_with_accessor(&agg, accessor, column_type, &mut res)?;
}
TopHits(ref mut top_hits) => {
top_hits.validate_and_resolve_field_names(reader.fast_fields().columnar())?;
let accessors: Vec<(Column<u64>, ColumnType)> = top_hits
.field_names()
.iter()
.map(|field| {
get_ff_reader(reader, field, Some(get_numeric_or_date_column_types()))
})
.collect::<crate::Result<_>>()?;
let value_accessors = top_hits
.value_field_names()
.iter()
.map(|field_name| {
Ok((
field_name.to_string(),
get_dynamic_columns(reader, field_name)?,
))
})
.collect::<crate::Result<_>>()?;
add_agg_with_accessors(&agg, accessors, &mut res, value_accessors)?;
}
};
Ok(res)
}
}
/// Get the missing value as internal u64 representation
///
/// For terms we use u64::MAX as sentinel value
/// For numerical data we convert the value into the representation
/// we would get from the fast field, when we open it as u64_lenient_for_type.
///
/// That way we can use it the same way as if it would come from the fastfield.
fn get_missing_val_as_u64_lenient(
column_type: ColumnType,
missing: &Key,
field_name: &str,
) -> crate::Result<Option<u64>> {
let missing_val = match missing {
Key::Str(_) if column_type == ColumnType::Str => Some(u64::MAX),
// Allow fallback to number on text fields
Key::F64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::U64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::I64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::F64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val, &column_type)
}
// NOTE: We may loose precision of the passed missing value by casting i64 and u64 to f64.
Key::I64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val as f64, &column_type)
}
Key::U64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val as f64, &column_type)
}
_ => {
return Err(crate::TantivyError::InvalidArgument(format!(
"Missing value {missing:?} for field {field_name} is not supported for column \
type {column_type:?}"
)));
}
};
Ok(missing_val)
}
fn get_numeric_or_date_column_types() -> &'static [ColumnType] {
&[
ColumnType::F64,
ColumnType::U64,
ColumnType::I64,
ColumnType::DateTime,
]
}
pub(crate) fn get_aggs_with_segment_accessor_and_validate(
aggs: &Aggregations,
reader: &SegmentReader,
segment_ordinal: SegmentOrdinal,
limits: &AggregationLimitsGuard,
) -> crate::Result<AggregationsWithAccessor> {
let mut aggss = Vec::new();
for (key, agg) in aggs.iter() {
let aggs = AggregationWithAccessor::try_from_agg(
agg,
agg.sub_aggregation(),
reader,
segment_ordinal,
limits.clone(),
)?;
for agg in aggs {
aggss.push((key.to_string(), agg));
}
}
Ok(AggregationsWithAccessor::from_data(
VecWithNames::from_entries(aggss),
))
}
/// Get fast field reader or empty as default.
fn get_ff_reader(
reader: &SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
) -> crate::Result<(columnar::Column<u64>, ColumnType)> {
let ff_fields = reader.fast_fields();
let ff_field_with_type = ff_fields
.u64_lenient_for_type(allowed_column_types, field_name)?
.unwrap_or_else(|| {
(
Column::build_empty_column(reader.num_docs()),
ColumnType::U64,
)
});
Ok(ff_field_with_type)
}
fn get_dynamic_columns(
reader: &SegmentReader,
field_name: &str,
) -> crate::Result<Vec<columnar::DynamicColumn>> {
let ff_fields = reader.fast_fields().dynamic_column_handles(field_name)?;
let cols = ff_fields
.iter()
.map(|h| h.open())
.collect::<io::Result<_>>()?;
assert!(!ff_fields.is_empty(), "field {field_name} not found");
Ok(cols)
}
/// Get all fast field reader or empty as default.
///
/// Is guaranteed to return at least one column.
fn get_all_ff_reader_or_empty(
reader: &SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
fallback_type: ColumnType,
) -> crate::Result<Vec<(columnar::Column<u64>, ColumnType)>> {
let ff_fields = reader.fast_fields();
let mut ff_field_with_type =
ff_fields.u64_lenient_for_type_all(allowed_column_types, field_name)?;
if ff_field_with_type.is_empty() {
ff_field_with_type.push((Column::build_empty_column(reader.num_docs()), fallback_type));
}
Ok(ff_field_with_type)
}

View File

@@ -156,6 +156,8 @@ pub enum BucketResult {
/// The upper bound error for the doc count of each term.
doc_count_error_upper_bound: Option<u64>,
},
/// This is the filter result - a single bucket with sub-aggregations
Filter(FilterBucketResult),
}
impl BucketResult {
@@ -172,6 +174,11 @@ impl BucketResult {
sum_other_doc_count: _,
doc_count_error_upper_bound: _,
} => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(),
BucketResult::Filter(filter_result) => {
// Filter doesn't add to bucket count - it's not a user-facing bucket
// Only count sub-aggregation buckets
filter_result.sub_aggregations.get_bucket_count()
}
}
}
}
@@ -308,3 +315,25 @@ impl RangeBucketEntry {
1 + self.sub_aggregation.get_bucket_count()
}
}
/// This is the filter bucket result, which contains the document count and sub-aggregations.
///
/// # JSON Format
/// ```json
/// {
/// "electronics_only": {
/// "doc_count": 2,
/// "avg_price": {
/// "value": 150.0
/// }
/// }
/// }
/// ```
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct FilterBucketResult {
/// Number of documents in the filter bucket
pub doc_count: u64,
/// Sub-aggregation results
#[serde(flatten)]
pub sub_aggregations: AggregationResults,
}

View File

@@ -5,7 +5,6 @@ use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
use crate::aggregation::collector::AggregationCollector;
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::segment_agg_result::AggregationLimitsGuard;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
use crate::aggregation::DistributedAggregationCollector;
use crate::query::{AllQuery, TermQuery};
@@ -128,10 +127,8 @@ fn test_aggregation_flushing(
.unwrap();
let agg_res: AggregationResults = if use_distributed_collector {
let collector = DistributedAggregationCollector::from_aggs(
agg_req.clone(),
AggregationLimitsGuard::default(),
);
let collector =
DistributedAggregationCollector::from_aggs(agg_req.clone(), Default::default());
let searcher = reader.searcher();
let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap();

File diff suppressed because it is too large Load Diff

View File

@@ -1,25 +1,54 @@
use std::cmp::Ordering;
use columnar::{Column, ColumnBlockAccessor, ColumnType};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use tantivy_bitpacker::minmax;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::agg_result::BucketEntry;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateHistogramBucketEntry,
};
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, SegmentAggregationCollector,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
/// Contains all information required by the SegmentHistogramCollector to perform the
/// histogram or date_histogram aggregation on a segment.
pub struct HistogramAggReqData {
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// The field type of the fast field.
pub field_type: ColumnType,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The name of the aggregation.
pub name: String,
/// The sub aggregation blueprint, used to create sub aggregations for each bucket.
/// Will be filled during initialization of the collector.
pub sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
/// The histogram aggregation request.
pub req: HistogramAggregation,
/// True if this is a date_histogram aggregation.
pub is_date_histogram: bool,
/// The bounds to limit the buckets to.
pub bounds: HistogramBounds,
/// The offset used to calculate the bucket position.
pub offset: f64,
}
impl HistogramAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// Histogram is a bucket aggregation, where buckets are created dynamically for given `interval`.
/// Each document value is rounded down to its bucket.
///
@@ -234,12 +263,12 @@ impl SegmentHistogramBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
agg_with_accessor: &AggregationsWithAccessor,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateHistogramBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = sub_aggregation {
sub_aggregation
.add_intermediate_aggregation_result(agg_with_accessor, &mut sub_aggregation_res)?;
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?;
}
Ok(IntermediateHistogramBucketEntry {
key: self.key,
@@ -256,24 +285,20 @@ pub struct SegmentHistogramCollector {
/// The buckets containing the aggregation data.
buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
sub_aggregations: FxHashMap<i64, Box<dyn SegmentAggregationCollector>>,
sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
column_type: ColumnType,
interval: f64,
offset: f64,
bounds: HistogramBounds,
accessor_idx: usize,
}
impl SegmentAggregationCollector for SegmentHistogramCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
let bucket = self.into_intermediate_bucket_result(agg_with_accessor)?;
let name = agg_data
.get_histogram_req_data(self.accessor_idx)
.name
.clone();
let bucket = self.into_intermediate_bucket_result(agg_data)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
@@ -283,56 +308,52 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_with_accessor)
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
let mut req = agg_data.take_histogram_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption();
let bounds = self.bounds;
let interval = self.interval;
let offset = self.offset;
let bounds = req.bounds;
let interval = req.req.interval;
let offset = req.offset;
let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64;
bucket_agg_accessor
req.column_block_accessor.fetch_block(docs, &req.accessor);
for (doc, val) in req
.column_block_accessor
.fetch_block(docs, &bucket_agg_accessor.accessor);
for (doc, val) in bucket_agg_accessor
.column_block_accessor
.iter_docid_vals(docs, &bucket_agg_accessor.accessor)
.iter_docid_vals(docs, &req.accessor)
{
let val = self.f64_from_fastfield_u64(val);
let val = f64_from_fastfield_u64(val, &req.field_type);
let bucket_pos = get_bucket_pos(val);
if bounds.contains(val) {
let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| {
let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset);
SegmentHistogramBucketEntry { key, doc_count: 0 }
});
bucket.doc_count += 1;
if let Some(sub_aggregation_blueprint) = self.sub_aggregation_blueprint.as_mut() {
if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() {
self.sub_aggregations
.entry(bucket_pos)
.or_insert_with(|| sub_aggregation_blueprint.clone())
.collect(doc, &mut bucket_agg_accessor.sub_aggregation)?;
.collect(doc, agg_data)?;
}
}
}
agg_data.put_back_histogram_req_data(self.accessor_idx, req);
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
bucket_agg_accessor
agg_data
.context
.limits
.add_memory_consumed(mem_delta as u64)?;
}
@@ -340,12 +361,9 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
Ok(())
}
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
let sub_aggregation_accessor =
&mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation;
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for sub_aggregation in self.sub_aggregations.values_mut() {
sub_aggregation.flush(sub_aggregation_accessor)?;
sub_aggregation.flush(agg_data)?;
}
Ok(())
@@ -362,65 +380,58 @@ impl SegmentHistogramCollector {
/// Converts the collector result into a intermediate bucket result.
pub fn into_intermediate_bucket_result(
self,
agg_with_accessor: &AggregationWithAccessor,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateBucketResult> {
let mut buckets = Vec::with_capacity(self.buckets.len());
for (bucket_pos, bucket) in self.buckets {
let bucket_res = bucket.into_intermediate_bucket_entry(
self.sub_aggregations.get(&bucket_pos).cloned(),
&agg_with_accessor.sub_aggregation,
agg_data,
);
buckets.push(bucket_res?);
}
buckets.sort_unstable_by(|b1, b2| b1.key.total_cmp(&b2.key));
let is_date_agg = agg_data
.get_histogram_req_data(self.accessor_idx)
.field_type
== ColumnType::DateTime;
Ok(IntermediateBucketResult::Histogram {
buckets,
is_date_agg: self.column_type == ColumnType::DateTime,
is_date_agg,
})
}
pub(crate) fn from_req_and_validate(
mut req: HistogramAggregation,
sub_aggregation: &mut AggregationsWithAccessor,
field_type: ColumnType,
accessor_idx: usize,
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
req.validate()?;
if field_type == ColumnType::DateTime {
req.normalize_date_time();
}
let sub_aggregation_blueprint = if sub_aggregation.is_empty() {
None
let blueprint = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else {
let sub_aggregation = build_segment_agg_collector(sub_aggregation)?;
Some(sub_aggregation)
None
};
let bounds = req.hard_bounds.unwrap_or(HistogramBounds {
let req_data = agg_data.get_histogram_req_data_mut(node.idx_in_req_data);
req_data.req.validate()?;
if req_data.field_type == ColumnType::DateTime && !req_data.is_date_histogram {
req_data.req.normalize_date_time();
}
req_data.bounds = req_data.req.hard_bounds.unwrap_or(HistogramBounds {
min: f64::MIN,
max: f64::MAX,
});
req_data.offset = req_data.req.offset.unwrap_or(0.0);
req_data.sub_aggregation_blueprint = blueprint;
Ok(Self {
buckets: Default::default(),
column_type: field_type,
interval: req.interval,
offset: req.offset.unwrap_or(0.0),
bounds,
sub_aggregations: Default::default(),
sub_aggregation_blueprint,
accessor_idx,
accessor_idx: node.idx_in_req_data,
})
}
#[inline]
fn f64_from_fastfield_u64(&self, val: u64) -> f64 {
f64_from_fastfield_u64(val, &self.column_type)
}
}
#[inline]

View File

@@ -22,6 +22,7 @@
//! - [Range](RangeAggregation)
//! - [Terms](TermsAggregation)
mod filter;
mod histogram;
mod range;
mod term_agg;
@@ -30,6 +31,7 @@ mod term_missing_agg;
use std::collections::HashMap;
use std::fmt;
pub use filter::*;
pub use histogram::*;
pub use range::*;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};

View File

@@ -1,20 +1,43 @@
use std::fmt::Debug;
use std::ops::Range;
use columnar::{Column, ColumnBlockAccessor, ColumnType};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
};
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, SegmentAggregationCollector,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
/// Contains all information required by the SegmentRangeCollector to perform the
/// range aggregation on a segment.
pub struct RangeAggReqData {
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// The type of the fast field.
pub field_type: ColumnType,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The range aggregation request.
pub req: RangeAggregation,
/// The name of the aggregation.
pub name: String,
}
impl RangeAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// Provide user-defined buckets to aggregate on.
///
/// Two special buckets will automatically be created to cover the whole range of values.
@@ -161,12 +184,12 @@ impl Debug for SegmentRangeBucketEntry {
impl SegmentRangeBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
agg_with_accessor: &AggregationsWithAccessor,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateRangeBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = self.sub_aggregation {
sub_aggregation
.add_intermediate_aggregation_result(agg_with_accessor, &mut sub_aggregation_res)?
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?
} else {
Default::default()
};
@@ -184,12 +207,14 @@ impl SegmentRangeBucketEntry {
impl SegmentAggregationCollector for SegmentRangeCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let field_type = self.column_type;
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let sub_agg = &agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation;
let name = agg_data
.get_range_req_data(self.accessor_idx)
.name
.to_string();
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
.buckets
@@ -199,7 +224,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
range_to_string(&range_bucket.range, &field_type)?,
range_bucket
.bucket
.into_intermediate_bucket_entry(sub_agg)?,
.into_intermediate_bucket_entry(agg_data)?,
))
})
.collect::<crate::Result<_>>()?;
@@ -218,66 +243,70 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_with_accessor)
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
// Take request data to avoid borrow conflicts during sub-aggregation
let mut req = agg_data.take_range_req_data(self.accessor_idx);
bucket_agg_accessor
.column_block_accessor
.fetch_block(docs, &bucket_agg_accessor.accessor);
req.column_block_accessor.fetch_block(docs, &req.accessor);
for (doc, val) in bucket_agg_accessor
for (doc, val) in req
.column_block_accessor
.iter_docid_vals(docs, &bucket_agg_accessor.accessor)
.iter_docid_vals(docs, &req.accessor)
{
let bucket_pos = self.get_bucket_pos(val);
let bucket = &mut self.buckets[bucket_pos];
bucket.bucket.doc_count += 1;
if let Some(sub_aggregation) = &mut bucket.bucket.sub_aggregation {
sub_aggregation.collect(doc, &mut bucket_agg_accessor.sub_aggregation)?;
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
sub_agg.collect(doc, agg_data)?;
}
}
agg_data.put_back_range_req_data(self.accessor_idx, req);
Ok(())
}
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
let sub_aggregation_accessor =
&mut agg_with_accessor.aggs.values[self.accessor_idx].sub_aggregation;
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for bucket in self.buckets.iter_mut() {
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
sub_agg.flush(sub_aggregation_accessor)?;
sub_agg.flush(agg_data)?;
}
}
Ok(())
}
}
impl SegmentRangeCollector {
pub(crate) fn from_req_and_validate(
req: &RangeAggregation,
sub_aggregation: &mut AggregationsWithAccessor,
limits: &mut AggregationLimitsGuard,
field_type: ColumnType,
accessor_idx: usize,
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let accessor_idx = node.idx_in_req_data;
let (field_type, ranges) = {
let req_view = req_data.get_range_req_data(node.idx_in_req_data);
(req_view.field_type, req_view.req.ranges.clone())
};
// The range input on the request is f64.
// We need to convert to u64 ranges, because we read the values as u64.
// The mapping from the conversion is monotonic so ordering is preserved.
let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)?
let sub_agg_prototype = if !node.children.is_empty() {
Some(build_segment_agg_collectors(req_data, &node.children)?)
} else {
None
};
let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)?
.iter()
.map(|range| {
let key = range
@@ -295,11 +324,7 @@ impl SegmentRangeCollector {
} else {
Some(f64_from_fastfield_u64(range.range.start, &field_type))
};
let sub_aggregation = if sub_aggregation.is_empty() {
None
} else {
Some(build_segment_agg_collector(sub_aggregation)?)
};
let sub_aggregation = sub_agg_prototype.clone();
Ok(SegmentRangeAndBucketEntry {
range: range.range.clone(),
@@ -314,7 +339,7 @@ impl SegmentRangeCollector {
})
.collect::<crate::Result<_>>()?;
limits.add_memory_consumed(
req_data.context.limits.add_memory_consumed(
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
)?;
@@ -467,15 +492,45 @@ mod tests {
ranges,
..Default::default()
};
// Build buckets directly as in from_req_and_validate without AggregationsData
let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)
.expect("unexpected error in extend_validate_ranges")
.iter()
.map(|range| {
let key = range
.key
.clone()
.map(|key| Ok(Key::Str(key)))
.unwrap_or_else(|| range_to_key(&range.range, &field_type))
.expect("unexpected error in range_to_key");
let to = if range.range.end == u64::MAX {
None
} else {
Some(f64_from_fastfield_u64(range.range.end, &field_type))
};
let from = if range.range.start == u64::MIN {
None
} else {
Some(f64_from_fastfield_u64(range.range.start, &field_type))
};
SegmentRangeAndBucketEntry {
range: range.range.clone(),
bucket: SegmentRangeBucketEntry {
doc_count: 0,
sub_aggregation: None,
key,
from,
to,
},
}
})
.collect();
SegmentRangeCollector::from_req_and_validate(
&req,
&mut Default::default(),
&mut AggregationLimitsGuard::default(),
field_type,
0,
)
.expect("unexpected error")
SegmentRangeCollector {
buckets,
column_type: field_type,
accessor_idx: 0,
}
}
#[test]

View File

@@ -0,0 +1,196 @@
use std::fmt::Debug;
use columnar::ColumnType;
use rustc_hash::FxHashMap;
use super::OrderTarget;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::bucket::get_agg_name_and_property;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::TantivyError;
#[derive(Clone, Debug, Default)]
/// Container to store term_ids/or u64 values and their buckets.
struct TermBuckets {
pub(crate) entries: FxHashMap<u64, u32>,
pub(crate) sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>>,
}
impl TermBuckets {
fn get_memory_consumption(&self) -> usize {
let sub_aggs_mem = self.sub_aggs.memory_consumption();
let buckets_mem = self.entries.memory_consumption();
sub_aggs_mem + buckets_mem
}
fn force_flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for sub_aggregations in &mut self.sub_aggs.values_mut() {
sub_aggregations.as_mut().flush(agg_data)?;
}
Ok(())
}
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Clone, Debug)]
pub struct SegmentTermCollector {
/// The buckets containing the aggregation data.
term_buckets: TermBuckets,
accessor_idx: usize,
}
impl SegmentAggregationCollector for SegmentTermCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_data.get_term_req_data(self.accessor_idx).name.clone();
let entries: Vec<(u64, u32)> = self.term_buckets.entries.into_iter().collect();
let bucket = super::into_intermediate_bucket_result(
self.accessor_idx,
entries,
self.term_buckets.sub_aggs,
agg_data,
)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mut req_data = agg_data.take_term_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption();
if let Some(missing) = req_data.missing_value_for_accessor {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
for term_id in req_data.column_block_accessor.iter_vals() {
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
let entry = self.term_buckets.entries.entry(term_id).or_default();
*entry += 1;
}
// has subagg
if let Some(blueprint) = req_data.sub_aggregation_blueprint.as_ref() {
for (doc, term_id) in req_data
.column_block_accessor
.iter_docid_vals(docs, &req_data.accessor)
{
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
let sub_aggregations = self
.term_buckets
.sub_aggs
.entry(term_id)
.or_insert_with(|| blueprint.clone());
sub_aggregations.collect(doc, agg_data)?;
}
}
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
agg_data
.context
.limits
.add_memory_consumed(mem_delta as u64)?;
}
agg_data.put_back_term_req_data(self.accessor_idx, req_data);
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
self.term_buckets.force_flush(agg_data)?;
Ok(())
}
}
impl SegmentTermCollector {
pub fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data);
let column_type = terms_req_data.column_type;
let accessor_idx = node.idx_in_req_data;
if column_type == ColumnType::Bytes {
return Err(TantivyError::InvalidArgument(format!(
"terms aggregation is not supported for column type {column_type:?}"
)));
}
let term_buckets = TermBuckets::default();
// Validate sub aggregation exists
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric \
sub_aggregations"
))
})?;
}
let has_sub_aggregations = !node.children.is_empty();
let blueprint = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
Some(sub_aggregation)
} else {
None
};
let terms_req_data = req_data.get_term_req_data_mut(node.idx_in_req_data);
terms_req_data.sub_aggregation_blueprint = blueprint;
Ok(SegmentTermCollector {
term_buckets,
accessor_idx,
})
}
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let term_buckets_mem = self.term_buckets.get_memory_consumption();
self_mem + term_buckets_mem
}
}

View File

@@ -0,0 +1,228 @@
use std::vec;
use rustc_hash::FxHashMap;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::{get_agg_name_and_property, OrderTarget};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::{DocId, TantivyError};
const MAX_BATCH_SIZE: usize = 1_024;
#[derive(Debug, Clone)]
struct LowCardTermBuckets {
entries: Box<[u32]>,
sub_aggs: Vec<Box<dyn SegmentAggregationCollector>>,
doc_buffers: Box<[Vec<DocId>]>,
}
impl LowCardTermBuckets {
pub fn with_num_buckets(
num_buckets: usize,
sub_aggs_blueprint_opt: Option<&Box<dyn SegmentAggregationCollector>>,
) -> Self {
let sub_aggs = sub_aggs_blueprint_opt
.as_ref()
.map(|blueprint| {
std::iter::repeat_with(|| blueprint.clone_box())
.take(num_buckets)
.collect::<Vec<_>>()
})
.unwrap_or_default();
Self {
entries: vec![0; num_buckets].into_boxed_slice(),
sub_aggs,
doc_buffers: std::iter::repeat_with(|| Vec::with_capacity(MAX_BATCH_SIZE))
.take(num_buckets)
.collect::<Vec<_>>()
.into_boxed_slice(),
}
}
fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
+ self.entries.len() * std::mem::size_of::<u32>()
+ self.doc_buffers.len()
* (std::mem::size_of::<Vec<DocId>>()
+ std::mem::size_of::<DocId>() * MAX_BATCH_SIZE)
}
}
#[derive(Debug, Clone)]
pub struct LowCardSegmentTermCollector {
term_buckets: LowCardTermBuckets,
accessor_idx: usize,
}
impl LowCardSegmentTermCollector {
pub fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let terms_req_data = req_data.get_term_req_data(node.idx_in_req_data);
let accessor_idx = node.idx_in_req_data;
let cardinality = terms_req_data
.accessor
.max_value()
.max(terms_req_data.missing_value_for_accessor.unwrap_or(0))
+ 1;
assert!(cardinality <= super::LOW_CARDINALITY_THRESHOLD);
// Validate sub aggregation exists
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric \
sub_aggregations"
))
})?;
}
let has_sub_aggregations = !node.children.is_empty();
let blueprint = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
Some(sub_aggregation)
} else {
None
};
let terms_req_data = req_data.get_term_req_data_mut(node.idx_in_req_data);
let term_buckets =
LowCardTermBuckets::with_num_buckets(cardinality as usize, blueprint.as_ref());
terms_req_data.sub_aggregation_blueprint = blueprint;
Ok(LowCardSegmentTermCollector {
term_buckets,
accessor_idx,
})
}
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let term_buckets_mem = self.term_buckets.get_memory_consumption();
self_mem + term_buckets_mem
}
}
impl SegmentAggregationCollector for LowCardSegmentTermCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_data.get_term_req_data(self.accessor_idx).name.clone();
let sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>> = self
.term_buckets
.sub_aggs
.into_iter()
.enumerate()
.filter(|(bucket_id, _sub_agg)| self.term_buckets.entries[*bucket_id] > 0)
.map(|(bucket_id, sub_agg)| (bucket_id as u64, sub_agg))
.collect();
let entries: Vec<(u64, u32)> = self
.term_buckets
.entries
.iter()
.enumerate()
.filter(|(_, count)| **count > 0)
.map(|(bucket_id, count)| (bucket_id as u64, *count))
.collect();
let bucket =
super::into_intermediate_bucket_result(self.accessor_idx, entries, sub_aggs, agg_data)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
}
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if docs.len() > MAX_BATCH_SIZE {
for batch in docs.chunks(MAX_BATCH_SIZE) {
self.collect_block(batch, agg_data)?;
}
}
let mut req_data = agg_data.take_term_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption();
if let Some(missing) = req_data.missing_value_for_accessor {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
// has subagg
if req_data.sub_aggregation_blueprint.is_some() {
for (doc, term_id) in req_data
.column_block_accessor
.iter_docid_vals(docs, &req_data.accessor)
{
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
self.term_buckets.doc_buffers[term_id as usize].push(doc);
}
for (bucket_id, docs) in self.term_buckets.doc_buffers.iter_mut().enumerate() {
self.term_buckets.entries[bucket_id] += docs.len() as u32;
self.term_buckets.sub_aggs[bucket_id].collect_block(&docs[..], agg_data)?;
docs.clear();
}
} else {
for term_id in req_data.column_block_accessor.iter_vals() {
if let Some(allowed_bs) = req_data.allowed_term_ids.as_ref() {
if !allowed_bs.contains(term_id as u32) {
continue;
}
}
self.term_buckets.entries[term_id as usize] += 1;
}
}
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
agg_data
.context
.limits
.add_memory_consumed(mem_delta as u64)?;
}
agg_data.put_back_term_req_data(self.accessor_idx, req_data);
Ok(())
}
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for sub_aggregations in &mut self.term_buckets.sub_aggs.iter_mut() {
sub_aggregations.as_mut().flush(agg_data)?;
}
Ok(())
}
}

View File

@@ -1,13 +1,39 @@
use columnar::{Column, ColumnType};
use rustc_hash::FxHashMap;
use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::term_agg::TermsAggregation;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
};
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, SegmentAggregationCollector,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
/// Special aggregation to handle missing values for term aggregations.
/// This missing aggregation will check multiple columns for existence.
///
/// This is needed when:
/// - The field is multi-valued and we therefore have multiple columns
/// - The field is not text and missing is provided as string (we cannot use the numeric missing
/// value optimization)
#[derive(Default)]
pub struct MissingTermAggReqData {
/// The accessors to check for existence of a value.
pub accessors: Vec<(Column<u64>, ColumnType)>,
/// The name of the aggregation.
pub name: String,
/// The original terms aggregation request.
pub req: TermsAggregation,
}
impl MissingTermAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// The specialized missing term aggregation.
#[derive(Default, Debug, Clone)]
@@ -18,12 +44,13 @@ pub struct TermMissingAgg {
}
impl TermMissingAgg {
pub(crate) fn new(
accessor_idx: usize,
sub_aggregations: &mut AggregationsWithAccessor,
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let has_sub_aggregations = !sub_aggregations.is_empty();
let has_sub_aggregations = !node.children.is_empty();
let accessor_idx = node.idx_in_req_data;
let sub_agg = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collector(sub_aggregations)?;
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
Some(sub_aggregation)
} else {
None
@@ -40,16 +67,11 @@ impl TermMissingAgg {
impl SegmentAggregationCollector for TermMissingAgg {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
let term_agg = agg_with_accessor
.agg
.agg
.as_term()
.expect("TermMissingAgg collector must be term agg req");
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
let term_agg = &req_data.req;
let missing = term_agg
.missing
.as_ref()
@@ -64,10 +86,7 @@ impl SegmentAggregationCollector for TermMissingAgg {
};
if let Some(sub_agg) = self.sub_agg {
let mut res = IntermediateAggregationResults::default();
sub_agg.add_intermediate_aggregation_result(
&agg_with_accessor.sub_aggregation,
&mut res,
)?;
sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?;
missing_entry.sub_aggregation = res;
}
entries.insert(missing.into(), missing_entry);
@@ -80,7 +99,10 @@ impl SegmentAggregationCollector for TermMissingAgg {
},
};
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
results.push(
req_data.name.to_string(),
IntermediateAggregationResult::Bucket(bucket),
)?;
Ok(())
}
@@ -88,17 +110,17 @@ impl SegmentAggregationCollector for TermMissingAgg {
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let agg = &mut agg_with_accessor.aggs.values[self.accessor_idx];
let has_value = agg
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
let has_value = req_data
.accessors
.iter()
.any(|(acc, _)| acc.index.has_value(doc));
if !has_value {
self.missing_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.collect(doc, &mut agg.sub_aggregation)?;
sub_agg.collect(doc, agg_data)?;
}
}
Ok(())
@@ -107,10 +129,10 @@ impl SegmentAggregationCollector for TermMissingAgg {
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
for doc in docs {
self.collect(*doc, agg_with_accessor)?;
self.collect(*doc, agg_data)?;
}
Ok(())
}

View File

@@ -1,6 +1,6 @@
use super::agg_req_with_accessor::AggregationsWithAccessor;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::DocId;
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
@@ -37,23 +37,23 @@ impl SegmentAggregationCollector for BufAggregationCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
Box::new(self.collector).add_intermediate_aggregation_result(agg_with_accessor, results)
Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results)
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.staged_docs[self.num_staged_docs] = doc;
self.num_staged_docs += 1;
if self.num_staged_docs == self.staged_docs.len() {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?;
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
self.num_staged_docs = 0;
}
Ok(())
@@ -63,20 +63,20 @@ impl SegmentAggregationCollector for BufAggregationCollector {
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collector.collect_block(docs, agg_with_accessor)?;
self.collector.collect_block(docs, agg_data)?;
Ok(())
}
#[inline]
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?;
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
self.num_staged_docs = 0;
self.collector.flush(agg_with_accessor)?;
self.collector.flush(agg_data)?;
Ok(())
}

View File

@@ -1,12 +1,12 @@
use super::agg_req::Aggregations;
use super::agg_req_with_accessor::AggregationsWithAccessor;
use super::agg_result::AggregationResults;
use super::buf_collector::BufAggregationCollector;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::{
build_segment_agg_collector, AggregationLimitsGuard, SegmentAggregationCollector,
use super::segment_agg_result::SegmentAggregationCollector;
use super::AggContextParams;
use crate::aggregation::agg_data::{
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
};
use crate::aggregation::agg_req_with_accessor::get_aggs_with_segment_accessor_and_validate;
use crate::collector::{Collector, SegmentCollector};
use crate::index::SegmentReader;
use crate::{DocId, SegmentOrdinal, TantivyError};
@@ -22,7 +22,7 @@ pub const DEFAULT_MEMORY_LIMIT: u64 = 500_000_000;
/// The collector collects all aggregations by the underlying aggregation request.
pub struct AggregationCollector {
agg: Aggregations,
limits: AggregationLimitsGuard,
context: AggContextParams,
}
impl AggregationCollector {
@@ -30,8 +30,8 @@ impl AggregationCollector {
///
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
/// bucket limit)
pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self {
Self { agg, limits }
pub fn from_aggs(agg: Aggregations, context: AggContextParams) -> Self {
Self { agg, context }
}
}
@@ -45,7 +45,7 @@ impl AggregationCollector {
/// into the final `AggregationResults` via the `into_final_result()` method.
pub struct DistributedAggregationCollector {
agg: Aggregations,
limits: AggregationLimitsGuard,
context: AggContextParams,
}
impl DistributedAggregationCollector {
@@ -53,8 +53,8 @@ impl DistributedAggregationCollector {
///
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
/// bucket limit)
pub fn from_aggs(agg: Aggregations, limits: AggregationLimitsGuard) -> Self {
Self { agg, limits }
pub fn from_aggs(agg: Aggregations, context: AggContextParams) -> Self {
Self { agg, context }
}
}
@@ -72,7 +72,7 @@ impl Collector for DistributedAggregationCollector {
&self.agg,
reader,
segment_local_id,
&self.limits,
&self.context,
)
}
@@ -102,7 +102,7 @@ impl Collector for AggregationCollector {
&self.agg,
reader,
segment_local_id,
&self.limits,
&self.context,
)
}
@@ -115,7 +115,7 @@ impl Collector for AggregationCollector {
segment_fruits: Vec<<Self::Child as SegmentCollector>::Fruit>,
) -> crate::Result<Self::Fruit> {
let res = merge_fruits(segment_fruits)?;
res.into_final_result(self.agg.clone(), self.limits.clone())
res.into_final_result(self.agg.clone(), self.context.limits.clone())
}
}
@@ -135,7 +135,7 @@ fn merge_fruits(
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
pub struct AggregationSegmentCollector {
aggs_with_accessor: AggregationsWithAccessor,
aggs_with_accessor: AggregationsSegmentCtx,
agg_collector: BufAggregationCollector,
error: Option<TantivyError>,
}
@@ -147,14 +147,15 @@ impl AggregationSegmentCollector {
agg: &Aggregations,
reader: &SegmentReader,
segment_ordinal: SegmentOrdinal,
limits: &AggregationLimitsGuard,
context: &AggContextParams,
) -> crate::Result<Self> {
let mut aggs_with_accessor =
get_aggs_with_segment_accessor_and_validate(agg, reader, segment_ordinal, limits)?;
let mut agg_data =
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
let result =
BufAggregationCollector::new(build_segment_agg_collector(&mut aggs_with_accessor)?);
BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?);
Ok(AggregationSegmentCollector {
aggs_with_accessor,
aggs_with_accessor: agg_data,
agg_collector: result,
error: None,
})

View File

@@ -24,7 +24,9 @@ use super::metric::{
};
use super::segment_agg_result::AggregationLimitsGuard;
use super::{format_date, AggregationError, Key, SerializedKey};
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
use crate::aggregation::agg_result::{
AggregationResults, BucketEntries, BucketEntry, FilterBucketResult,
};
use crate::aggregation::bucket::TermsAggregationInternal;
use crate::aggregation::metric::CardinalityCollector;
use crate::TantivyError;
@@ -179,12 +181,17 @@ impl IntermediateAggregationResults {
}
/// Merge another intermediate aggregation result into this result.
///
/// The order of the values need to be the same on both results. This is ensured when the same
/// (key values) are present on the underlying `VecWithNames` struct.
pub fn merge_fruits(&mut self, other: IntermediateAggregationResults) -> crate::Result<()> {
for (left, right) in self.aggs_res.values_mut().zip(other.aggs_res.into_values()) {
left.merge_fruits(right)?;
pub fn merge_fruits(&mut self, mut other: IntermediateAggregationResults) -> crate::Result<()> {
for (key, left) in self.aggs_res.iter_mut() {
if let Some(key) = other.aggs_res.remove(key) {
left.merge_fruits(key)?;
}
}
// Move remainder of other aggs_res into self.
// Note: Currently we don't expect this to happen, as we create empty intermediate results
// via [IntermediateAggregationResults::empty_from_req].
for (key, value) in other.aggs_res {
self.aggs_res.insert(key, value);
}
Ok(())
}
@@ -241,11 +248,16 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
Cardinality(_) => IntermediateAggregationResult::Metric(
IntermediateMetricResult::Cardinality(CardinalityCollector::default()),
),
Filter(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Filter {
doc_count: 0,
sub_aggregations: IntermediateAggregationResults::default(),
}),
}
}
/// An aggregation is either a bucket or a metric.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[allow(clippy::large_enum_variant)]
pub enum IntermediateAggregationResult {
/// Bucket variant
Bucket(IntermediateBucketResult),
@@ -426,6 +438,13 @@ pub enum IntermediateBucketResult {
/// The term buckets
buckets: IntermediateTermBucketResult,
},
/// Filter aggregation - a single bucket with sub-aggregations
Filter {
/// Document count in the filter bucket
doc_count: u64,
/// Sub-aggregation results
sub_aggregations: IntermediateAggregationResults,
},
}
impl IntermediateBucketResult {
@@ -509,6 +528,18 @@ impl IntermediateBucketResult {
req.sub_aggregation(),
limits,
),
IntermediateBucketResult::Filter {
doc_count,
sub_aggregations,
} => {
// Convert sub-aggregation results to final format
let final_sub_aggregations = sub_aggregations
.into_final_result(req.sub_aggregation().clone(), limits.clone())?;
Ok(BucketResult::Filter(FilterBucketResult {
doc_count,
sub_aggregations: final_sub_aggregations,
}))
}
}
}
@@ -562,6 +593,19 @@ impl IntermediateBucketResult {
*buckets_left = buckets?;
}
(
IntermediateBucketResult::Filter {
doc_count: doc_count_left,
sub_aggregations: sub_aggs_left,
},
IntermediateBucketResult::Filter {
doc_count: doc_count_right,
sub_aggregations: sub_aggs_right,
},
) => {
*doc_count_left += doc_count_right;
sub_aggs_left.merge_fruits(sub_aggs_right)?;
}
(IntermediateBucketResult::Range(_), _) => {
panic!("try merge on different types")
}
@@ -571,6 +615,9 @@ impl IntermediateBucketResult {
(IntermediateBucketResult::Terms { .. }, _) => {
panic!("try merge on different types")
}
(IntermediateBucketResult::Filter { .. }, _) => {
panic!("try merge on different types")
}
}
Ok(())
}

View File

@@ -2,15 +2,13 @@ use std::collections::hash_map::DefaultHasher;
use std::hash::{BuildHasher, Hasher};
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::Dictionary;
use columnar::{Column, ColumnBlockAccessor, ColumnType, Dictionary, StrColumn};
use common::f64_to_u64;
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
use rustc_hash::FxHashSet;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
@@ -97,6 +95,32 @@ pub struct CardinalityAggregationReq {
pub missing: Option<Key>,
}
/// Contains all information required by the SegmentCardinalityCollector to perform the
/// cardinality aggregation on a segment.
pub struct CardinalityAggReqData {
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// The column_type of the field.
pub column_type: ColumnType,
/// The string dictionary column if the field is of type string.
pub str_dict_column: Option<StrColumn>,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_value_for_accessor: Option<u64>,
/// The column block accessor to access the fast field values.
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
/// The name of the aggregation.
pub name: String,
/// The aggregation request.
pub req: CardinalityAggregationReq,
}
impl CardinalityAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
impl CardinalityAggregationReq {
/// Creates a new [`CardinalityAggregationReq`] instance from a field name.
pub fn from_field_name(field_name: String) -> Self {
@@ -115,47 +139,44 @@ impl CardinalityAggregationReq {
pub(crate) struct SegmentCardinalityCollector {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
column_type: ColumnType,
accessor_idx: usize,
missing: Option<Key>,
}
impl SegmentCardinalityCollector {
pub fn from_req(column_type: ColumnType, accessor_idx: usize, missing: &Option<Key>) -> Self {
pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self {
Self {
cardinality: CardinalityCollector::new(column_type as u8),
entries: Default::default(),
column_type,
accessor_idx,
missing: missing.clone(),
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_accessor: &mut AggregationWithAccessor,
agg_data: &mut CardinalityAggReqData,
) {
if let Some(missing) = agg_accessor.missing_value_for_accessor {
agg_accessor.column_block_accessor.fetch_block_with_missing(
if let Some(missing) = agg_data.missing_value_for_accessor {
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&agg_accessor.accessor,
&agg_data.accessor,
missing,
);
} else {
agg_accessor
agg_data
.column_block_accessor
.fetch_block(docs, &agg_accessor.accessor);
.fetch_block(docs, &agg_data.accessor);
}
}
fn into_intermediate_metric_result(
mut self,
agg_with_accessor: &AggregationWithAccessor,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateMetricResult> {
if self.column_type == ColumnType::Str {
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
if req_data.column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let dict = agg_with_accessor
let dict = req_data
.str_dict_column
.as_ref()
.map(|el| el.dictionary())
@@ -180,10 +201,10 @@ impl SegmentCardinalityCollector {
})?;
if has_missing {
// Replace missing with the actual value provided
let missing_key = self
.missing
.as_ref()
.expect("Found sentinel value u64::MAX for term_ord but `missing` is not set");
let missing_key =
req_data.req.missing.as_ref().expect(
"Found sentinel value u64::MAX for term_ord but `missing` is not set",
);
match missing_key {
Key::Str(missing) => {
self.cardinality.sketch.insert_any(&missing);
@@ -209,13 +230,13 @@ impl SegmentCardinalityCollector {
impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
let name = req_data.name.to_string();
let intermediate_result = self.into_intermediate_metric_result(agg_with_accessor)?;
let intermediate_result = self.into_intermediate_metric_result(agg_data)?;
results.push(
name,
IntermediateAggregationResult::Metric(intermediate_result),
@@ -227,26 +248,26 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_with_accessor)
self.collect_block(&[doc], agg_data)
}
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.fetch_block_with_field(docs, bucket_agg_accessor);
let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx);
self.fetch_block_with_field(docs, req_data);
let col_block_accessor = &bucket_agg_accessor.column_block_accessor;
if self.column_type == ColumnType::Str {
let col_block_accessor = &req_data.column_block_accessor;
if req_data.column_type == ColumnType::Str {
for term_ord in col_block_accessor.iter_vals() {
self.entries.insert(term_ord);
}
} else if self.column_type == ColumnType::IpAddr {
let compact_space_accessor = bucket_agg_accessor
} else if req_data.column_type == ColumnType::IpAddr {
let compact_space_accessor = req_data
.accessor
.values
.clone()

View File

@@ -4,12 +4,11 @@ use std::mem;
use serde::{Deserialize, Serialize};
use super::*;
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::{DocId, TantivyError};
@@ -348,20 +347,20 @@ impl SegmentExtendedStatsCollector {
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
agg_accessor: &mut AggregationWithAccessor,
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = self.missing.as_ref() {
agg_accessor.column_block_accessor.fetch_block_with_missing(
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&agg_accessor.accessor,
&req_data.accessor,
*missing,
);
} else {
agg_accessor
req_data
.column_block_accessor
.fetch_block(docs, &agg_accessor.accessor);
.fetch_block(docs, &req_data.accessor);
}
for val in agg_accessor.column_block_accessor.iter_vals() {
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
}
@@ -372,10 +371,10 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
results.push(
name,
IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats(
@@ -390,12 +389,12 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor;
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
if let Some(missing) = self.missing {
let mut has_val = false;
for val in field.values_for_doc(doc) {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
has_val = true;
@@ -405,7 +404,7 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
.collect(f64_from_fastfield_u64(missing, &self.field_type));
}
} else {
for val in field.values_for_doc(doc) {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
}
@@ -418,10 +417,10 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let field = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.collect_block_with_field(docs, field);
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
Ok(())
}
}

View File

@@ -31,6 +31,7 @@ use std::collections::HashMap;
pub use average::*;
pub use cardinality::*;
use columnar::{Column, ColumnBlockAccessor, ColumnType};
pub use count::*;
pub use extended_stats::*;
pub use max::*;
@@ -44,6 +45,35 @@ pub use top_hits::*;
use crate::schema::OwnedValue;
/// Contains all information required by metric aggregations like avg, min, max, sum, stats,
/// extended_stats, count, percentiles.
#[repr(C)]
pub struct MetricAggReqData {
/// True if the field is of number or date type.
pub is_number_or_date_type: bool,
/// The type of the field.
pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// Used when converting to intermediate result
pub collecting_for: StatsType,
/// The missing value
pub missing: Option<f64>,
/// The name of the aggregation.
pub name: String,
}
impl MetricAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// Single-metric aggregations use this common result structure.
///
/// Main reason to wrap it in value is to match elasticsearch output structure.

View File

@@ -3,12 +3,11 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use super::*;
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::{DocId, TantivyError};
@@ -112,7 +111,8 @@ impl PercentilesAggregationReq {
&self.field
}
fn validate(&self) -> crate::Result<()> {
/// Validates the request parameters.
pub fn validate(&self) -> crate::Result<()> {
if let Some(percents) = self.percents.as_ref() {
let all_in_range = percents
.iter()
@@ -133,10 +133,8 @@ impl PercentilesAggregationReq {
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentPercentilesCollector {
field_type: ColumnType,
pub(crate) percentiles: PercentilesCollector,
pub(crate) accessor_idx: usize,
missing: Option<u64>,
}
#[derive(Clone, Serialize, Deserialize)]
@@ -231,43 +229,32 @@ impl PercentilesCollector {
}
impl SegmentPercentilesCollector {
pub fn from_req_and_validate(
req: &PercentilesAggregationReq,
field_type: ColumnType,
accessor_idx: usize,
) -> crate::Result<Self> {
req.validate()?;
let missing = req
.missing
.and_then(|val| f64_to_fastfield_u64(val, &field_type));
pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result<Self> {
Ok(Self {
field_type,
percentiles: PercentilesCollector::new(),
accessor_idx,
missing,
})
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
agg_accessor: &mut AggregationWithAccessor,
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = self.missing.as_ref() {
agg_accessor.column_block_accessor.fetch_block_with_missing(
if let Some(missing) = req_data.missing_u64.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&agg_accessor.accessor,
&req_data.accessor,
*missing,
);
} else {
agg_accessor
req_data
.column_block_accessor
.fetch_block(docs, &agg_accessor.accessor);
.fetch_block(docs, &req_data.accessor);
}
for val in agg_accessor.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
}
}
@@ -277,10 +264,10 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles);
results.push(
@@ -295,24 +282,24 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor;
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
if let Some(missing) = self.missing {
if let Some(missing) = req_data.missing_u64 {
let mut has_val = false;
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
has_val = true;
}
if !has_val {
self.percentiles
.collect(f64_from_fastfield_u64(missing, &self.field_type));
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
}
} else {
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
}
}
@@ -324,10 +311,10 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let field = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.collect_block_with_field(docs, field);
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
Ok(())
}
}

View File

@@ -3,12 +3,11 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use super::*;
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::{DocId, TantivyError};
@@ -166,74 +165,65 @@ impl IntermediateStats {
}
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) enum SegmentStatsType {
/// The type of stats aggregation to perform.
/// Note that not all stats types are supported in the stats aggregation.
#[derive(Clone, Copy, Debug)]
pub enum StatsType {
/// The average of the values.
Average,
/// The count of the values.
Count,
/// The maximum value.
Max,
/// The minimum value.
Min,
/// The stats (count, sum, min, max, avg) of the values.
Stats,
/// The extended stats (count, sum, min, max, avg, sum_of_squares, variance, std_deviation,
ExtendedStats(Option<f64>), // sigma
/// The sum of the values.
Sum,
/// The percentiles of the values.
Percentiles,
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentStatsCollector {
missing: Option<u64>,
field_type: ColumnType,
pub(crate) collecting_for: SegmentStatsType,
pub(crate) stats: IntermediateStats,
pub(crate) accessor_idx: usize,
val_cache: Vec<u64>,
}
impl SegmentStatsCollector {
pub fn from_req(
field_type: ColumnType,
collecting_for: SegmentStatsType,
accessor_idx: usize,
missing: Option<f64>,
) -> Self {
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
pub fn from_req(accessor_idx: usize) -> Self {
Self {
field_type,
collecting_for,
stats: IntermediateStats::default(),
accessor_idx,
missing,
val_cache: Default::default(),
}
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
agg_accessor: &mut AggregationWithAccessor,
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = self.missing.as_ref() {
agg_accessor.column_block_accessor.fetch_block_with_missing(
if let Some(missing) = req_data.missing_u64.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&agg_accessor.accessor,
&req_data.accessor,
*missing,
);
} else {
agg_accessor
req_data
.column_block_accessor
.fetch_block(docs, &agg_accessor.accessor);
.fetch_block(docs, &req_data.accessor);
}
if [
ColumnType::I64,
ColumnType::U64,
ColumnType::F64,
ColumnType::DateTime,
]
.contains(&self.field_type)
{
for val in agg_accessor.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
if req_data.is_number_or_date_type {
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
}
} else {
for _val in agg_accessor.column_block_accessor.iter_vals() {
for _val in req_data.column_block_accessor.iter_vals() {
// we ignore the value and simply record that we got something
self.stats.collect(0.0);
}
@@ -245,27 +235,28 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let req = agg_data.get_metric_req_data(self.accessor_idx);
let name = req.name.clone();
let intermediate_metric_result = match self.collecting_for {
SegmentStatsType::Average => {
let intermediate_metric_result = match req.collecting_for {
StatsType::Average => {
IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self))
}
SegmentStatsType::Count => {
StatsType::Count => {
IntermediateMetricResult::Count(IntermediateCount::from_collector(*self))
}
SegmentStatsType::Max => {
IntermediateMetricResult::Max(IntermediateMax::from_collector(*self))
}
SegmentStatsType::Min => {
IntermediateMetricResult::Min(IntermediateMin::from_collector(*self))
}
SegmentStatsType::Stats => IntermediateMetricResult::Stats(self.stats),
SegmentStatsType::Sum => {
IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self))
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)),
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)),
StatsType::Stats => IntermediateMetricResult::Stats(self.stats),
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)),
_ => {
return Err(TantivyError::InvalidArgument(format!(
"Unsupported stats type for stats aggregation: {:?}",
req.collecting_for
)))
}
};
@@ -281,23 +272,23 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let field = &agg_with_accessor.aggs.values[self.accessor_idx].accessor;
if let Some(missing) = self.missing {
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
if let Some(missing) = req_data.missing_u64 {
let mut has_val = false;
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
has_val = true;
}
if !has_val {
self.stats
.collect(f64_from_fastfield_u64(missing, &self.field_type));
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
}
} else {
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
}
}
@@ -309,10 +300,10 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let field = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.collect_block_with_field(docs, field);
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
Ok(())
}
}

View File

@@ -9,6 +9,7 @@ use serde::ser::SerializeMap;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use super::{TopHitsMetricResult, TopHitsVecEntry};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::bucket::Order;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateMetricResult,
@@ -18,6 +19,30 @@ use crate::aggregation::AggregationError;
use crate::collector::TopNComputer;
use crate::schema::OwnedValue;
use crate::{DocAddress, DocId, SegmentOrdinal};
// duplicate import removed; already imported above
/// Contains all information required by the TopHitsSegmentCollector to perform the
/// top_hits aggregation on a segment.
#[derive(Default)]
pub struct TopHitsAggReqData {
/// The accessors to access the fast field values.
pub accessors: Vec<(Column<u64>, ColumnType)>,
/// The accessors to access the fast field values for retrieving document fields.
pub value_accessors: HashMap<String, Vec<DynamicColumn>>,
/// The ordinal of the segment this request data is for.
pub segment_ordinal: SegmentOrdinal,
/// The name of the aggregation.
pub name: String,
/// The top_hits aggregation request.
pub req: TopHitsAggregationReq,
}
impl TopHitsAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
}
}
/// # Top Hits
///
@@ -566,23 +591,18 @@ impl TopHitsSegmentCollector {
impl SegmentAggregationCollector for TopHitsSegmentCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor,
agg_data: &AggregationsSegmentCtx,
results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
let value_accessors = &agg_with_accessor.aggs.values[self.accessor_idx].value_accessors;
let tophits_req = &agg_with_accessor.aggs.values[self.accessor_idx]
.agg
.agg
.as_top_hits()
.expect("aggregation request must be of type top hits");
let value_accessors = &req_data.value_accessors;
let intermediate_result = IntermediateMetricResult::TopHits(
self.into_top_hits_collector(value_accessors, tophits_req),
self.into_top_hits_collector(value_accessors, &req_data.req),
);
results.push(
name,
req_data.name.to_string(),
IntermediateAggregationResult::Metric(intermediate_result),
)
}
@@ -591,32 +611,22 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
fn collect(
&mut self,
doc_id: crate::DocId,
agg_with_accessor: &mut crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let tophits_req = &agg_with_accessor.aggs.values[self.accessor_idx]
.agg
.agg
.as_top_hits()
.expect("aggregation request must be of type top hits");
let accessors = &agg_with_accessor.aggs.values[self.accessor_idx].accessors;
self.collect_with(doc_id, tophits_req, accessors)?;
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
self.collect_with(doc_id, &req_data.req, &req_data.accessors)?;
Ok(())
}
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let tophits_req = &agg_with_accessor.aggs.values[self.accessor_idx]
.agg
.agg
.as_top_hits()
.expect("aggregation request must be of type top hits");
let accessors = &agg_with_accessor.aggs.values[self.accessor_idx].accessors;
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
// TODO: Consider getting fields with the column block accessor.
for doc in docs {
self.collect_with(*doc, tophits_req, accessors)?;
self.collect_with(*doc, &req_data.req, &req_data.accessors)?;
}
Ok(())
}

View File

@@ -127,9 +127,10 @@
//! [`AggregationResults`](agg_result::AggregationResults) via the
//! [`into_final_result`](intermediate_agg_result::IntermediateAggregationResults::into_final_result) method.
mod accessor_helpers;
mod agg_data;
mod agg_limits;
pub mod agg_req;
mod agg_req_with_accessor;
pub mod agg_result;
pub mod bucket;
mod buf_collector;
@@ -140,7 +141,6 @@ pub mod intermediate_agg_result;
pub mod metric;
mod segment_agg_result;
use std::collections::HashMap;
use std::fmt::Display;
#[cfg(test)]
@@ -160,6 +160,28 @@ use itertools::Itertools;
use serde::de::{self, Visitor};
use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::TokenizerManager;
/// Context parameters for aggregation execution
///
/// This struct holds shared resources needed during aggregation execution:
/// - `limits`: Memory and bucket limits for the aggregation
/// - `tokenizers`: TokenizerManager for parsing query strings in filter aggregations
#[derive(Clone, Default)]
pub struct AggContextParams {
/// Aggregation limits (memory and bucket count)
pub limits: AggregationLimitsGuard,
/// Tokenizer manager for query string parsing
pub tokenizers: TokenizerManager,
}
impl AggContextParams {
/// Create new aggregation context parameters
pub fn new(limits: AggregationLimitsGuard, tokenizers: TokenizerManager) -> Self {
Self { limits, tokenizers }
}
}
fn parse_str_into_f64<E: de::Error>(value: &str) -> Result<f64, E> {
let parsed = value
.parse::<f64>()
@@ -257,80 +279,6 @@ where D: Deserializer<'de> {
deserializer.deserialize_any(StringOrFloatVisitor)
}
/// Represents an associative array `(key => values)` in a very efficient manner.
#[derive(PartialEq, Serialize, Deserialize)]
pub(crate) struct VecWithNames<T> {
pub(crate) values: Vec<T>,
keys: Vec<String>,
}
impl<T: Clone> Clone for VecWithNames<T> {
fn clone(&self) -> Self {
Self {
values: self.values.clone(),
keys: self.keys.clone(),
}
}
}
impl<T> Default for VecWithNames<T> {
fn default() -> Self {
Self {
values: Default::default(),
keys: Default::default(),
}
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for VecWithNames<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_map().entries(self.iter()).finish()
}
}
impl<T> From<HashMap<String, T>> for VecWithNames<T> {
fn from(map: HashMap<String, T>) -> Self {
VecWithNames::from_entries(map.into_iter().collect_vec())
}
}
impl<T> VecWithNames<T> {
fn from_entries(mut entries: Vec<(String, T)>) -> Self {
// Sort to ensure order of elements match across multiple instances
entries.sort_by(|left, right| left.0.cmp(&right.0));
let mut data = Vec::with_capacity(entries.len());
let mut data_names = Vec::with_capacity(entries.len());
for entry in entries {
data_names.push(entry.0);
data.push(entry.1);
}
VecWithNames {
values: data,
keys: data_names,
}
}
fn iter(&self) -> impl Iterator<Item = (&str, &T)> + '_ {
self.keys().zip(self.values.iter())
}
fn keys(&self) -> impl Iterator<Item = &str> + '_ {
self.keys.iter().map(|key| key.as_str())
}
fn values_mut(&mut self) -> impl Iterator<Item = &mut T> + '_ {
self.values.iter_mut()
}
fn is_empty(&self) -> bool {
self.keys.is_empty()
}
fn len(&self) -> usize {
self.keys.len()
}
fn get(&self, name: &str) -> Option<&T> {
self.keys()
.position(|key| key == name)
.map(|pos| &self.values[pos])
}
}
/// The serialized key is used in a `HashMap`.
pub type SerializedKey = String;
@@ -464,7 +412,10 @@ mod tests {
query: Option<(&str, &str)>,
limits: AggregationLimitsGuard,
) -> crate::Result<Value> {
let collector = AggregationCollector::from_aggs(agg_req, limits);
let collector = AggregationCollector::from_aggs(
agg_req,
AggContextParams::new(limits, index.tokenizers().clone()),
);
let reader = index.reader()?;
let searcher = reader.searcher();

View File

@@ -6,48 +6,41 @@
use std::fmt::Debug;
pub(crate) use super::agg_limits::AggregationLimitsGuard;
use super::agg_req::AggregationVariants;
use super::agg_req_with_accessor::{AggregationWithAccessor, AggregationsWithAccessor};
use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector};
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::metric::{
AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation,
SegmentPercentilesCollector, SegmentStatsCollector, SegmentStatsType, StatsAggregation,
SumAggregation,
};
use crate::aggregation::bucket::TermMissingAgg;
use crate::aggregation::metric::{
CardinalityAggregationReq, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
TopHitsSegmentCollector,
};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug {
/// A SegmentAggregationCollector is used to collect aggregation results.
pub trait SegmentAggregationCollector: CollectorClone + Debug {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()>;
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()>;
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()>;
/// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`.
/// This method ensures those staged docs will be collected.
fn flush(&mut self, _agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
fn flush(&mut self, _agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
Ok(())
}
}
pub(crate) trait CollectorClone {
/// A helper trait to enable cloning of Box<dyn SegmentAggregationCollector>
pub trait CollectorClone {
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector>;
}
@@ -65,119 +58,6 @@ impl Clone for Box<dyn SegmentAggregationCollector> {
}
}
pub(crate) fn build_segment_agg_collector(
req: &mut AggregationsWithAccessor,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
// Single collector special case
if req.aggs.len() == 1 {
let req = &mut req.aggs.values[0];
let accessor_idx = 0;
return build_single_agg_segment_collector(req, accessor_idx);
}
let agg = GenericSegmentAggregationResultsCollector::from_req_and_validate(req)?;
Ok(Box::new(agg))
}
pub(crate) fn build_single_agg_segment_collector(
req: &mut AggregationWithAccessor,
accessor_idx: usize,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
use AggregationVariants::*;
match &req.agg.agg {
Terms(terms_req) => {
if req.accessors.is_empty() {
Ok(Box::new(SegmentTermCollector::from_req_and_validate(
terms_req,
&mut req.sub_aggregation,
req.field_type,
accessor_idx,
)?))
} else {
Ok(Box::new(TermMissingAgg::new(
accessor_idx,
&mut req.sub_aggregation,
)?))
}
}
Range(range_req) => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
range_req,
&mut req.sub_aggregation,
&mut req.limits,
req.field_type,
accessor_idx,
)?)),
Histogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
histogram.clone(),
&mut req.sub_aggregation,
req.field_type,
accessor_idx,
)?)),
DateHistogram(histogram) => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
histogram.to_histogram_req()?,
&mut req.sub_aggregation,
req.field_type,
accessor_idx,
)?)),
Average(AverageAggregation { missing, .. }) => {
Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Average,
accessor_idx,
*missing,
)))
}
Count(CountAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Count,
accessor_idx,
*missing,
))),
Max(MaxAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Max,
accessor_idx,
*missing,
))),
Min(MinAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Min,
accessor_idx,
*missing,
))),
Stats(StatsAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Stats,
accessor_idx,
*missing,
))),
ExtendedStats(ExtendedStatsAggregation { missing, sigma, .. }) => Ok(Box::new(
SegmentExtendedStatsCollector::from_req(req.field_type, *sigma, accessor_idx, *missing),
)),
Sum(SumAggregation { missing, .. }) => Ok(Box::new(SegmentStatsCollector::from_req(
req.field_type,
SegmentStatsType::Sum,
accessor_idx,
*missing,
))),
Percentiles(percentiles_req) => Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(
percentiles_req,
req.field_type,
accessor_idx,
)?,
)),
TopHits(top_hits_req) => Ok(Box::new(TopHitsSegmentCollector::from_req(
top_hits_req,
accessor_idx,
req.segment_ordinal,
))),
Cardinality(CardinalityAggregationReq { missing, .. }) => Ok(Box::new(
SegmentCardinalityCollector::from_req(req.field_type, accessor_idx, missing),
)),
}
}
#[derive(Clone, Default)]
/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which
/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one
@@ -197,11 +77,11 @@ impl Debug for GenericSegmentAggregationResultsCollector {
impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
for agg in self.aggs {
agg.add_intermediate_aggregation_result(agg_with_accessor, results)?;
agg.add_intermediate_aggregation_result(agg_data, results)?;
}
Ok(())
@@ -210,9 +90,9 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_with_accessor)?;
self.collect_block(&[doc], agg_data)?;
Ok(())
}
@@ -220,32 +100,19 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.collect_block(docs, agg_with_accessor)?;
collector.collect_block(docs, agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.flush(agg_with_accessor)?;
collector.flush(agg_data)?;
}
Ok(())
}
}
impl GenericSegmentAggregationResultsCollector {
pub(crate) fn from_req_and_validate(req: &mut AggregationsWithAccessor) -> crate::Result<Self> {
let aggs = req
.aggs
.values_mut()
.enumerate()
.map(|(accessor_idx, req)| build_single_agg_segment_collector(req, accessor_idx))
.collect::<crate::Result<Vec<Box<dyn SegmentAggregationCollector>>>>()?;
Ok(GenericSegmentAggregationResultsCollector { aggs })
}
}

View File

@@ -87,6 +87,17 @@ pub trait DocSet: Send {
/// length of the docset.
fn size_hint(&self) -> u32;
/// Returns a best-effort hint of the cost to consume the entire docset.
///
/// Consuming means calling advance until [`TERMINATED`] is returned.
/// The cost should be relative to the cost of driving a Term query,
/// which would be the number of documents in the DocSet.
///
/// By default this returns `size_hint()`.
fn cost(&self) -> u64 {
self.size_hint() as u64
}
/// Returns the number documents matching.
/// Calling this method consumes the `DocSet`.
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
@@ -134,6 +145,10 @@ impl DocSet for &mut dyn DocSet {
(**self).size_hint()
}
fn cost(&self) -> u64 {
(**self).cost()
}
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
(**self).count(alive_bitset)
}
@@ -169,6 +184,11 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
unboxed.size_hint()
}
fn cost(&self) -> u64 {
let unboxed: &TDocSet = self.borrow();
unboxed.cost()
}
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.count(alive_bitset)

View File

@@ -41,8 +41,6 @@ const COMPRESSION_BLOCK_SIZE: usize = BitPacker4x::BLOCK_LEN;
#[cfg(test)]
pub(crate) mod tests {
use std::iter;
use proptest::prelude::*;
use proptest::sample::select;

View File

@@ -667,12 +667,15 @@ mod bench {
.read_postings(&TERM_D, IndexRecordOption::Basic)
.unwrap()
.unwrap();
let mut intersection = Intersection::new(vec![
segment_postings_a,
segment_postings_b,
segment_postings_c,
segment_postings_d,
]);
let mut intersection = Intersection::new(
vec![
segment_postings_a,
segment_postings_b,
segment_postings_c,
segment_postings_d,
],
reader.searcher().num_docs() as u32,
);
while intersection.advance() != TERMINATED {}
});
}

View File

@@ -367,10 +367,14 @@ mod tests {
checkpoints
}
fn compute_checkpoints_manual(term_scorers: Vec<TermScorer>, n: usize) -> Vec<(DocId, Score)> {
fn compute_checkpoints_manual(
term_scorers: Vec<TermScorer>,
n: usize,
max_doc: u32,
) -> Vec<(DocId, Score)> {
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(n);
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
let mut scorer = BufferedUnionScorer::build(term_scorers, SumCombiner::default);
let mut scorer = BufferedUnionScorer::build(term_scorers, SumCombiner::default, max_doc);
let mut limit = Score::MIN;
loop {
@@ -478,7 +482,8 @@ mod tests {
for top_k in 1..4 {
let checkpoints_for_each_pruning =
compute_checkpoints_for_each_pruning(term_scorers.clone(), top_k);
let checkpoints_manual = compute_checkpoints_manual(term_scorers.clone(), top_k);
let checkpoints_manual =
compute_checkpoints_manual(term_scorers.clone(), top_k, 100_000);
assert_eq!(checkpoints_for_each_pruning.len(), checkpoints_manual.len());
for (&(left_doc, left_score), &(right_doc, right_score)) in checkpoints_for_each_pruning
.iter()

View File

@@ -39,9 +39,11 @@ where
))
}
/// num_docs is the number of documents in the segment.
fn scorer_union<TScoreCombiner>(
scorers: Vec<Box<dyn Scorer>>,
score_combiner_fn: impl Fn() -> TScoreCombiner,
num_docs: u32,
) -> SpecializedScorer
where
TScoreCombiner: ScoreCombiner,
@@ -68,6 +70,7 @@ where
return SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
scorers,
score_combiner_fn,
num_docs,
)));
}
}
@@ -75,16 +78,19 @@ where
SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
scorers,
score_combiner_fn,
num_docs,
)))
}
fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
scorer: SpecializedScorer,
score_combiner_fn: impl Fn() -> TScoreCombiner,
num_docs: u32,
) -> Box<dyn Scorer> {
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let union_scorer = BufferedUnionScorer::build(term_scorers, score_combiner_fn);
let union_scorer =
BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs);
Box::new(union_scorer)
}
SpecializedScorer::Other(scorer) => scorer,
@@ -151,6 +157,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
boost: Score,
score_combiner_fn: impl Fn() -> TComplexScoreCombiner,
) -> crate::Result<SpecializedScorer> {
let num_docs = reader.num_docs();
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
// Indicate how should clauses are combined with other clauses.
enum CombinationMethod {
@@ -167,11 +174,16 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
match self.minimum_number_should_match {
0 => CombinationMethod::Optional(scorer_union(should_scorers, &score_combiner_fn)),
1 => {
let scorer_union = scorer_union(should_scorers, &score_combiner_fn);
CombinationMethod::Required(scorer_union)
}
0 => CombinationMethod::Optional(scorer_union(
should_scorers,
&score_combiner_fn,
num_docs,
)),
1 => CombinationMethod::Required(scorer_union(
should_scorers,
&score_combiner_fn,
num_docs,
)),
n if num_of_should_scorers == n => {
// When num_of_should_scorers equals the number of should clauses,
// they are no different from must clauses.
@@ -200,21 +212,21 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
};
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::MustNot)
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default))
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default, num_docs))
.map(|specialized_scorer: SpecializedScorer| {
into_box_scorer(specialized_scorer, DoNothingCombiner::default)
into_box_scorer(specialized_scorer, DoNothingCombiner::default, num_docs)
});
let positive_scorer = match (should_opt, must_scorers) {
(CombinationMethod::Ignored, Some(must_scorers)) => {
SpecializedScorer::Other(intersect_scorers(must_scorers))
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
}
(CombinationMethod::Optional(should_scorer), Some(must_scorers)) => {
let must_scorer = intersect_scorers(must_scorers);
let must_scorer = intersect_scorers(must_scorers, num_docs);
if self.scoring_enabled {
SpecializedScorer::Other(Box::new(
RequiredOptionalScorer::<_, _, TScoreCombiner>::new(
must_scorer,
into_box_scorer(should_scorer, &score_combiner_fn),
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
),
))
} else {
@@ -222,8 +234,8 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
}
}
(CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => {
must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn));
SpecializedScorer::Other(intersect_scorers(must_scorers))
must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs));
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
}
(CombinationMethod::Ignored, None) => {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)))
@@ -233,7 +245,8 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
(CombinationMethod::Optional(should_scorer), None) => should_scorer,
};
if let Some(exclude_scorer) = exclude_scorer_opt {
let positive_scorer_boxed = into_box_scorer(positive_scorer, &score_combiner_fn);
let positive_scorer_boxed =
into_box_scorer(positive_scorer, &score_combiner_fn, num_docs);
Ok(SpecializedScorer::Other(Box::new(Exclude::new(
positive_scorer_boxed,
exclude_scorer,
@@ -246,6 +259,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombiner> {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let num_docs = reader.num_docs();
if self.weights.is_empty() {
Ok(Box::new(EmptyScorer))
} else if self.weights.len() == 1 {
@@ -258,12 +272,12 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
} else if self.scoring_enabled {
self.complex_scorer(reader, boost, &self.score_combiner_fn)
.map(|specialized_scorer| {
into_box_scorer(specialized_scorer, &self.score_combiner_fn)
into_box_scorer(specialized_scorer, &self.score_combiner_fn, num_docs)
})
} else {
self.complex_scorer(reader, boost, DoNothingCombiner::default)
.map(|specialized_scorer| {
into_box_scorer(specialized_scorer, DoNothingCombiner::default)
into_box_scorer(specialized_scorer, DoNothingCombiner::default, num_docs)
})
}
}
@@ -296,8 +310,11 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer =
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn);
let mut union_scorer = BufferedUnionScorer::build(
term_scorers,
&self.score_combiner_fn,
reader.num_docs(),
);
for_each_scorer(&mut union_scorer, callback);
}
SpecializedScorer::Other(mut scorer) => {
@@ -317,8 +334,11 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer =
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn);
let mut union_scorer = BufferedUnionScorer::build(
term_scorers,
&self.score_combiner_fn,
reader.num_docs(),
);
for_each_docset_buffered(&mut union_scorer, &mut buffer, callback);
}
SpecializedScorer::Other(mut scorer) => {

View File

@@ -117,6 +117,10 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
self.underlying.size_hint()
}
fn cost(&self) -> u64 {
self.underlying.cost()
}
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
self.underlying.count(alive_bitset)
}

View File

@@ -130,6 +130,10 @@ impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
fn size_hint(&self) -> u32 {
self.docset.size_hint()
}
fn cost(&self) -> u64 {
self.docset.cost()
}
}
impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> {

View File

@@ -70,6 +70,10 @@ impl<T: Scorer> DocSet for ScorerWrapper<T> {
fn size_hint(&self) -> u32 {
self.scorer.size_hint()
}
fn cost(&self) -> u64 {
self.scorer.cost()
}
}
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Disjunction<TScorer, TScoreCombiner> {
@@ -146,6 +150,14 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> DocSet
.max()
.unwrap_or(0u32)
}
fn cost(&self) -> u64 {
self.chains
.iter()
.map(|docset| docset.cost())
.max()
.unwrap_or(0u64)
}
}
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Scorer

View File

@@ -1,4 +1,5 @@
use crate::docset::{DocSet, TERMINATED};
use crate::query::size_hint::estimate_intersection;
use crate::query::term_query::TermScorer;
use crate::query::{EmptyScorer, Scorer};
use crate::{DocId, Score};
@@ -11,14 +12,18 @@ use crate::{DocId, Score};
/// For better performance, the function uses a
/// specialized implementation if the two
/// shortest scorers are `TermScorer`s.
pub fn intersect_scorers(mut scorers: Vec<Box<dyn Scorer>>) -> Box<dyn Scorer> {
pub fn intersect_scorers(
mut scorers: Vec<Box<dyn Scorer>>,
num_docs_segment: u32,
) -> Box<dyn Scorer> {
if scorers.is_empty() {
return Box::new(EmptyScorer);
}
if scorers.len() == 1 {
return scorers.pop().unwrap();
}
scorers.sort_by_key(|scorer| scorer.size_hint());
// Order by estimated cost to drive each scorer.
scorers.sort_by_key(|scorer| scorer.cost());
let doc = go_to_first_doc(&mut scorers[..]);
if doc == TERMINATED {
return Box::new(EmptyScorer);
@@ -34,12 +39,14 @@ pub fn intersect_scorers(mut scorers: Vec<Box<dyn Scorer>>) -> Box<dyn Scorer> {
left: *(left.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
right: *(right.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
others: scorers,
num_docs: num_docs_segment,
});
}
Box::new(Intersection {
left,
right,
others: scorers,
num_docs: num_docs_segment,
})
}
@@ -48,6 +55,7 @@ pub struct Intersection<TDocSet: DocSet, TOtherDocSet: DocSet = Box<dyn Scorer>>
left: TDocSet,
right: TDocSet,
others: Vec<TOtherDocSet>,
num_docs: u32,
}
fn go_to_first_doc<TDocSet: DocSet>(docsets: &mut [TDocSet]) -> DocId {
@@ -66,10 +74,11 @@ fn go_to_first_doc<TDocSet: DocSet>(docsets: &mut [TDocSet]) -> DocId {
}
impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
pub(crate) fn new(mut docsets: Vec<TDocSet>) -> Intersection<TDocSet, TDocSet> {
/// num_docs is the number of documents in the segment.
pub(crate) fn new(mut docsets: Vec<TDocSet>, num_docs: u32) -> Intersection<TDocSet, TDocSet> {
let num_docsets = docsets.len();
assert!(num_docsets >= 2);
docsets.sort_by_key(|docset| docset.size_hint());
docsets.sort_by_key(|docset| docset.cost());
go_to_first_doc(&mut docsets);
let left = docsets.remove(0);
let right = docsets.remove(0);
@@ -77,6 +86,7 @@ impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
left,
right,
others: docsets,
num_docs,
}
}
}
@@ -141,7 +151,19 @@ impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOt
}
fn size_hint(&self) -> u32 {
self.left.size_hint()
estimate_intersection(
[self.left.size_hint(), self.right.size_hint()]
.into_iter()
.chain(self.others.iter().map(DocSet::size_hint)),
self.num_docs,
)
}
fn cost(&self) -> u64 {
// What's the best way to compute the cost of an intersection?
// For now we take the cost of the docset driver, which is the first docset.
// If there are docsets that are bad at skipping, they should also influence the cost.
self.left.cost()
}
}
@@ -169,7 +191,7 @@ mod tests {
{
let left = VecDocSet::from(vec![1, 3, 9]);
let right = VecDocSet::from(vec![3, 4, 9, 18]);
let mut intersection = Intersection::new(vec![left, right]);
let mut intersection = Intersection::new(vec![left, right], 10);
assert_eq!(intersection.doc(), 3);
assert_eq!(intersection.advance(), 9);
assert_eq!(intersection.doc(), 9);
@@ -179,7 +201,7 @@ mod tests {
let a = VecDocSet::from(vec![1, 3, 9]);
let b = VecDocSet::from(vec![3, 4, 9, 18]);
let c = VecDocSet::from(vec![1, 5, 9, 111]);
let mut intersection = Intersection::new(vec![a, b, c]);
let mut intersection = Intersection::new(vec![a, b, c], 10);
assert_eq!(intersection.doc(), 9);
assert_eq!(intersection.advance(), TERMINATED);
}
@@ -189,7 +211,7 @@ mod tests {
fn test_intersection_zero() {
let left = VecDocSet::from(vec![0]);
let right = VecDocSet::from(vec![0]);
let mut intersection = Intersection::new(vec![left, right]);
let mut intersection = Intersection::new(vec![left, right], 10);
assert_eq!(intersection.doc(), 0);
assert_eq!(intersection.advance(), TERMINATED);
}
@@ -198,7 +220,7 @@ mod tests {
fn test_intersection_skip() {
let left = VecDocSet::from(vec![0, 1, 2, 4]);
let right = VecDocSet::from(vec![2, 5]);
let mut intersection = Intersection::new(vec![left, right]);
let mut intersection = Intersection::new(vec![left, right], 10);
assert_eq!(intersection.seek(2), 2);
assert_eq!(intersection.doc(), 2);
}
@@ -209,7 +231,7 @@ mod tests {
|| {
let left = VecDocSet::from(vec![4]);
let right = VecDocSet::from(vec![2, 5]);
Box::new(Intersection::new(vec![left, right]))
Box::new(Intersection::new(vec![left, right], 10))
},
vec![0, 2, 4, 5, 6],
);
@@ -219,19 +241,22 @@ mod tests {
let mut right = VecDocSet::from(vec![2, 5, 10]);
left.advance();
right.advance();
Box::new(Intersection::new(vec![left, right]))
Box::new(Intersection::new(vec![left, right], 10))
},
vec![0, 1, 2, 3, 4, 5, 6, 7, 10, 11],
);
test_skip_against_unoptimized(
|| {
Box::new(Intersection::new(vec![
VecDocSet::from(vec![1, 4, 5, 6]),
VecDocSet::from(vec![1, 2, 5, 6]),
VecDocSet::from(vec![1, 4, 5, 6]),
VecDocSet::from(vec![1, 5, 6]),
VecDocSet::from(vec![2, 4, 5, 7, 8]),
]))
Box::new(Intersection::new(
vec![
VecDocSet::from(vec![1, 4, 5, 6]),
VecDocSet::from(vec![1, 2, 5, 6]),
VecDocSet::from(vec![1, 4, 5, 6]),
VecDocSet::from(vec![1, 5, 6]),
VecDocSet::from(vec![2, 4, 5, 7, 8]),
],
10,
))
},
vec![0, 1, 2, 3, 4, 5, 6, 7, 10, 11],
);
@@ -242,7 +267,7 @@ mod tests {
let a = VecDocSet::from(vec![1, 3]);
let b = VecDocSet::from(vec![1, 4]);
let c = VecDocSet::from(vec![3, 9]);
let intersection = Intersection::new(vec![a, b, c]);
let intersection = Intersection::new(vec![a, b, c], 10);
assert_eq!(intersection.doc(), TERMINATED);
}
}

View File

@@ -23,6 +23,7 @@ mod regex_query;
mod reqopt_scorer;
mod scorer;
mod set_query;
mod size_hint;
mod term_query;
mod union;
mod weight;

View File

@@ -200,6 +200,10 @@ impl<TPostings: Postings> DocSet for PhrasePrefixScorer<TPostings> {
fn size_hint(&self) -> u32 {
self.phrase_scorer.size_hint()
}
fn cost(&self) -> u64 {
self.phrase_scorer.cost()
}
}
impl<TPostings: Postings> Scorer for PhrasePrefixScorer<TPostings> {

View File

@@ -368,6 +368,7 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
slop: u32,
offset: usize,
) -> PhraseScorer<TPostings> {
let num_docs = fieldnorm_reader.num_docs();
let max_offset = term_postings_with_offset
.iter()
.map(|&(offset, _)| offset)
@@ -382,7 +383,7 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
})
.collect::<Vec<_>>();
let mut scorer = PhraseScorer {
intersection_docset: Intersection::new(postings_with_offsets),
intersection_docset: Intersection::new(postings_with_offsets, num_docs),
num_terms: num_docsets,
left_positions: Vec::with_capacity(100),
right_positions: Vec::with_capacity(100),
@@ -535,6 +536,15 @@ impl<TPostings: Postings> DocSet for PhraseScorer<TPostings> {
fn size_hint(&self) -> u32 {
self.intersection_docset.size_hint()
}
/// Returns a best-effort hint of the
/// cost to drive the docset.
fn cost(&self) -> u64 {
// Evaluating phrase matches is generally more expensive than simple term matches,
// as it requires loading and comparing positions. Use a conservative multiplier
// based on the number of terms.
self.intersection_docset.size_hint() as u64 * 10 * self.num_terms as u64
}
}
impl<TPostings: Postings> Scorer for PhraseScorer<TPostings> {

View File

@@ -176,6 +176,14 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> DocSet for RangeDocSe
fn size_hint(&self) -> u32 {
self.column.num_docs()
}
/// Returns a best-effort hint of the
/// cost to drive the docset.
fn cost(&self) -> u64 {
// Advancing the docset is relatively expensive since it scans the column.
// Keep cost relative to a term query driver; use num_docs as baseline.
self.column.num_docs() as u64
}
}
#[cfg(test)]

View File

@@ -63,6 +63,10 @@ where
fn size_hint(&self) -> u32 {
self.req_scorer.size_hint()
}
fn cost(&self) -> u64 {
self.req_scorer.cost()
}
}
impl<TReqScorer, TOptScorer, TScoreCombiner> Scorer

141
src/query/size_hint.rs Normal file
View File

@@ -0,0 +1,141 @@
/// Computes the estimated number of documents in the intersection of multiple docsets
/// given their sizes.
///
/// # Arguments
/// * `docset_sizes` - An iterator over the sizes of the docsets (number of documents in each set).
/// * `max_docs` - The maximum number of docs that can hit, usually number of documents in the
/// segment.
///
/// # Returns
/// The estimated number of documents in the intersection.
pub fn estimate_intersection<I>(mut docset_sizes: I, max_docs: u32) -> u32
where I: Iterator<Item = u32> {
if max_docs == 0u32 {
return 0u32;
}
// Terms tend to be not really randomly distributed.
// This factor is used to adjust the estimate.
let mut co_loc_factor: f64 = 1.3;
let mut intersection_estimate = match docset_sizes.next() {
Some(first_size) => first_size as f64,
None => return 0, // No docsets provided, so return 0.
};
let mut smallest_docset_size = intersection_estimate;
// Assuming random distribution of terms, the probability of a document being in the
// intersection
for size in docset_sizes {
// Diminish the co-location factor for each additional set, or we will overestimate.
co_loc_factor = (co_loc_factor - 0.1).max(1.0);
intersection_estimate *= (size as f64 / max_docs as f64) * co_loc_factor;
smallest_docset_size = smallest_docset_size.min(size as f64);
}
intersection_estimate.round().min(smallest_docset_size) as u32
}
/// Computes the estimated number of documents in the union of multiple docsets
/// given their sizes.
///
/// # Arguments
/// * `docset_sizes` - An iterator over the sizes of the docsets (number of documents in each set).
/// * `max_docs` - The maximum number of docs that can hit, usually number of documents in the
/// segment.
///
/// # Returns
/// The estimated number of documents in the union.
pub fn estimate_union<I>(docset_sizes: I, max_docs: u32) -> u32
where I: Iterator<Item = u32> {
// Terms tend to be not really randomly distributed.
// This factor is used to adjust the estimate.
// Unlike intersection, the co-location reduces the estimate.
let co_loc_factor = 0.8;
// The approach for union is to compute the probability of a document not being in any of the
// sets
let mut not_in_any_set_prob = 1.0;
// Assuming random distribution of terms, the probability of a document being in the
// union is the complement of the probability of it not being in any of the sets.
for size in docset_sizes {
let prob_in_set = (size as f64 / max_docs as f64) * co_loc_factor;
not_in_any_set_prob *= 1.0 - prob_in_set;
}
let union_estimate = (max_docs as f64 * (1.0 - not_in_any_set_prob)).round();
union_estimate.min(max_docs as f64) as u32
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_estimate_intersection_small1() {
let docset_sizes = &[500, 1000];
let n = 10_000;
let result = estimate_intersection(docset_sizes.iter().copied(), n);
assert_eq!(result, 60);
}
#[test]
fn test_estimate_intersection_small2() {
let docset_sizes = &[500, 1000, 1500];
let n = 10_000;
let result = estimate_intersection(docset_sizes.iter().copied(), n);
assert_eq!(result, 10);
}
#[test]
fn test_estimate_intersection_large_values() {
let docset_sizes = &[100_000, 50_000, 30_000];
let n = 1_000_000;
let result = estimate_intersection(docset_sizes.iter().copied(), n);
assert_eq!(result, 198);
}
#[test]
fn test_estimate_union_small() {
let docset_sizes = &[500, 1000, 1500];
let n = 10000;
let result = estimate_union(docset_sizes.iter().copied(), n);
assert_eq!(result, 2228);
}
#[test]
fn test_estimate_union_large_values() {
let docset_sizes = &[100000, 50000, 30000];
let n = 1000000;
let result = estimate_union(docset_sizes.iter().copied(), n);
assert_eq!(result, 137997);
}
#[test]
fn test_estimate_intersection_large() {
let docset_sizes: Vec<_> = (0..10).map(|_| 4_000_000).collect();
let n = 5_000_000;
let result = estimate_intersection(docset_sizes.iter().copied(), n);
// Check that it doesn't overflow and returns a reasonable result
assert_eq!(result, 708_670);
}
#[test]
fn test_estimate_intersection_overflow_safety() {
let docset_sizes: Vec<_> = (0..100).map(|_| 4_000_000).collect();
let n = 5_000_000;
let result = estimate_intersection(docset_sizes.iter().copied(), n);
// Check that it doesn't overflow and returns a reasonable result
assert_eq!(result, 0);
}
#[test]
fn test_estimate_union_overflow_safety() {
let docset_sizes: Vec<_> = (0..100).map(|_| 1_000_000).collect();
let n = 20_000_000;
let result = estimate_union(docset_sizes.iter().copied(), n);
// Check that it doesn't overflow and returns a reasonable result
assert_eq!(result, 19_662_594);
}
}

View File

@@ -101,7 +101,7 @@ impl TermQuery {
EnableScoring::Enabled {
statistics_provider,
..
} => Bm25Weight::for_terms(statistics_provider, std::slice::from_ref(&self.term))?,
} => Bm25Weight::for_terms(statistics_provider, &[self.term.clone()])?,
EnableScoring::Disabled { .. } => {
Bm25Weight::new(Explanation::new("<no score>", 1.0f32), 1.0f32)
}

View File

@@ -2,6 +2,7 @@ use common::TinySet;
use crate::docset::{DocSet, TERMINATED};
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::size_hint::estimate_union;
use crate::query::Scorer;
use crate::{DocId, Score};
@@ -50,6 +51,8 @@ pub struct BufferedUnionScorer<TScorer, TScoreCombiner = DoNothingCombiner> {
doc: DocId,
/// Combined score for current `doc` as produced by `TScoreCombiner`.
score: Score,
/// Number of documents in the segment.
num_docs: u32,
}
fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
@@ -78,9 +81,11 @@ fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
}
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer, TScoreCombiner> {
/// num_docs is the number of documents in the segment.
pub(crate) fn build(
docsets: Vec<TScorer>,
score_combiner_fn: impl FnOnce() -> TScoreCombiner,
num_docs: u32,
) -> BufferedUnionScorer<TScorer, TScoreCombiner> {
let non_empty_docsets: Vec<TScorer> = docsets
.into_iter()
@@ -94,6 +99,7 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
window_start_doc: 0,
doc: 0,
score: 0.0,
num_docs,
};
if union.refill() {
union.advance();
@@ -218,11 +224,11 @@ where
}
fn size_hint(&self) -> u32 {
self.docsets
.iter()
.map(|docset| docset.size_hint())
.max()
.unwrap_or(0u32)
estimate_union(self.docsets.iter().map(DocSet::size_hint), self.num_docs)
}
fn cost(&self) -> u64 {
self.docsets.iter().map(|docset| docset.cost()).sum()
}
fn count_including_deleted(&mut self) -> u32 {

View File

@@ -27,11 +27,17 @@ mod tests {
docs_list.iter().cloned().map(VecDocSet::from)
}
fn union_from_docs_list(docs_list: &[Vec<DocId>]) -> Box<dyn DocSet> {
let max_doc = docs_list
.iter()
.flat_map(|docs| docs.iter().copied())
.max()
.unwrap_or(0);
Box::new(BufferedUnionScorer::build(
vec_doc_set_from_docs_list(docs_list)
.map(|docset| ConstScorer::new(docset, 1.0))
.collect::<Vec<ConstScorer<VecDocSet>>>(),
DoNothingCombiner::default,
max_doc,
))
}
@@ -273,6 +279,7 @@ mod bench {
.map(|docset| ConstScorer::new(docset, 1.0))
.collect::<Vec<_>>(),
DoNothingCombiner::default,
100_000,
);
while v.doc() != TERMINATED {
v.advance();
@@ -294,6 +301,7 @@ mod bench {
.map(|docset| ConstScorer::new(docset, 1.0))
.collect::<Vec<_>>(),
DoNothingCombiner::default,
100_000,
);
while v.doc() != TERMINATED {
v.advance();

View File

@@ -99,6 +99,10 @@ impl<TDocSet: DocSet> DocSet for SimpleUnion<TDocSet> {
.unwrap_or(0u32)
}
fn cost(&self) -> u64 {
self.docsets.iter().map(|docset| docset.cost()).sum()
}
fn count_including_deleted(&mut self) -> u32 {
if self.doc == TERMINATED {
return 0u32;

View File

@@ -342,7 +342,7 @@ mod tests {
fn test_pack() -> crate::Result<()> {
let mut store_writer = TermInfoStoreWriter::new();
let mut term_infos = vec![];
let offset = |i| (i * 13 + i * i);
let offset = |i| i * 13 + i * i;
for i in 0usize..1000usize {
let term_info = TermInfo {
doc_freq: i as u32,