Compare commits

...

25 Commits

Author SHA1 Message Date
Pascal Seitz
3ddca31292 simplify fetch block in column_block_accessor 2025-12-29 10:56:50 +01:00
Pascal Seitz
87fe3a311f share column block accessor 2025-12-12 16:54:44 +08:00
Pascal Seitz
71dc08424c add comment 2025-12-12 16:19:06 +08:00
Pascal Seitz
15913446b8 cleanup
remove clone
move data in term req, single doc opt for stats
2025-12-10 14:34:43 +08:00
Pascal Seitz
78bd3826dc remove stacktrace bloat, use &mut helper
increase cache to 2048
2025-12-08 10:35:23 +08:00
Pascal Seitz
1b56487307 specialize columntype in stats 2025-12-08 10:20:09 +08:00
Pascal Seitz
030554d544 use radix map, fix prepare_max_bucket
use paged term map in term agg
use special no sub agg term map impl
2025-12-08 10:20:09 +08:00
Pascal Seitz
c852bac532 reduce dynamic dispatch, faster term agg 2025-12-08 10:20:09 +08:00
Pascal Seitz
2ce4da8b66 one collector per agg request instead per bucket
In this refactoring a collector knows in which bucket of the parent
their data is in. This allows to convert the previous approach of one
collector per bucket to one collector per request.

low card bucket optimization
2025-12-08 10:20:09 +08:00
Pascal Seitz
0dd6a958f8 add more tests for new collection type 2025-12-08 10:20:09 +08:00
Pascal Seitz
254314a4a3 improve bench 2025-12-07 10:05:30 +08:00
PSeitz
b2f99c6217 add term->histogram benchmark (#2758)
* add term->histogram benchmark

* add more term aggs

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-12-04 02:29:37 +01:00
PSeitz
76de5bab6f fix unsafe warnings (#2757) 2025-12-03 20:15:21 +08:00
rustmailer
b7eb31162b docs: add usage example to README (#2743) 2025-12-02 21:56:57 +01:00
Paul Masurel
63c66005db Lazy scorers (#2726)
* Refactoring of the score tweaker into `SortKeyComputer`s to unlock two features.

- Allow lazy evaluation of score. As soon as we identified that a doc won't
reach the topK threshold, we can stop the evaluation.
- Allow for a different segment level score, segment level score and their conversion.

This PR breaks public API, but fixing code is straightforward.

* Bumping tantivy version

---------

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-12-01 15:38:57 +01:00
Paul Masurel
7d513a44c5 Added some benchmark for top K by a fast field (#2754)
Also removed query parsing from the bench code.

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-12-01 14:58:29 +01:00
Stu Hood
ca87fcd454 Implement collect_block for Collectors which wrap other Collectors (#2727)
* Implement `collect_block` for tuple Collectors, and for MultiCollector.

* Two more.
2025-12-01 12:26:29 +01:00
Ang
08a92675dc Fix typos again (#2753)
Found via `codespell -S benches,stopwords.rs -L
womens,parth,abd,childs,ond,ser,ue,mot,hel,atleast,pris,claus,allo`
2025-12-01 12:15:41 +01:00
Raphaël Cohen
f7f4b354d6 fix: Handle phrase prefixed with star (#2751)
Signed-off-by: Darkheir <raphael.cohen@sekoia.io>
2025-12-01 11:43:25 +01:00
Paul Masurel
25d44fcec8 Revert "remove unused columnar api (#2742)" (#2748)
* Revert "remove unused columnar api (#2742)"

This reverts commit 8725594d47.

* Clippy comment + removing fill_vals

---------

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-11-26 17:44:02 +01:00
PSeitz-dd
842fe9295f split Term in Term and IndexingTerm (#2744)
* split Term in Term and IndexingTerm

* add append_json_path to JsonTermSerializer
2025-11-26 16:48:59 +01:00
Paul Masurel
f88b7200b2 Optimization when posting list are saturated. (#2745)
* Optimization when posting list are saturated.

If a posting list doc freq is the segment reader's
max_doc, and if scoring does not matter, we can replace it
by a AllScorer.

In turn, in a boolean query, we can dismiss  all scorers and
empty scorers, to accelerate the request.

* Added range query optimization

* CR comment

* CR comments

* CR comment

---------

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-11-26 15:50:57 +01:00
PSeitz-dd
8725594d47 remove unused columnar api (#2742) 2025-11-21 18:07:25 +01:00
PSeitz
43a784671a clippy (#2741)
Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-11-21 18:07:03 +01:00
Paul Masurel
c363bbd23d Optimize term aggregation with low cardinality + some refactoring (#2740)
This introduce an optimization of top level term aggregation on field with a low cardinality.

We then use a Vec as the underlying map.
In addition, we buffer subaggregations.

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
Co-authored-by: Paul Masurel <paul@quickwit.io>
2025-11-21 14:46:29 +01:00
114 changed files with 6054 additions and 2931 deletions

View File

@@ -78,7 +78,7 @@ This will slightly increase space and access time. [#2439](https://github.com/qu
- **Store DateTime as nanoseconds in doc store** DateTime in the doc store was truncated to microseconds previously. This removes this truncation, while still keeping backwards compatibility. [#2486](https://github.com/quickwit-oss/tantivy/pull/2486)(@PSeitz)
- **Performace/Memory**
- **Performance/Memory**
- lift clauses in LogicalAst for optimized ast during execution [#2449](https://github.com/quickwit-oss/tantivy/pull/2449)(@PSeitz)
- Use Vec instead of BTreeMap to back OwnedValue object [#2364](https://github.com/quickwit-oss/tantivy/pull/2364)(@fulmicoton)
- Replace TantivyDocument with CompactDoc. CompactDoc is much smaller and provides similar performance. [#2402](https://github.com/quickwit-oss/tantivy/pull/2402)(@PSeitz)

View File

@@ -1,6 +1,6 @@
[package]
name = "tantivy"
version = "0.25.0"
version = "0.26.0"
authors = ["Paul Masurel <paul.masurel@gmail.com>"]
license = "MIT"
categories = ["database-implementations", "data-structures"]

View File

@@ -123,6 +123,7 @@ You can also find other bindings on [GitHub](https://github.com/search?q=tantivy
- [seshat](https://github.com/matrix-org/seshat/): A matrix message database/indexer
- [tantiny](https://github.com/baygeldin/tantiny): Tiny full-text search for Ruby
- [lnx](https://github.com/lnx-search/lnx): adaptable, typo tolerant search engine with a REST API
- [Bichon](https://github.com/rustmailer/bichon): A lightweight, high-performance Rust email archiver with WebUI
- and [more](https://github.com/search?q=tantivy)!
### On average, how much faster is Tantivy compared to Lucene?

View File

@@ -10,7 +10,7 @@ rename FastFieldReaders::open to load
remove fast field reader
find a way to unify the two DateTime.
readd type check in the filter wrapper
re-add type check in the filter wrapper
add unit test on columnar list columns.

View File

@@ -1,5 +1,6 @@
use binggan::plugins::PeakMemAllocPlugin;
use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM};
use rand::distributions::WeightedIndex;
use rand::prelude::SliceRandom;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
@@ -53,25 +54,33 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, stats_f64);
register!(group, extendedstats_f64);
register!(group, percentiles_f64);
register!(group, terms_few);
register!(group, terms_many);
register!(group, terms_7);
register!(group, terms_all_unique);
register!(group, terms_150_000);
register!(group, terms_many_top_1000);
register!(group, terms_many_order_by_term);
register!(group, terms_many_with_top_hits);
register!(group, terms_all_unique_with_avg_sub_agg);
register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_status_with_avg_sub_agg);
register!(group, terms_status_with_histogram);
register!(group, terms_zipf_1000);
register!(group, terms_zipf_1000_with_histogram);
register!(group, terms_zipf_1000_with_avg_sub_agg);
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
register!(group, cardinality_agg);
register!(group, terms_few_with_cardinality_agg);
register!(group, terms_status_with_cardinality_agg);
register!(group, range_agg);
register!(group, range_agg_with_avg_sub_agg);
register!(group, range_agg_with_term_agg_few);
register!(group, range_agg_with_term_agg_status);
register!(group, range_agg_with_term_agg_many);
register!(group, histogram);
register!(group, histogram_hard_bounds);
register!(group, histogram_with_avg_sub_agg);
register!(group, histogram_with_term_agg_few);
register!(group, histogram_with_term_agg_status);
register!(group, avg_and_range_with_avg_sub_agg);
// Filter aggregation benchmarks
@@ -130,12 +139,12 @@ fn extendedstats_f64(index: &Index) {
}
fn percentiles_f64(index: &Index) {
let agg_req = json!({
"mypercentiles": {
"percentiles": {
"field": "score_f64",
"percents": [ 95, 99, 99.9 ]
"mypercentiles": {
"percentiles": {
"field": "score_f64",
"percents": [ 95, 99, 99.9 ]
}
}
}
});
execute_agg(index, agg_req);
}
@@ -150,10 +159,10 @@ fn cardinality_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_few_with_cardinality_agg(index: &Index) {
fn terms_status_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms" },
"terms": { "field": "text_few_terms_status" },
"aggs": {
"cardinality": {
"cardinality": {
@@ -166,13 +175,20 @@ fn terms_few_with_cardinality_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_few(index: &Index) {
fn terms_7(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms" } },
"my_texts": { "terms": { "field": "text_few_terms_status" } },
});
execute_agg(index, agg_req);
}
fn terms_many(index: &Index) {
fn terms_all_unique(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_all_unique_terms" } },
});
execute_agg(index, agg_req);
}
fn terms_150_000(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_many_terms" } },
});
@@ -220,6 +236,72 @@ fn terms_many_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_all_unique_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_all_unique_terms" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_status_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_1000_terms_zipf" } },
});
execute_agg(index, agg_req);
}
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -275,7 +357,7 @@ fn range_agg_with_avg_sub_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn range_agg_with_term_agg_few(index: &Index) {
fn range_agg_with_term_agg_status(index: &Index) {
let agg_req = json!({
"rangef64": {
"range": {
@@ -290,7 +372,7 @@ fn range_agg_with_term_agg_few(index: &Index) {
]
},
"aggs": {
"my_texts": { "terms": { "field": "text_few_terms" } },
"my_texts": { "terms": { "field": "text_few_terms_status" } },
}
},
});
@@ -346,12 +428,12 @@ fn histogram_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn histogram_with_term_agg_few(index: &Index) {
fn histogram_with_term_agg_status(index: &Index) {
let agg_req = json!({
"rangef64": {
"histogram": { "field": "score_f64", "interval": 10 },
"aggs": {
"my_texts": { "terms": { "field": "text_few_terms" } }
"my_texts": { "terms": { "field": "text_few_terms_status" } }
}
}
});
@@ -396,6 +478,13 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
}
fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
// Flag to use existing index
let reuse_index = std::env::var("REUSE_AGG_BENCH_INDEX").is_ok();
if reuse_index && std::path::Path::new("agg_bench").exists() {
return Index::open_in_dir("agg_bench");
}
// crreate dir
std::fs::create_dir_all("agg_bench")?;
let mut schema_builder = Schema::builder();
let text_fieldtype = tantivy::schema::TextOptions::default()
.set_indexing_options(
@@ -404,20 +493,47 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
.set_stored();
let text_field = schema_builder.add_text_field("text", text_fieldtype);
let json_field = schema_builder.add_json_field("json", FAST);
let text_field_all_unique_terms =
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
let text_field_few_terms_status =
schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
let text_field_1000_terms_zipf =
schema_builder.add_text_field("text_1000_terms_zipf", STRING | FAST);
let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast();
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
let index = Index::create_from_tempdir(schema_builder.build())?;
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"];
// use tmp dir
let index = if reuse_index {
Index::create_in_dir("agg_bench", schema_builder.build())?
} else {
Index::create_from_tempdir(schema_builder.build())?
};
// Approximate log proportions
let status_field_data = [
("INFO", 8000),
("ERROR", 300),
("WARN", 1200),
("DEBUG", 500),
("OK", 500),
("CRITICAL", 20),
("EMERGENCY", 1),
];
let log_level_distribution =
WeightedIndex::new(status_field_data.iter().map(|item| item.1)).unwrap();
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
let many_terms_data = (0..150_000)
.map(|num| format!("author{num}"))
.collect::<Vec<_>>();
// Prepare 1000 unique terms sampled using a Zipf distribution.
// Exponent ~1.1 approximates top-20 terms covering around ~20%.
let terms_1000: Vec<String> = (1..=1000).map(|i| format!("term_{i}")).collect();
let zipf_1000 = rand_distr::Zipf::new(1000, 1.1f64).unwrap();
{
let mut rng = StdRng::from_seed([1u8; 32]);
let mut index_writer = index.writer_with_num_threads(1, 200_000_000)?;
@@ -427,15 +543,25 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!())?;
}
if cardinality == Cardinality::Multivalued {
let log_level_sample_a = status_field_data[log_level_distribution.sample(&mut rng)].0;
let log_level_sample_b = status_field_data[log_level_distribution.sample(&mut rng)].0;
let idx_a = zipf_1000.sample(&mut rng) as usize - 1;
let idx_b = zipf_1000.sample(&mut rng) as usize - 1;
let term_1000_a = &terms_1000[idx_a];
let term_1000_b = &terms_1000[idx_b];
index_writer.add_document(doc!(
json_field => json!({"mixed_type": 10.0}),
json_field => json!({"mixed_type": 10.0}),
text_field => "cool",
text_field => "cool",
text_field_all_unique_terms => "cool",
text_field_all_unique_terms => "coolo",
text_field_many_terms => "cool",
text_field_many_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms_status => log_level_sample_a,
text_field_few_terms_status => log_level_sample_b,
text_field_1000_terms_zipf => term_1000_a.as_str(),
text_field_1000_terms_zipf => term_1000_b.as_str(),
score_field => 1u64,
score_field => 1u64,
score_field_f64 => lg_norm.sample(&mut rng),
@@ -460,8 +586,10 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!(
text_field => "cool",
json_field => json,
text_field_all_unique_terms => format!("unique_term_{}", rng.gen::<u64>()),
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0,
text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(),
score_field => val as u64,
score_field_f64 => lg_norm.sample(&mut rng),
score_field_i64 => val as i64,
@@ -513,7 +641,7 @@ fn filter_agg_all_query_with_sub_aggs(index: &Index) {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms" }
"terms": { "field": "text_few_terms_status" }
}
}
}
@@ -529,7 +657,7 @@ fn filter_agg_term_query_with_sub_aggs(index: &Index) {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms" }
"terms": { "field": "text_few_terms_status" }
}
}
}

View File

@@ -16,14 +16,15 @@
// - This bench isolates boolean iteration speed and intersection/union cost.
// - Use `cargo bench --bench boolean_conjunction` to run.
use binggan::{black_box, BenchRunner};
use binggan::{black_box, BenchGroup, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::{Count, TopDocs};
use tantivy::query::QueryParser;
use tantivy::schema::{Schema, TEXT};
use tantivy::{doc, Index, ReloadPolicy, Searcher};
use tantivy::collector::sort_key::SortByStaticFastValue;
use tantivy::collector::{Collector, Count, TopDocs};
use tantivy::query::{Query, QueryParser};
use tantivy::schema::{Schema, FAST, TEXT};
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher};
#[derive(Clone)]
struct BenchIndex {
@@ -33,23 +34,6 @@ struct BenchIndex {
query_parser: QueryParser,
}
impl BenchIndex {
#[inline(always)]
fn count_query(&self, query_str: &str) -> usize {
let query = self.query_parser.parse_query(query_str).unwrap();
self.searcher.search(&query, &Count).unwrap()
}
#[inline(always)]
fn topk_len(&self, query_str: &str, k: usize) -> usize {
let query = self.query_parser.parse_query(query_str).unwrap();
self.searcher
.search(&query, &TopDocs::with_limit(k))
.unwrap()
.len()
}
}
/// Build a single index containing both fields (title, body) and
/// return two BenchIndex views:
/// - single_field: QueryParser defaults to only "body"
@@ -59,6 +43,8 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
let mut schema_builder = Schema::builder();
let f_title = schema_builder.add_text_field("title", TEXT);
let f_body = schema_builder.add_text_field("body", TEXT);
let f_score = schema_builder.add_u64_field("score", FAST);
let f_score2 = schema_builder.add_u64_field("score2", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
@@ -67,11 +53,13 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
// Populate: spread each present token 90/10 to body/title
{
let mut writer = index.writer(500_000_000).unwrap();
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
for _ in 0..num_docs {
let has_a = rng.gen_bool(p_a as f64);
let has_b = rng.gen_bool(p_b as f64);
let has_c = rng.gen_bool(p_c as f64);
let score = rng.gen_range(0u64..100u64);
let score2 = rng.gen_range(0u64..100_000u64);
let mut title_tokens: Vec<&str> = Vec::new();
let mut body_tokens: Vec<&str> = Vec::new();
if has_a {
@@ -101,7 +89,9 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
writer
.add_document(doc!(
f_title=>title_tokens.join(" "),
f_body=>body_tokens.join(" ")
f_body=>body_tokens.join(" "),
f_score=>score,
f_score2=>score2,
))
.unwrap();
}
@@ -153,72 +143,76 @@ fn main() {
),
];
let queries = &["a", "+a +b", "+a +b +c", "a OR b", "a OR b OR c"];
let mut runner = BenchRunner::new();
for (label, n, pa, pb, pc) in scenarios {
let (single_view, multi_view) = build_shared_indices(n, pa, pb, pc);
// Single-field group: default field is body only
for (view_name, bench_index) in [("single_field", single_view), ("multi_field", multi_view)]
{
// Single-field group: default field is body only
let mut group = runner.new_group();
group.set_name(format!("single_field — {}", label));
group.register_with_input("+a_+b_count", &single_view, |benv: &BenchIndex| {
black_box(benv.count_query("+a +b"))
});
group.register_with_input("+a_+b_+c_count", &single_view, |benv: &BenchIndex| {
black_box(benv.count_query("+a +b +c"))
});
group.register_with_input("+a_+b_top10", &single_view, |benv: &BenchIndex| {
black_box(benv.topk_len("+a +b", 10))
});
group.register_with_input("+a_+b_+c_top10", &single_view, |benv: &BenchIndex| {
black_box(benv.topk_len("+a +b +c", 10))
});
// OR queries
group.register_with_input("a_OR_b_count", &single_view, |benv: &BenchIndex| {
black_box(benv.count_query("a OR b"))
});
group.register_with_input("a_OR_b_OR_c_count", &single_view, |benv: &BenchIndex| {
black_box(benv.count_query("a OR b OR c"))
});
group.register_with_input("a_OR_b_top10", &single_view, |benv: &BenchIndex| {
black_box(benv.topk_len("a OR b", 10))
});
group.register_with_input("a_OR_b_OR_c_top10", &single_view, |benv: &BenchIndex| {
black_box(benv.topk_len("a OR b OR c", 10))
});
group.run();
}
// Multi-field group: default fields are [title, body]
{
let mut group = runner.new_group();
group.set_name(format!("multi_field — {}", label));
group.register_with_input("+a_+b_count", &multi_view, |benv: &BenchIndex| {
black_box(benv.count_query("+a +b"))
});
group.register_with_input("+a_+b_+c_count", &multi_view, |benv: &BenchIndex| {
black_box(benv.count_query("+a +b +c"))
});
group.register_with_input("+a_+b_top10", &multi_view, |benv: &BenchIndex| {
black_box(benv.topk_len("+a +b", 10))
});
group.register_with_input("+a_+b_+c_top10", &multi_view, |benv: &BenchIndex| {
black_box(benv.topk_len("+a +b +c", 10))
});
// OR queries
group.register_with_input("a_OR_b_count", &multi_view, |benv: &BenchIndex| {
black_box(benv.count_query("a OR b"))
});
group.register_with_input("a_OR_b_OR_c_count", &multi_view, |benv: &BenchIndex| {
black_box(benv.count_query("a OR b OR c"))
});
group.register_with_input("a_OR_b_top10", &multi_view, |benv: &BenchIndex| {
black_box(benv.topk_len("a OR b", 10))
});
group.register_with_input("a_OR_b_OR_c_top10", &multi_view, |benv: &BenchIndex| {
black_box(benv.topk_len("a OR b OR c", 10))
});
group.set_name(format!("{}{}", view_name, label));
for query_str in queries {
add_bench_task(&mut group, &bench_index, query_str, Count, "count");
add_bench_task(
&mut group,
&bench_index,
query_str,
TopDocs::with_limit(10).order_by_score(),
"top10",
);
add_bench_task(
&mut group,
&bench_index,
query_str,
TopDocs::with_limit(10).order_by_fast_field::<u64>("score", Order::Asc),
"top10_by_ff",
);
add_bench_task(
&mut group,
&bench_index,
query_str,
TopDocs::with_limit(10).order_by((
SortByStaticFastValue::<u64>::for_field("score"),
SortByStaticFastValue::<u64>::for_field("score2"),
)),
"top10_by_2ff",
);
}
group.run();
}
}
}
fn add_bench_task<C: Collector + 'static>(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query_str: &str,
collector: C,
collector_name: &str,
) {
let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name);
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let search_task = SearchTask {
searcher: bench_index.searcher.clone(),
collector,
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
struct SearchTask<C: Collector> {
searcher: Searcher,
collector: C,
query: Box<dyn Query>,
}
impl<C: Collector> SearchTask<C> {
#[inline(never)]
pub fn run(&self) -> usize {
self.searcher.search(&self.query, &self.collector).unwrap();
1
}
}

View File

@@ -258,7 +258,7 @@ mod test {
bitpacker.write(val, num_bits, &mut data).unwrap();
}
bitpacker.close(&mut data).unwrap();
assert_eq!(data.len(), ((num_bits as usize) * len + 7) / 8);
assert_eq!(data.len(), ((num_bits as usize) * len).div_ceil(8));
let bitunpacker = BitUnpacker::new(num_bits);
(bitunpacker, vals, data)
}
@@ -304,7 +304,7 @@ mod test {
bitpacker.write(val, num_bits, &mut buffer).unwrap();
}
bitpacker.flush(&mut buffer).unwrap();
assert_eq!(buffer.len(), (vals.len() * num_bits as usize + 7) / 8);
assert_eq!(buffer.len(), (vals.len() * num_bits as usize).div_ceil(8));
let bitunpacker = BitUnpacker::new(num_bits);
let max_val = if num_bits == 64 {
u64::MAX

View File

@@ -19,7 +19,7 @@ fn u32_to_i32(val: u32) -> i32 {
#[inline]
unsafe fn u32_to_i32_avx2(vals_u32x8s: DataType) -> DataType {
const HIGHEST_BIT_MASK: DataType = from_u32x8([HIGHEST_BIT; NUM_LANES]);
op_xor(vals_u32x8s, HIGHEST_BIT_MASK)
unsafe { op_xor(vals_u32x8s, HIGHEST_BIT_MASK) }
}
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
@@ -66,17 +66,19 @@ unsafe fn filter_vec_avx2_aux(
]);
const SHIFT: __m256i = from_u32x8([NUM_LANES as u32; NUM_LANES]);
for _ in 0..num_words {
let word = load_unaligned(input);
let word = u32_to_i32_avx2(word);
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
let added_len = keeper_bitset.count_ones();
let filtered_doc_ids = compact(ids, keeper_bitset);
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
output_tail = output_tail.offset(added_len as isize);
ids = op_add(ids, SHIFT);
input = input.offset(1);
unsafe {
let word = load_unaligned(input);
let word = u32_to_i32_avx2(word);
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
let added_len = keeper_bitset.count_ones();
let filtered_doc_ids = compact(ids, keeper_bitset);
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
output_tail = output_tail.offset(added_len as isize);
ids = op_add(ids, SHIFT);
input = input.offset(1);
}
}
output_tail.offset_from(output) as usize
unsafe { output_tail.offset_from(output) as usize }
}
#[inline]
@@ -92,8 +94,7 @@ unsafe fn compute_filter_bitset(val: __m256i, range: std::ops::RangeInclusive<__
let too_low = op_greater(*range.start(), val);
let too_high = op_greater(val, *range.end());
let inside = op_or(too_low, too_high);
255 - std::arch::x86_64::_mm256_movemask_ps(std::mem::transmute::<DataType, __m256>(inside))
as u8
255 - std::arch::x86_64::_mm256_movemask_ps(_mm256_castsi256_ps(inside)) as u8
}
union U8x32 {

View File

@@ -73,7 +73,7 @@ The crate introduces the following concepts.
`Columnar` is an equivalent of a dataframe.
It maps `column_key` to `Column`.
A `Column<T>` asssociates a `RowId` (u32) to any
A `Column<T>` associates a `RowId` (u32) to any
number of values.
This is made possible by wrapping a `ColumnIndex` and a `ColumnValue` object.

View File

@@ -89,13 +89,6 @@ fn main() {
black_box(sum);
});
group.register("first_block_fetch", |column| {
let mut block: Vec<Option<u64>> = vec![None; 64];
let fetch_docids = (0..64).collect::<Vec<_>>();
column.first_vals(&fetch_docids, &mut block);
black_box(block[0]);
});
group.register("first_block_single_calls", |column| {
let mut block: Vec<Option<u64>> = vec![None; 64];
let fetch_docids = (0..64).collect::<Vec<_>>();

View File

@@ -29,12 +29,20 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
}
}
#[inline]
pub fn fetch_block_with_missing(&mut self, docs: &[u32], accessor: &Column<T>, missing: T) {
pub fn fetch_block_with_missing(
&mut self,
docs: &[u32],
accessor: &Column<T>,
missing: Option<T>,
) {
self.fetch_block(docs, accessor);
// no missing values
if accessor.index.get_cardinality().is_full() {
return;
}
let Some(missing) = missing else {
return;
};
// We can compare docid_cache length with docs to find missing docs
// For multi value columns we can't rely on the length and always need to scan

View File

@@ -131,6 +131,8 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
self.index.docids_to_rowids(doc_ids, doc_ids_out, row_ids)
}
/// Get an iterator over the values for the provided docid.
#[inline]
pub fn values_for_doc(&self, doc_id: DocId) -> impl Iterator<Item = T> + '_ {
self.index
.value_row_ids(doc_id)
@@ -158,15 +160,6 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
.select_batch_in_place(selected_docid_range.start, doc_ids);
}
/// Fills the output vector with the (possibly multiple values that are associated_with
/// `row_id`.
///
/// This method clears the `output` vector.
pub fn fill_vals(&self, row_id: RowId, output: &mut Vec<T>) {
output.clear();
output.extend(self.values_for_doc(row_id));
}
pub fn first_or_default_col(self, default_value: T) -> Arc<dyn ColumnValues<T>> {
Arc::new(FirstValueWithDefault {
column: self,

View File

@@ -1,7 +1,7 @@
use std::fmt::Debug;
use std::net::Ipv6Addr;
/// Montonic maps a value to u128 value space
/// Monotonic maps a value to u128 value space
/// Monotonic mapping enables `PartialOrd` on u128 space without conversion to original space.
pub trait MonotonicallyMappableToU128: 'static + PartialOrd + Copy + Debug + Send + Sync {
/// Converts a value to u128.

View File

@@ -8,7 +8,7 @@ use crate::column_values::ColumnValues;
const MID_POINT: u64 = (1u64 << 32) - 1u64;
/// `Line` describes a line function `y: ax + b` using integer
/// arithmetics.
/// arithmetic.
///
/// The slope is in fact a decimal split into a 32 bit integer value,
/// and a 32-bit decimal value.
@@ -94,7 +94,7 @@ impl Line {
// `(i, ys[])`.
//
// The best intercept therefore has the form
// `y[i] - line.eval(i)` (using wrapping arithmetics).
// `y[i] - line.eval(i)` (using wrapping arithmetic).
// In other words, the best intercept is one of the `y - Line::eval(ys[i])`
// and our task is just to pick the one that minimizes our error.
//

View File

@@ -52,7 +52,7 @@ pub trait ColumnCodecEstimator<T = u64>: 'static {
) -> io::Result<()>;
}
/// A column codec describes a colunm serialization format.
/// A column codec describes a column serialization format.
pub trait ColumnCodec<T: PartialOrd = u64> {
/// Specialized `ColumnValues` type.
type ColumnValues: ColumnValues<T> + 'static;

View File

@@ -181,6 +181,14 @@ pub struct BitSet {
len: u64,
max_value: u32,
}
impl std::fmt::Debug for BitSet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BitSet")
.field("len", &self.len)
.field("max_value", &self.max_value)
.finish()
}
}
fn num_buckets(max_val: u32) -> u32 {
max_val.div_ceil(64u32)

View File

@@ -28,7 +28,9 @@ impl BinarySerializable for VIntU128 {
writer.write_all(&buffer)
}
#[allow(clippy::unbuffered_bytes)]
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
#[allow(clippy::unbuffered_bytes)]
let mut bytes = reader.bytes();
let mut result = 0u128;
let mut shift = 0u64;
@@ -195,7 +197,9 @@ impl BinarySerializable for VInt {
writer.write_all(&buffer[0..num_bytes])
}
#[allow(clippy::unbuffered_bytes)]
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
#[allow(clippy::unbuffered_bytes)]
let mut bytes = reader.bytes();
let mut result = 0u64;
let mut shift = 0u64;

View File

@@ -208,7 +208,7 @@ fn main() -> tantivy::Result<()> {
// is the role of the `TopDocs` collector.
// We can now perform our query.
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
// The actual documents still need to be
// retrieved from Tantivy's store.
@@ -226,7 +226,7 @@ fn main() -> tantivy::Result<()> {
let query = query_parser.parse_query("title:sea^20 body:whale^70")?;
let (_score, doc_address) = searcher
.search(&query, &TopDocs::with_limit(1))?
.search(&query, &TopDocs::with_limit(1).order_by_score())?
.into_iter()
.next()
.unwrap();

View File

@@ -100,7 +100,7 @@ fn main() -> tantivy::Result<()> {
// here we want to get a hit on the 'ken' in Frankenstein
let query = query_parser.parse_query("ken")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
for (_, doc_address) in top_docs {
let retrieved_doc: TantivyDocument = searcher.doc(doc_address)?;

View File

@@ -50,14 +50,14 @@ fn main() -> tantivy::Result<()> {
{
// Simple exact search on the date
let query = query_parser.parse_query("occurred_at:\"2022-06-22T12:53:50.53Z\"")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5).order_by_score())?;
assert_eq!(count_docs.len(), 1);
}
{
// Range query on the date field
let query = query_parser
.parse_query(r#"occurred_at:[2022-06-22T12:58:00Z TO 2022-06-23T00:00:00Z}"#)?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(4))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(4).order_by_score())?;
assert_eq!(count_docs.len(), 1);
for (_score, doc_address) in count_docs {
let retrieved_doc = searcher.doc::<TantivyDocument>(doc_address)?;

View File

@@ -28,7 +28,7 @@ fn extract_doc_given_isbn(
// The second argument is here to tell we don't care about decoding positions,
// or term frequencies.
let term_query = TermQuery::new(isbn_term.clone(), IndexRecordOption::Basic);
let top_docs = searcher.search(&term_query, &TopDocs::with_limit(1))?;
let top_docs = searcher.search(&term_query, &TopDocs::with_limit(1).order_by_score())?;
if let Some((_score, doc_address)) = top_docs.first() {
let doc = searcher.doc(*doc_address)?;

View File

@@ -145,7 +145,7 @@ fn main() -> tantivy::Result<()> {
let query = FuzzyTermQuery::new(term, 2, true);
let (top_docs, count) = searcher
.search(&query, &(TopDocs::with_limit(5), Count))
.search(&query, &(TopDocs::with_limit(5).order_by_score(), Count))
.unwrap();
assert_eq!(count, 3);
assert_eq!(top_docs.len(), 3);

View File

@@ -69,25 +69,25 @@ fn main() -> tantivy::Result<()> {
{
// Inclusive range queries
let query = query_parser.parse_query("ip:[192.168.0.80 TO 192.168.0.100]")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5).order_by_score())?;
assert_eq!(count_docs.len(), 1);
}
{
// Exclusive range queries
let query = query_parser.parse_query("ip:{192.168.0.80 TO 192.168.1.100]")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(count_docs.len(), 0);
}
{
// Find docs with IP addresses smaller equal 192.168.1.100
let query = query_parser.parse_query("ip:[* TO 192.168.1.100]")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(count_docs.len(), 2);
}
{
// Find docs with IP addresses smaller than 192.168.1.100
let query = query_parser.parse_query("ip:[* TO 192.168.1.100}")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(count_docs.len(), 2);
}

View File

@@ -59,12 +59,12 @@ fn main() -> tantivy::Result<()> {
let query_parser = QueryParser::for_index(&index, vec![event_type, attributes]);
{
let query = query_parser.parse_query("target:submit-button")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(count_docs.len(), 2);
}
{
let query = query_parser.parse_query("target:submit")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(count_docs.len(), 2);
}
{
@@ -74,33 +74,33 @@ fn main() -> tantivy::Result<()> {
}
{
let query = query_parser.parse_query("click AND cart.product_id:133")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(hits.len(), 1);
}
{
// The sub-fields in the json field marked as default field still need to be explicitly
// addressed
let query = query_parser.parse_query("click AND 133")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(hits.len(), 0);
}
{
// Default json fields are ignored if they collide with the schema
let query = query_parser.parse_query("event_type:holiday-sale")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(hits.len(), 0);
}
// # Query via full attribute path
{
// This only searches in our schema's `event_type` field
let query = query_parser.parse_query("event_type:click")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(hits.len(), 2);
}
{
// Default json fields can still be accessed by full path
let query = query_parser.parse_query("attributes.event_type:holiday-sale")?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(hits.len(), 1);
}
Ok(())

View File

@@ -63,7 +63,7 @@ fn main() -> Result<()> {
// but not "in the Gulf Stream".
let query = query_parser.parse_query("\"in the su\"*")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
let mut titles = top_docs
.into_iter()
.map(|(_score, doc_address)| {

View File

@@ -107,7 +107,8 @@ fn main() -> tantivy::Result<()> {
IndexRecordOption::Basic,
);
let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?;
let (top_docs, count) =
searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?;
assert_eq!(count, 2);
@@ -128,7 +129,8 @@ fn main() -> tantivy::Result<()> {
IndexRecordOption::Basic,
);
let (_top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?;
let (_top_docs, count) =
searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?;
assert_eq!(count, 0);

View File

@@ -50,7 +50,7 @@ fn main() -> tantivy::Result<()> {
let query_parser = QueryParser::for_index(&index, vec![title, body]);
let query = query_parser.parse_query("sycamore spring")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
let snippet_generator = SnippetGenerator::create(&searcher, &*query, body)?;

View File

@@ -102,7 +102,7 @@ fn main() -> tantivy::Result<()> {
// stop words are applied on the query as well.
// The following will be equivalent to `title:frankenstein`
let query = query_parser.parse_query("title:\"the Frankenstein\"")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
for (score, doc_address) in top_docs {
let retrieved_doc: TantivyDocument = searcher.doc(doc_address)?;

View File

@@ -164,7 +164,7 @@ fn main() -> tantivy::Result<()> {
move |doc_id: DocId| Reverse(price[doc_id as usize])
};
let most_expensive_first = TopDocs::with_limit(10).custom_score(score_by_price);
let most_expensive_first = TopDocs::with_limit(10).order_by(score_by_price);
let hits = searcher.search(&query, &most_expensive_first)?;
assert_eq!(

View File

@@ -758,7 +758,17 @@ fn negate(expr: UserInputAst) -> UserInputAst {
fn leaf(inp: &str) -> IResult<&str, UserInputAst> {
alt((
delimited(char('('), ast, char(')')),
map(char('*'), |_| UserInputAst::from(UserInputLeaf::All)),
map(
terminated(
char('*'),
peek(alt((
value((), multispace1),
value((), char(')')),
value((), eof),
))),
),
|_| UserInputAst::from(UserInputLeaf::All),
),
map(preceded(tuple((tag("NOT"), multispace1)), leaf), negate),
literal,
))(inp)
@@ -779,7 +789,17 @@ fn leaf_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>> {
),
),
(
value((), char('*')),
value(
(),
terminated(
char('*'),
peek(alt((
value((), multispace1),
value((), char(')')),
value((), eof),
))),
),
),
map(nothing, |_| {
(Some(UserInputAst::from(UserInputLeaf::All)), Vec::new())
}),
@@ -1671,6 +1691,21 @@ mod test {
test_parse_query_to_ast_helper("abc:a b", "(*\"abc\":a *b)");
test_parse_query_to_ast_helper("abc:\"a b\"", "\"abc\":\"a b\"");
test_parse_query_to_ast_helper("foo:[1 TO 5]", "\"foo\":[\"1\" TO \"5\"]");
// Phrase prefixed with *
test_parse_query_to_ast_helper("foo:(*A)", "\"foo\":*A");
test_parse_query_to_ast_helper("*A", "*A");
test_parse_query_to_ast_helper("(*A)", "*A");
test_parse_query_to_ast_helper("foo:(A OR B)", "(?\"foo\":A ?\"foo\":B)");
test_parse_query_to_ast_helper("foo:(A* OR B*)", "(?\"foo\":A* ?\"foo\":B*)");
test_parse_query_to_ast_helper("foo:(*A OR *B)", "(?\"foo\":*A ?\"foo\":*B)");
}
#[test]
fn test_parse_query_all() {
test_parse_query_to_ast_helper("*", "*");
test_parse_query_to_ast_helper("(*)", "*");
test_parse_query_to_ast_helper("(* )", "*");
}
#[test]

View File

@@ -16,15 +16,16 @@ use crate::index::SegmentReader;
/// That way we can use it the same way as if it would come from the fastfield.
pub(crate) fn get_missing_val_as_u64_lenient(
column_type: ColumnType,
column_max_value: u64,
missing: &Key,
field_name: &str,
) -> crate::Result<Option<u64>> {
let missing_val = match missing {
Key::Str(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::Str(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
// Allow fallback to number on text fields
Key::F64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::U64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::I64(_) if column_type == ColumnType::Str => Some(u64::MAX),
Key::F64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
Key::U64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
Key::I64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
Key::F64(val) if column_type.numerical_type().is_some() => {
f64_to_fastfield_u64(*val, &column_type)
}

View File

@@ -1,4 +1,4 @@
use columnar::{Column, ColumnType, StrColumn};
use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn};
use common::BitSet;
use rustc_hash::FxHashSet;
use serde::Serialize;
@@ -10,16 +10,16 @@ use crate::aggregation::accessor_helpers::{
};
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
use crate::aggregation::bucket::{
FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam,
MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector,
SegmentRangeCollector, SegmentTermCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
build_segment_range_collector, FilterAggReqData, HistogramAggReqData, HistogramBounds,
IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector,
SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
TermsAggregationInternal,
};
use crate::aggregation::metric::{
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation,
ExtendedStatsAggregation, MaxAggregation, MetricAggReqData, MinAggregation,
SegmentCardinalityCollector, SegmentExtendedStatsCollector, SegmentPercentilesCollector,
SegmentStatsCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation,
MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
TopHitsSegmentCollector,
};
use crate::aggregation::segment_agg_result::{
@@ -35,6 +35,7 @@ pub struct AggregationsSegmentCtx {
/// Request data for each aggregation type.
pub per_request: PerRequestAggSegCtx,
pub context: AggContextParams,
pub column_block_accessor: ColumnBlockAccessor<u64>,
}
impl AggregationsSegmentCtx {
@@ -107,21 +108,14 @@ impl AggregationsSegmentCtx {
.as_deref()
.expect("range_req_data slot is empty (taken)")
}
#[inline]
pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData {
self.per_request.filter_req_data[idx]
.as_deref()
.expect("filter_req_data slot is empty (taken)")
}
// ---------- mutable getters ----------
#[inline]
pub(crate) fn get_term_req_data_mut(&mut self, idx: usize) -> &mut TermsAggReqData {
self.per_request.term_req_data[idx]
.as_deref_mut()
.expect("term_req_data slot is empty (taken)")
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
&mut self.per_request.stats_metric_req_data[idx]
}
#[inline]
pub(crate) fn get_cardinality_req_data_mut(
&mut self,
@@ -129,10 +123,7 @@ impl AggregationsSegmentCtx {
) -> &mut CardinalityAggReqData {
&mut self.per_request.cardinality_req_data[idx]
}
#[inline]
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
&mut self.per_request.stats_metric_req_data[idx]
}
#[inline]
pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData {
self.per_request.histogram_req_data[idx]
@@ -142,21 +133,6 @@ impl AggregationsSegmentCtx {
// ---------- take / put (terms, histogram, range) ----------
/// Move out the boxed Terms request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_term_req_data(&mut self, idx: usize) -> Box<TermsAggReqData> {
self.per_request.term_req_data[idx]
.take()
.expect("term_req_data slot is empty (taken)")
}
/// Put back a Terms request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_term_req_data(&mut self, idx: usize, value: Box<TermsAggReqData>) {
debug_assert!(self.per_request.term_req_data[idx].is_none());
self.per_request.term_req_data[idx] = Some(value);
}
/// Move out the boxed Histogram request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_histogram_req_data(&mut self, idx: usize) -> Box<HistogramAggReqData> {
@@ -320,6 +296,7 @@ impl PerRequestAggSegCtx {
/// Convert the aggregation tree into a serializable struct representation.
/// Each node contains: { name, kind, children }.
#[allow(dead_code)]
pub fn get_view_tree(&self) -> Vec<AggTreeViewNode> {
fn node_to_view(node: &AggRefNode, pr: &PerRequestAggSegCtx) -> AggTreeViewNode {
let mut children: Vec<AggTreeViewNode> =
@@ -345,12 +322,19 @@ impl PerRequestAggSegCtx {
pub(crate) fn build_segment_agg_collectors_root(
req: &mut AggregationsSegmentCtx,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
build_segment_agg_collectors(req, &req.per_request.agg_tree.clone())
build_segment_agg_collectors_generic(req, &req.per_request.agg_tree.clone())
}
pub(crate) fn build_segment_agg_collectors(
req: &mut AggregationsSegmentCtx,
nodes: &[AggRefNode],
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
build_segment_agg_collectors_generic(req, nodes)
}
fn build_segment_agg_collectors_generic(
req: &mut AggregationsSegmentCtx,
nodes: &[AggRefNode],
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let mut collectors = Vec::new();
for node in nodes.iter() {
@@ -373,9 +357,7 @@ pub(crate) fn build_segment_agg_collector(
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
match node.kind {
AggKind::Terms => Ok(Box::new(SegmentTermCollector::from_req_and_validate(
req, node,
)?)),
AggKind::Terms => crate::aggregation::bucket::build_segment_term_collector(req, node),
AggKind::MissingTerm => {
let req_data = &mut req.per_request.missing_term_req_data[node.idx_in_req_data];
if req_data.accessors.is_empty() {
@@ -390,6 +372,8 @@ pub(crate) fn build_segment_agg_collector(
Ok(Box::new(SegmentCardinalityCollector::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
)))
}
AggKind::StatsKind(stats_type) => {
@@ -400,20 +384,21 @@ pub(crate) fn build_segment_agg_collector(
| StatsType::Count
| StatsType::Max
| StatsType::Min
| StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req(
node.idx_in_req_data,
))),
StatsType::ExtendedStats(sigma) => {
Ok(Box::new(SegmentExtendedStatsCollector::from_req(
req_data.field_type,
sigma,
node.idx_in_req_data,
req_data.missing,
)))
}
StatsType::Percentiles => Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(node.idx_in_req_data)?,
| StatsType::Stats => build_segment_stats_collector(req_data),
StatsType::ExtendedStats(sigma) => Ok(Box::new(
SegmentExtendedStatsCollector::from_req(req_data, sigma),
)),
StatsType::Percentiles => {
let req_data = req.get_metric_req_data_mut(node.idx_in_req_data);
Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(
req_data.field_type,
req_data.missing_u64,
req_data.accessor.clone(),
node.idx_in_req_data,
),
))
}
}
}
AggKind::TopHits => {
@@ -430,9 +415,7 @@ pub(crate) fn build_segment_agg_collector(
AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
req, node,
)?)),
AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
req, node,
)?)),
AggKind::Range => Ok(build_segment_range_collector(req, node)?),
AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate(
req, node,
)?)),
@@ -495,10 +478,11 @@ pub(crate) fn build_aggregations_data_from_req(
let mut data = AggregationsSegmentCtx {
per_request: Default::default(),
context,
column_block_accessor: ColumnBlockAccessor::default(),
};
for (name, agg) in aggs.iter() {
let nodes = build_nodes(name, agg, reader, segment_ordinal, &mut data)?;
let nodes = build_nodes(name, agg, reader, segment_ordinal, &mut data, true)?;
data.per_request.agg_tree.extend(nodes);
}
Ok(data)
@@ -510,6 +494,7 @@ fn build_nodes(
reader: &SegmentReader,
segment_ordinal: SegmentOrdinal,
data: &mut AggregationsSegmentCtx,
is_top_level: bool,
) -> crate::Result<Vec<AggRefNode>> {
use AggregationVariants::*;
match &req.agg {
@@ -522,9 +507,9 @@ fn build_nodes(
let idx_in_req_data = data.push_range_req_data(RangeAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
req: range_req.clone(),
is_top_level,
});
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
Ok(vec![AggRefNode {
@@ -542,9 +527,7 @@ fn build_nodes(
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
sub_aggregation_blueprint: None,
req: histo_req.clone(),
is_date_histogram: false,
bounds: HistogramBounds {
@@ -569,9 +552,7 @@ fn build_nodes(
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
sub_aggregation_blueprint: None,
req: histo_req,
is_date_histogram: true,
bounds: HistogramBounds {
@@ -596,6 +577,7 @@ fn build_nodes(
data,
&req.sub_aggregation,
TermsOrCardinalityRequest::Terms(terms_req.clone()),
is_top_level,
),
Cardinality(card_req) => build_terms_or_cardinality_nodes(
agg_name,
@@ -606,6 +588,7 @@ fn build_nodes(
data,
&req.sub_aggregation,
TermsOrCardinalityRequest::Cardinality(card_req.clone()),
is_top_level,
),
Average(AverageAggregation { field, missing, .. })
| Max(MaxAggregation { field, missing, .. })
@@ -649,7 +632,6 @@ fn build_nodes(
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
collecting_for,
missing: *missing,
@@ -677,7 +659,6 @@ fn build_nodes(
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
collecting_for: StatsType::Percentiles,
missing: percentiles_req.missing,
@@ -734,7 +715,7 @@ fn build_nodes(
// Build the query and evaluator upfront
let schema = reader.schema();
let tokenizers = &data.context.tokenizers;
let query = filter_req.parse_query(&schema, tokenizers)?;
let query = filter_req.parse_query(schema, tokenizers)?;
let evaluator = crate::aggregation::bucket::DocumentQueryEvaluator::new(
query,
schema.clone(),
@@ -771,7 +752,14 @@ fn build_children(
) -> crate::Result<Vec<AggRefNode>> {
let mut children = Vec::new();
for (name, agg) in aggs.iter() {
children.extend(build_nodes(name, agg, reader, segment_ordinal, data)?);
children.extend(build_nodes(
name,
agg,
reader,
segment_ordinal,
data,
false,
)?);
}
Ok(children)
}
@@ -835,6 +823,7 @@ fn build_terms_or_cardinality_nodes(
data: &mut AggregationsSegmentCtx,
sub_aggs: &Aggregations,
req: TermsOrCardinalityRequest,
is_top_level: bool,
) -> crate::Result<Vec<AggRefNode>> {
let mut nodes = Vec::new();
@@ -886,12 +875,12 @@ fn build_terms_or_cardinality_nodes(
});
}
// Add one node per accessor to mirror previous behavior and allow per-type missing handling.
// Add one node per accessor
for (accessor, column_type) in column_and_types {
let missing_value_for_accessor = if use_special_missing_agg {
None
} else if let Some(m) = missing.as_ref() {
get_missing_val_as_u64_lenient(column_type, m, field_name)?
get_missing_val_as_u64_lenient(column_type, accessor.max_value(), m, field_name)?
} else {
None
};
@@ -917,13 +906,11 @@ fn build_terms_or_cardinality_nodes(
column_type,
str_dict_column: str_dict_column.clone(),
missing_value_for_accessor,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
req: TermsAggregationInternal::from_req(req),
// Will be filled later when building collectors
sub_aggregation_blueprint: None,
sug_aggregations: sub_aggs.clone(),
allowed_term_ids,
is_top_level,
});
(idx_in_req_data, AggKind::Terms)
}
@@ -933,7 +920,6 @@ fn build_terms_or_cardinality_nodes(
column_type,
str_dict_column: str_dict_column.clone(),
missing_value_for_accessor,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
req: req.clone(),
});

View File

@@ -35,6 +35,7 @@ pub struct AggregationLimitsGuard {
/// Allocated memory with this guard.
allocated_with_the_guard: u64,
}
impl Clone for AggregationLimitsGuard {
fn clone(&self) -> Self {
Self {
@@ -70,7 +71,7 @@ impl AggregationLimitsGuard {
/// *memory_limit*
/// memory_limit is defined in bytes.
/// Aggregation fails when the estimated memory consumption of the aggregation is higher than
/// memory_limit.
/// memory_limit.
/// memory_limit will default to `DEFAULT_MEMORY_LIMIT` (500MB)
///
/// *bucket_limit*

View File

@@ -16,7 +16,7 @@ use super::{AggregationError, Key};
use crate::TantivyError;
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
/// The final aggegation result.
/// The final aggregation result.
pub struct AggregationResults(pub FxHashMap<String, AggregationResult>);
impl AggregationResults {

View File

@@ -2,15 +2,441 @@ use serde_json::Value;
use crate::aggregation::agg_req::{Aggregation, Aggregations};
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
use crate::aggregation::collector::AggregationCollector;
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
use crate::aggregation::DistributedAggregationCollector;
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
use crate::query::{AllQuery, TermQuery};
use crate::schema::{IndexRecordOption, Schema, FAST};
use crate::{Index, IndexWriter, Term};
// The following tests ensure that each bucket aggregation type correctly functions as a
// sub-aggregation of another bucket aggregation in two scenarios:
// 1) The parent has more buckets than the child sub-aggregation
// 2) The child sub-aggregation has more buckets than the parent
//
// These scenarios exercise the bucket id mapping and sub-aggregation routing logic.
#[test]
fn test_terms_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with 4 buckets
// Child: terms on text -> 2 buckets
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
// Exact expected structure and counts
assert_eq!(
res["parent_range"]["buckets"],
json!([
{
"key": "*-3",
"doc_count": 1,
"to": 3.0,
"child_terms": {
"buckets": [
{"doc_count": 1, "key": "cool"}
],
"sum_other_doc_count": 0
}
},
{
"key": "3-7",
"doc_count": 3,
"from": 3.0,
"to": 7.0,
"child_terms": {
"buckets": [
{"doc_count": 2, "key": "cool"},
{"doc_count": 1, "key": "nohit"}
],
"sum_other_doc_count": 0
}
},
{
"key": "7-20",
"doc_count": 3,
"from": 7.0,
"to": 20.0,
"child_terms": {
"buckets": [
{"doc_count": 3, "key": "cool"}
],
"sum_other_doc_count": 0
}
},
{
"key": "20-*",
"doc_count": 2,
"from": 20.0,
"child_terms": {
"buckets": [
{"doc_count": 1, "key": "cool"},
{"doc_count": 1, "key": "nohit"}
],
"sum_other_doc_count": 0
}
}
])
);
// Case B: child has more buckets than parent
// Parent: histogram on score with large interval -> 1 bucket
// Child: terms on text -> 2 buckets (cool/nohit)
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_hist": {
"histogram": {"field": "score", "interval": 100.0},
"aggs": {
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_hist"],
json!({
"buckets": [
{
"key": 0.0,
"doc_count": 9,
"child_terms": {
"buckets": [
{"doc_count": 7, "key": "cool"},
{"doc_count": 2, "key": "nohit"}
],
"sum_other_doc_count": 0
}
}
]
})
);
Ok(())
}
#[test]
fn test_range_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with 5 buckets
// Child: coarse range with 3 buckets
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 20.0}
]
}
}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
assert_eq!(
res["parent_range"]["buckets"],
json!([
{"key": "*-3", "doc_count": 1, "to": 3.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 1, "to": 3.0},
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 3, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 1, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 2, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "20-*", "doc_count": 2, "from": 20.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 2, "from": 20.0}
]}
}
])
);
// Case B: child has more buckets than parent
// Parent: terms on text (2 buckets)
// Child: range with 4 buckets
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 20.0}
]
}
}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_terms"],
json!({
"buckets": [
{
"key": "cool",
"doc_count": 7,
"child_range": {
"buckets": [
{"key": "*-3", "doc_count": 1, "to": 3.0},
{"key": "3-7", "doc_count": 2, "from": 3.0, "to": 7.0},
{"key": "7-20", "doc_count": 3, "from": 7.0, "to": 20.0},
{"key": "20-*", "doc_count": 1, "from": 20.0}
]
}
},
{
"key": "nohit",
"doc_count": 2,
"child_range": {
"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-7", "doc_count": 1, "from": 3.0, "to": 7.0},
{"key": "7-20", "doc_count": 0, "from": 7.0, "to": 20.0},
{"key": "20-*", "doc_count": 1, "from": 20.0}
]
}
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
})
);
Ok(())
}
#[test]
fn test_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with several ranges
// Child: histogram with large interval (single bucket per parent)
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_hist": {"histogram": {"field": "score", "interval": 100.0}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
assert_eq!(
res["parent_range"]["buckets"],
json!([
{"key": "*-3", "doc_count": 1, "to": 3.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
},
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 3} ]}
},
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
},
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
},
{"key": "20-*", "doc_count": 2, "from": 20.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
}
])
);
// Case B: child has more buckets than parent
// Parent: terms on text -> 2 buckets
// Child: histogram with small interval -> multiple buckets including empties
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_hist": {"histogram": {"field": "score", "interval": 10.0}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_terms"],
json!({
"buckets": [
{
"key": "cool",
"doc_count": 7,
"child_hist": {
"buckets": [
{"key": 0.0, "doc_count": 4},
{"key": 10.0, "doc_count": 2},
{"key": 20.0, "doc_count": 0},
{"key": 30.0, "doc_count": 0},
{"key": 40.0, "doc_count": 1}
]
}
},
{
"key": "nohit",
"doc_count": 2,
"child_hist": {
"buckets": [
{"key": 0.0, "doc_count": 1},
{"key": 10.0, "doc_count": 0},
{"key": 20.0, "doc_count": 0},
{"key": 30.0, "doc_count": 0},
{"key": 40.0, "doc_count": 1}
]
}
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
})
);
Ok(())
}
#[test]
fn test_date_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with several buckets
// Child: date_histogram with 30d -> single bucket per parent
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "30d"}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
let buckets = res["parent_range"]["buckets"].as_array().unwrap();
// Verify each parent bucket has exactly one child date bucket with matching doc_count
for bucket in buckets {
let parent_count = bucket["doc_count"].as_u64().unwrap();
let child_buckets = bucket["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(child_buckets.len(), 1);
assert_eq!(child_buckets[0]["doc_count"], parent_count);
}
// Case B: child has more buckets than parent
// Parent: terms on text (2 buckets)
// Child: date_histogram with 1d -> multiple buckets
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "1d"}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
let buckets = res["parent_terms"]["buckets"].as_array().unwrap();
// cool bucket
assert_eq!(buckets[0]["key"], "cool");
let cool_buckets = buckets[0]["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(cool_buckets.len(), 3);
assert_eq!(cool_buckets[0]["doc_count"], 1); // day 0
assert_eq!(cool_buckets[1]["doc_count"], 4); // day 1
assert_eq!(cool_buckets[2]["doc_count"], 2); // day 2
// nohit bucket
assert_eq!(buckets[1]["key"], "nohit");
let nohit_buckets = buckets[1]["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(nohit_buckets.len(), 2);
assert_eq!(nohit_buckets[0]["doc_count"], 1); // day 1
assert_eq!(nohit_buckets[1]["doc_count"], 1); // day 2
Ok(())
}
fn get_avg_req(field_name: &str) -> Aggregation {
serde_json::from_value(json!({
"avg": {
@@ -25,6 +451,10 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
}
// *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE ***
// Note: The flushng part of these tests are outdated, since the buffering change after converting
// the collection into one collector per request instead of per bucket.
//
// However they are useful as they test a complex aggregation requests.
fn test_aggregation_flushing(
merge_segments: bool,
use_distributed_collector: bool,
@@ -37,8 +467,9 @@ fn test_aggregation_flushing(
let reader = index.reader()?;
assert_eq!(DOC_BLOCK_SIZE, 64);
// In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block.
assert_eq!(COLLECT_BLOCK_BUFFER_LEN, 64);
// In the tree we cache documents of COLLECT_BLOCK_BUFFER_LEN before passing them down as one
// block.
//
// Build a request so that on the first level we have one full cache, which is then flushed.
// The same cache should have some residue docs at the end, which are flushed (Range 0-70)

View File

@@ -6,10 +6,12 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
};
use crate::aggregation::segment_agg_result::{CollectorClone, SegmentAggregationCollector};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
use crate::docset::DocSet;
use crate::query::{AllQuery, EnableScoring, Query, QueryParser};
use crate::schema::Schema;
@@ -32,7 +34,7 @@ use crate::{DocId, SegmentReader, TantivyError};
///
/// # Implementation Requirements
///
/// Implementors must:
/// Implementers must:
/// 1. Derive `Debug`, `Clone`, `Serialize`, and `Deserialize`
/// 2. Use `#[typetag::serde]` attribute on the impl block
/// 3. Implement `build_query()` to construct the query from schema/tokenizers
@@ -410,9 +412,9 @@ impl FilterAggReqData {
pub(crate) fn get_memory_consumption(&self) -> usize {
// Estimate: name + segment reader reference + bitset + buffer capacity
self.name.len()
+ std::mem::size_of::<SegmentReader>()
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
+ std::mem::size_of::<SegmentReader>()
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
}
}
@@ -489,12 +491,19 @@ impl Debug for DocumentQueryEvaluator {
}
}
#[derive(Debug, Clone, PartialEq, Copy)]
struct DocCount {
doc_count: u64,
bucket_id: BucketId,
}
/// Segment collector for filter aggregation
pub struct SegmentFilterCollector {
/// Document count in this bucket
doc_count: u64,
/// Document counts per parent bucket
parent_buckets: Vec<DocCount>,
/// Sub-aggregation collectors
sub_aggregations: Option<Box<dyn SegmentAggregationCollector>>,
sub_aggregations: Option<CachedSubAggs<true>>,
bucket_id_provider: BucketIdProvider,
/// Accessor index for this filter aggregation (to access FilterAggReqData)
accessor_idx: usize,
}
@@ -511,11 +520,13 @@ impl SegmentFilterCollector {
} else {
None
};
let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new);
Ok(SegmentFilterCollector {
doc_count: 0,
parent_buckets: Vec::new(),
sub_aggregations: sub_agg_collector,
accessor_idx: node.idx_in_req_data,
bucket_id_provider: BucketIdProvider::default(),
})
}
}
@@ -523,35 +534,41 @@ impl SegmentFilterCollector {
impl Debug for SegmentFilterCollector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentFilterCollector")
.field("doc_count", &self.doc_count)
.field("buckets", &self.parent_buckets)
.field("has_sub_aggs", &self.sub_aggregations.is_some())
.field("accessor_idx", &self.accessor_idx)
.finish()
}
}
impl CollectorClone for SegmentFilterCollector {
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
// For now, panic - this needs proper implementation with weight recreation
panic!("SegmentFilterCollector cloning not yet implemented - requires weight recreation")
}
}
impl SegmentAggregationCollector for SegmentFilterCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let mut sub_results = IntermediateAggregationResults::default();
let bucket_opt = self.parent_buckets.get(parent_bucket_id as usize);
if let Some(sub_aggs) = self.sub_aggregations {
sub_aggs.add_intermediate_aggregation_result(agg_data, &mut sub_results)?;
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_results,
// Here we create a new bucket ID for sub-aggregations if the bucket doesn't
// exist, so that sub-aggregations can still produce results (e.g., zero doc
// count)
bucket_opt
.map(|bucket| bucket.bucket_id)
.unwrap_or(self.bucket_id_provider.next_bucket_id()),
)?;
}
// Create the filter bucket result
let filter_bucket_result = IntermediateBucketResult::Filter {
doc_count: self.doc_count,
doc_count: bucket_opt.map(|b| b.doc_count).unwrap_or(0),
sub_aggregations: sub_results,
};
@@ -570,32 +587,17 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
Ok(())
}
fn collect(&mut self, doc: DocId, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
// Access the evaluator from FilterAggReqData
let req_data = agg_data.get_filter_req_data(self.accessor_idx);
// O(1) BitSet lookup to check if document matches filter
if req_data.evaluator.matches_document(doc) {
self.doc_count += 1;
// If we have sub-aggregations, collect on them for this filtered document
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs.collect(doc, agg_data)?;
}
}
Ok(())
}
#[inline]
fn collect_block(
fn collect(
&mut self,
docs: &[DocId],
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if docs.is_empty() {
return Ok(());
}
let mut bucket = self.parent_buckets[parent_bucket_id as usize];
// Take the request data to avoid borrow checker issues with sub-aggregations
let mut req = agg_data.take_filter_req_data(self.accessor_idx);
@@ -604,18 +606,24 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
req.evaluator
.filter_batch(docs, &mut req.matching_docs_buffer);
self.doc_count += req.matching_docs_buffer.len() as u64;
bucket.doc_count += req.matching_docs_buffer.len() as u64;
// Batch process sub-aggregations if we have matches
if !req.matching_docs_buffer.is_empty() {
if let Some(sub_aggs) = &mut self.sub_aggregations {
// Use collect_block for better sub-aggregation performance
sub_aggs.collect_block(&req.matching_docs_buffer, agg_data)?;
for &doc_id in &req.matching_docs_buffer {
sub_aggs.push(bucket.bucket_id, doc_id);
}
}
}
// Put the request data back
agg_data.put_back_filter_req_data(self.accessor_idx, req);
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs.check_flush_local(agg_data)?;
}
// put back bucket
self.parent_buckets[parent_bucket_id as usize] = bucket;
Ok(())
}
@@ -626,6 +634,21 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
let bucket_id = self.bucket_id_provider.next_bucket_id();
self.parent_buckets.push(DocCount {
doc_count: 0,
bucket_id,
});
}
Ok(())
}
}
/// Intermediate result for filter aggregation
@@ -639,16 +662,14 @@ pub struct IntermediateFilterBucketResult {
#[cfg(test)]
mod tests {
use std::time::Instant;
use serde_json::{json, Value};
use super::*;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::{AggContextParams, AggregationCollector};
use crate::query::{AllQuery, QueryParser, TermQuery};
use crate::schema::{IndexRecordOption, Schema, Term, FAST, INDEXED, STORED, TEXT};
use crate::query::{AllQuery, TermQuery};
use crate::schema::{IndexRecordOption, Schema, Term, FAST, INDEXED, TEXT};
use crate::{doc, Index, IndexWriter};
// Test helper functions
@@ -729,12 +750,13 @@ mod tests {
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer: IndexWriter = index.writer(50_000_000)?;
let mut writer: IndexWriter = index.writer_for_tests()?;
writer.add_document(doc!(
category => "electronics", brand => "apple",
price => 999u64, rating => 4.5f64, in_stock => true
))?;
writer.commit()?;
writer.add_document(doc!(
category => "electronics", brand => "samsung",
price => 799u64, rating => 4.2f64, in_stock => true
@@ -938,7 +960,7 @@ mod tests {
let index = create_standard_test_index()?;
let reader = index.reader()?;
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 2);
let agg = json!({
"premium_electronics": {
"filter": "category:electronics AND price:[800 TO *]",
@@ -1520,9 +1542,9 @@ mod tests {
let searcher = reader.searcher();
let agg = json!({
"test": {
"filter": deserialized,
"aggs": { "count": { "value_count": { "field": "brand" } } }
"test": {
"filter": deserialized,
"aggs": { "count": { "value_count": { "field": "brand" } } }
}
});

View File

@@ -1,6 +1,6 @@
use std::cmp::Ordering;
use columnar::{Column, ColumnBlockAccessor, ColumnType};
use columnar::{Column, ColumnType};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use tantivy_bitpacker::minmax;
@@ -8,14 +8,14 @@ use tantivy_bitpacker::minmax;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_result::BucketEntry;
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateHistogramBucketEntry,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::*;
use crate::TantivyError;
@@ -26,13 +26,8 @@ pub struct HistogramAggReqData {
pub accessor: Column<u64>,
/// The field type of the fast field.
pub field_type: ColumnType,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The name of the aggregation.
pub name: String,
/// The sub aggregation blueprint, used to create sub aggregations for each bucket.
/// Will be filled during initialization of the collector.
pub sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
/// The histogram aggregation request.
pub req: HistogramAggregation,
/// True if this is a date_histogram aggregation.
@@ -257,18 +252,24 @@ impl HistogramBounds {
pub(crate) struct SegmentHistogramBucketEntry {
pub key: f64,
pub doc_count: u64,
pub bucket_id: BucketId,
}
impl SegmentHistogramBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
sub_aggregation: &mut Option<CachedSubAggs>,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateHistogramBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = sub_aggregation {
sub_aggregation
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?;
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
self.bucket_id,
)?;
}
Ok(IntermediateHistogramBucketEntry {
key: self.key,
@@ -278,27 +279,38 @@ impl SegmentHistogramBucketEntry {
}
}
#[derive(Clone, Debug, Default)]
struct HistogramBuckets {
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct SegmentHistogramCollector {
/// The buckets containing the aggregation data.
buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
sub_aggregations: FxHashMap<i64, Box<dyn SegmentAggregationCollector>>,
/// One Histogram bucket per parent bucket id.
parent_buckets: Vec<HistogramBuckets>,
sub_agg: Option<CachedSubAggs>,
accessor_idx: usize,
bucket_id_provider: BucketIdProvider,
}
impl SegmentAggregationCollector for SegmentHistogramCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data
.get_histogram_req_data(self.accessor_idx)
.name
.clone();
let bucket = self.into_intermediate_bucket_result(agg_data)?;
// TODO: avoid prepare_max_bucket here and handle empty buckets.
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let histogram = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
let bucket = self.add_intermediate_bucket_result(agg_data, histogram)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
@@ -307,44 +319,40 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mut req = agg_data.take_histogram_req_data(self.accessor_idx);
let req = agg_data.take_histogram_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption();
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
let bounds = req.bounds;
let interval = req.req.interval;
let offset = req.offset;
let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64;
req.column_block_accessor.fetch_block(docs, &req.accessor);
for (doc, val) in req
agg_data
.column_block_accessor
.fetch_block(docs, &req.accessor);
for (doc, val) in agg_data
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
{
let val = f64_from_fastfield_u64(val, &req.field_type);
let val = f64_from_fastfield_u64(val, req.field_type);
let bucket_pos = get_bucket_pos(val);
if bounds.contains(val) {
let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| {
let bucket = buckets.entry(bucket_pos).or_insert_with(|| {
let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset);
SegmentHistogramBucketEntry { key, doc_count: 0 }
SegmentHistogramBucketEntry {
key,
doc_count: 0,
bucket_id: self.bucket_id_provider.next_bucket_id(),
}
});
bucket.doc_count += 1;
if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() {
self.sub_aggregations
.entry(bucket_pos)
.or_insert_with(|| sub_aggregation_blueprint.clone())
.collect(doc, agg_data)?;
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.push(bucket.bucket_id, doc);
}
}
}
@@ -358,14 +366,30 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
.add_memory_consumed(mem_delta as u64)?;
}
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for sub_aggregation in self.sub_aggregations.values_mut() {
if let Some(sub_aggregation) = &mut self.sub_agg {
sub_aggregation.flush(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
self.parent_buckets.push(HistogramBuckets {
buckets: FxHashMap::default(),
});
}
Ok(())
}
}
@@ -373,22 +397,19 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
impl SegmentHistogramCollector {
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let sub_aggs_mem = self.sub_aggregations.memory_consumption();
let buckets_mem = self.buckets.memory_consumption();
self_mem + sub_aggs_mem + buckets_mem
let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
self_mem + buckets_mem
}
/// Converts the collector result into a intermediate bucket result.
pub fn into_intermediate_bucket_result(
self,
fn add_intermediate_bucket_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
histogram: HistogramBuckets,
) -> crate::Result<IntermediateBucketResult> {
let mut buckets = Vec::with_capacity(self.buckets.len());
let mut buckets = Vec::with_capacity(histogram.buckets.len());
for (bucket_pos, bucket) in self.buckets {
let bucket_res = bucket.into_intermediate_bucket_entry(
self.sub_aggregations.get(&bucket_pos).cloned(),
agg_data,
);
for bucket in histogram.buckets.into_values() {
let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data);
buckets.push(bucket_res?);
}
@@ -408,7 +429,7 @@ impl SegmentHistogramCollector {
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let blueprint = if !node.children.is_empty() {
let sub_agg = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else {
None
@@ -423,13 +444,13 @@ impl SegmentHistogramCollector {
max: f64::MAX,
});
req_data.offset = req_data.req.offset.unwrap_or(0.0);
req_data.sub_aggregation_blueprint = blueprint;
let sub_agg = sub_agg.map(CachedSubAggs::new);
Ok(Self {
buckets: Default::default(),
sub_aggregations: Default::default(),
parent_buckets: Default::default(),
sub_agg,
accessor_idx: node.idx_in_req_data,
bucket_id_provider: BucketIdProvider::default(),
})
}
}

View File

@@ -1,18 +1,20 @@
use std::fmt::Debug;
use std::ops::Range;
use columnar::{Column, ColumnBlockAccessor, ColumnType};
use columnar::{Column, ColumnType};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::AggregationLimitsGuard;
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::*;
use crate::TantivyError;
@@ -23,12 +25,12 @@ pub struct RangeAggReqData {
pub accessor: Column<u64>,
/// The type of the fast field.
pub field_type: ColumnType,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The range aggregation request.
pub req: RangeAggregation,
/// The name of the aggregation.
pub name: String,
/// Whether this is a top-level aggregation.
pub is_top_level: bool,
}
impl RangeAggReqData {
@@ -151,19 +153,47 @@ pub(crate) struct SegmentRangeAndBucketEntry {
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Clone, Debug)]
pub struct SegmentRangeCollector {
pub struct SegmentRangeCollector<const LOWCARD: bool = false> {
/// The buckets containing the aggregation data.
buckets: Vec<SegmentRangeAndBucketEntry>,
/// One for each ParentBucketId
parent_buckets: Vec<Vec<SegmentRangeAndBucketEntry>>,
column_type: ColumnType,
pub(crate) accessor_idx: usize,
sub_agg: Option<CachedSubAggs<LOWCARD>>,
/// Here things get a bit weird. We need to assign unique bucket ids across all
/// parent buckets. So we keep track of the next available bucket id here.
/// This allows a kind of flattening of the bucket ids across all parent buckets.
/// E.g. in nested aggregations:
/// Term Agg -> Range aggregation -> Stats aggregation
/// E.g. the Term Agg creates 3 buckets ["INFO", "ERROR", "WARN"], each of these has a Range
/// aggregation with 4 buckets. The Range aggregation will create buckets with ids:
/// - INFO: 0,1,2,3
/// - ERROR: 4,5,6,7
/// - WARN: 8,9,10,11
///
/// This allows the Stats aggregation to have unique bucket ids to refer to.
bucket_id_provider: BucketIdProvider,
limits: AggregationLimitsGuard,
}
impl<const LOWCARD: bool> Debug for SegmentRangeCollector<LOWCARD> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentRangeCollector")
.field("parent_buckets_len", &self.parent_buckets.len())
.field("column_type", &self.column_type)
.field("accessor_idx", &self.accessor_idx)
.field("has_sub_agg", &self.sub_agg.is_some())
.finish()
}
}
/// TODO: Bad naming, there's also SegmentRangeAndBucketEntry
#[derive(Clone)]
pub(crate) struct SegmentRangeBucketEntry {
pub key: Key,
pub doc_count: u64,
pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
// pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
pub bucket_id: BucketId,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not
@@ -184,48 +214,50 @@ impl Debug for SegmentRangeBucketEntry {
impl SegmentRangeBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateRangeBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = self.sub_aggregation {
sub_aggregation
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?
} else {
Default::default()
};
let sub_aggregation = IntermediateAggregationResults::default();
Ok(IntermediateRangeBucketEntry {
key: self.key.into(),
doc_count: self.doc_count,
sub_aggregation: sub_aggregation_res,
sub_aggregation_res: sub_aggregation,
from: self.from,
to: self.to,
})
}
}
impl SegmentAggregationCollector for SegmentRangeCollector {
impl<const LOWCARD: bool> SegmentAggregationCollector for SegmentRangeCollector<LOWCARD> {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let field_type = self.column_type;
let name = agg_data
.get_range_req_data(self.accessor_idx)
.name
.to_string();
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
.buckets
let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = buckets
.into_iter()
.map(move |range_bucket| {
Ok((
range_to_string(&range_bucket.range, &field_type)?,
range_bucket
.bucket
.into_intermediate_bucket_entry(agg_data)?,
))
.map(|range_bucket| {
let bucket_id = range_bucket.bucket.bucket_id;
let mut agg = range_bucket.bucket.into_intermediate_bucket_entry()?;
if let Some(sub_aggregation) = &mut self.sub_agg {
sub_aggregation
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut agg.sub_aggregation_res,
bucket_id,
)?;
}
Ok((range_to_string(&range_bucket.range, &field_type)?, agg))
})
.collect::<crate::Result<_>>()?;
@@ -242,73 +274,114 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
// Take request data to avoid borrow conflicts during sub-aggregation
let mut req = agg_data.take_range_req_data(self.accessor_idx);
let req = agg_data.take_range_req_data(self.accessor_idx);
req.column_block_accessor.fetch_block(docs, &req.accessor);
agg_data
.column_block_accessor
.fetch_block(docs, &req.accessor);
for (doc, val) in req
let buckets = &mut self.parent_buckets[parent_bucket_id as usize];
for (doc, val) in agg_data
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
{
let bucket_pos = self.get_bucket_pos(val);
let bucket = &mut self.buckets[bucket_pos];
let bucket_pos = get_bucket_pos(val, buckets);
let bucket = &mut buckets[bucket_pos];
bucket.bucket.doc_count += 1;
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
sub_agg.collect(doc, agg_data)?;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.push(bucket.bucket.bucket_id, doc);
}
}
agg_data.put_back_range_req_data(self.accessor_idx, req);
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for bucket in self.buckets.iter_mut() {
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
sub_agg.flush(agg_data)?;
}
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.flush(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
let new_buckets = self.create_new_buckets(agg_data)?;
self.parent_buckets.push(new_buckets);
}
Ok(())
}
}
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
/// bucket storage, depending on the column type and aggregation level.
pub(crate) fn build_segment_range_collector(
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let accessor_idx = node.idx_in_req_data;
let req_data = agg_data.get_range_req_data(node.idx_in_req_data);
let field_type = req_data.field_type;
// TODO: A better metric instead of is_top_level would be the number of buckets expected.
// E.g. If range agg is not top level, but the parent is a bucket agg with less than 10 buckets,
// we can are still in low cardinality territory.
let is_low_card = req_data.is_top_level && req_data.req.ranges.len() <= 64;
let sub_agg = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else {
None
};
if is_low_card {
Ok(Box::new(SegmentRangeCollector {
sub_agg: sub_agg.map(CachedSubAggs::<true>::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
} else {
Ok(Box::new(SegmentRangeCollector {
sub_agg: sub_agg.map(CachedSubAggs::<false>::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
}
}
impl SegmentRangeCollector {
pub(crate) fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let accessor_idx = node.idx_in_req_data;
let (field_type, ranges) = {
let req_view = req_data.get_range_req_data(node.idx_in_req_data);
(req_view.field_type, req_view.req.ranges.clone())
};
impl<const LOWCARD: bool> SegmentRangeCollector<LOWCARD> {
pub(crate) fn create_new_buckets(
&mut self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<Vec<SegmentRangeAndBucketEntry>> {
let field_type = self.column_type;
let req_data = agg_data.get_range_req_data(self.accessor_idx);
// The range input on the request is f64.
// We need to convert to u64 ranges, because we read the values as u64.
// The mapping from the conversion is monotonic so ordering is preserved.
let sub_agg_prototype = if !node.children.is_empty() {
Some(build_segment_agg_collectors(req_data, &node.children)?)
} else {
None
};
let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)?
let buckets: Vec<_> = extend_validate_ranges(&req_data.req.ranges, &field_type)?
.iter()
.map(|range| {
let bucket_id = self.bucket_id_provider.next_bucket_id();
let key = range
.key
.clone()
@@ -317,20 +390,20 @@ impl SegmentRangeCollector {
let to = if range.range.end == u64::MAX {
None
} else {
Some(f64_from_fastfield_u64(range.range.end, &field_type))
Some(f64_from_fastfield_u64(range.range.end, field_type))
};
let from = if range.range.start == u64::MIN {
None
} else {
Some(f64_from_fastfield_u64(range.range.start, &field_type))
Some(f64_from_fastfield_u64(range.range.start, field_type))
};
let sub_aggregation = sub_agg_prototype.clone();
// let sub_aggregation = sub_agg_prototype.clone();
Ok(SegmentRangeAndBucketEntry {
range: range.range.clone(),
bucket: SegmentRangeBucketEntry {
doc_count: 0,
sub_aggregation,
bucket_id,
key,
from,
to,
@@ -339,27 +412,20 @@ impl SegmentRangeCollector {
})
.collect::<crate::Result<_>>()?;
req_data.context.limits.add_memory_consumed(
self.limits.add_memory_consumed(
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
)?;
Ok(SegmentRangeCollector {
buckets,
column_type: field_type,
accessor_idx,
})
}
#[inline]
fn get_bucket_pos(&self, val: u64) -> usize {
let pos = self
.buckets
.binary_search_by_key(&val, |probe| probe.range.start)
.unwrap_or_else(|pos| pos - 1);
debug_assert!(self.buckets[pos].range.contains(&val));
pos
Ok(buckets)
}
}
#[inline]
fn get_bucket_pos(val: u64, buckets: &[SegmentRangeAndBucketEntry]) -> usize {
let pos = buckets
.binary_search_by_key(&val, |probe| probe.range.start)
.unwrap_or_else(|pos| pos - 1);
debug_assert!(buckets[pos].range.contains(&val));
pos
}
/// Converts the user provided f64 range value to fast field value space.
///
@@ -456,7 +522,7 @@ pub(crate) fn range_to_string(
let val = i64::from_u64(val);
format_date(val)
} else {
Ok(f64_from_fastfield_u64(val, field_type).to_string())
Ok(f64_from_fastfield_u64(val, *field_type).to_string())
}
};
@@ -506,30 +572,33 @@ mod tests {
let to = if range.range.end == u64::MAX {
None
} else {
Some(f64_from_fastfield_u64(range.range.end, &field_type))
Some(f64_from_fastfield_u64(range.range.end, field_type))
};
let from = if range.range.start == u64::MIN {
None
} else {
Some(f64_from_fastfield_u64(range.range.start, &field_type))
Some(f64_from_fastfield_u64(range.range.start, field_type))
};
SegmentRangeAndBucketEntry {
range: range.range.clone(),
bucket: SegmentRangeBucketEntry {
doc_count: 0,
sub_aggregation: None,
key,
from,
to,
bucket_id: 0,
},
}
})
.collect();
SegmentRangeCollector {
buckets,
parent_buckets: vec![buckets],
column_type: field_type,
accessor_idx: 0,
sub_agg: None,
bucket_id_provider: Default::default(),
limits: AggregationLimitsGuard::default(),
}
}
@@ -776,7 +845,7 @@ mod tests {
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
let buckets = collector.parent_buckets[0].clone();
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -799,7 +868,7 @@ mod tests {
];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
let buckets = collector.parent_buckets[0].clone();
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -814,7 +883,7 @@ mod tests {
let buckets = vec![(-10f64..-1f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
let buckets = collector.parent_buckets[0].clone();
assert_eq!(&buckets[0].bucket.key.to_string(), "*--10");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*");
}
@@ -823,7 +892,7 @@ mod tests {
let buckets = vec![(0f64..10f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
let buckets = collector.parent_buckets[0].clone();
assert_eq!(&buckets[0].bucket.key.to_string(), "*-0");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*");
}
@@ -832,7 +901,7 @@ mod tests {
fn range_binary_search_test_u64() {
let check_ranges = |ranges: Vec<RangeAggregationRange>| {
let collector = get_collector_from_ranges(ranges, ColumnType::U64);
let search = |val: u64| collector.get_bucket_pos(val);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9), 0);
@@ -878,7 +947,7 @@ mod tests {
let ranges = vec![(10.0..100.0).into()];
let collector = get_collector_from_ranges(ranges, ColumnType::F64);
let search = |val: u64| collector.get_bucket_pos(val);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9f64.to_u64()), 0);
@@ -890,63 +959,3 @@ mod tests {
// the max value
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use itertools::Itertools;
use rand::seq::SliceRandom;
use rand::thread_rng;
use super::*;
use crate::aggregation::bucket::range::tests::get_collector_from_ranges;
const TOTAL_DOCS: u64 = 1_000_000u64;
const NUM_DOCS: u64 = 50_000u64;
fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector {
let bucket_size = num_docs / num_buckets;
let mut buckets: Vec<RangeAggregationRange> = vec![];
for i in 0..num_buckets {
let bucket_start = (i * bucket_size) as f64;
buckets.push((bucket_start..bucket_start + bucket_size as f64).into())
}
get_collector_from_ranges(buckets, ColumnType::U64)
}
fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec<u64> {
let mut rng = thread_rng();
let all_docs = (0..total_docs - 1).collect_vec();
let mut vals = all_docs
.as_slice()
.choose_multiple(&mut rng, num_docs_returned as usize)
.cloned()
.collect_vec();
vals.sort();
vals
}
fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) {
let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS);
let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS);
b.iter(|| {
let mut bucket_pos = 0;
for val in &vals {
bucket_pos = collector.get_bucket_pos(*val);
}
bucket_pos
})
}
#[bench]
fn bench_range_100_buckets(b: &mut test::Bencher) {
bench_range_binary_search(b, 100)
}
#[bench]
fn bench_range_10_buckets(b: &mut test::Bencher) {
bench_range_binary_search(b, 10)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -5,11 +5,13 @@ use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::term_agg::TermsAggregation;
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
/// Special aggregation to handle missing values for term aggregations.
/// This missing aggregation will check multiple columns for existence.
@@ -35,41 +37,55 @@ impl MissingTermAggReqData {
}
}
/// The specialized missing term aggregation.
#[derive(Default, Debug, Clone)]
pub struct TermMissingAgg {
struct MissingCount {
missing_count: u32,
bucket_id: BucketId,
}
/// The specialized missing term aggregation.
#[derive(Default, Debug)]
pub struct TermMissingAgg {
accessor_idx: usize,
sub_agg: Option<Box<dyn SegmentAggregationCollector>>,
sub_agg: Option<CachedSubAggs>,
/// Idx = parent bucket id, Value = missing count for that bucket
missing_count_per_bucket: Vec<MissingCount>,
bucket_id_provider: BucketIdProvider,
}
impl TermMissingAgg {
pub(crate) fn new(
req_data: &mut AggregationsSegmentCtx,
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let has_sub_aggregations = !node.children.is_empty();
let accessor_idx = node.idx_in_req_data;
let sub_agg = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
let sub_aggregation = build_segment_agg_collectors(agg_data, &node.children)?;
Some(sub_aggregation)
} else {
None
};
let sub_agg = sub_agg.map(CachedSubAggs::new);
let bucket_id_provider = BucketIdProvider::default();
Ok(Self {
accessor_idx,
sub_agg,
..Default::default()
missing_count_per_bucket: Vec::new(),
bucket_id_provider,
})
}
}
impl SegmentAggregationCollector for TermMissingAgg {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
let term_agg = &req_data.req;
let missing = term_agg
@@ -80,13 +96,16 @@ impl SegmentAggregationCollector for TermMissingAgg {
let mut entries: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> =
Default::default();
let missing_count = &self.missing_count_per_bucket[parent_bucket_id as usize];
let mut missing_entry = IntermediateTermBucketEntry {
doc_count: self.missing_count,
doc_count: missing_count.missing_count,
sub_aggregation: Default::default(),
};
if let Some(sub_agg) = self.sub_agg {
if let Some(sub_agg) = &mut self.sub_agg {
let mut res = IntermediateAggregationResults::default();
sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?;
sub_agg
.get_sub_agg_collector()
.add_intermediate_aggregation_result(agg_data, &mut res, missing_count.bucket_id)?;
missing_entry.sub_aggregation = res;
}
entries.insert(missing.into(), missing_entry);
@@ -109,30 +128,52 @@ impl SegmentAggregationCollector for TermMissingAgg {
fn collect(
&mut self,
doc: crate::DocId,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let bucket = &mut self.missing_count_per_bucket[parent_bucket_id as usize];
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
let has_value = req_data
.accessors
.iter()
.any(|(acc, _)| acc.index.has_value(doc));
if !has_value {
self.missing_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.collect(doc, agg_data)?;
for doc in docs {
let doc = *doc;
let has_value = req_data
.accessors
.iter()
.any(|(acc, _)| acc.index.has_value(doc));
if !has_value {
bucket.missing_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.push(bucket.bucket_id, doc);
}
}
}
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn collect_block(
fn prepare_max_bucket(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
for doc in docs {
self.collect(*doc, agg_data)?;
while self.missing_count_per_bucket.len() <= max_bucket as usize {
let bucket_id = self.bucket_id_provider.next_bucket_id();
self.missing_count_per_bucket.push(MissingCount {
missing_count: 0,
bucket_id,
});
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.flush(agg_data)?;
}
Ok(())
}

View File

@@ -1,83 +0,0 @@
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::DocId;
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
/// BufAggregationCollector buffers documents before calling collect_block().
#[derive(Clone)]
pub(crate) struct BufAggregationCollector {
pub(crate) collector: Box<dyn SegmentAggregationCollector>,
staged_docs: DocBlock,
num_staged_docs: usize,
}
impl std::fmt::Debug for BufAggregationCollector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentAggregationResultsCollector")
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
.field("num_staged_docs", &self.num_staged_docs)
.finish()
}
}
impl BufAggregationCollector {
pub fn new(collector: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
collector,
num_staged_docs: 0,
staged_docs: [0; DOC_BLOCK_SIZE],
}
}
}
impl SegmentAggregationCollector for BufAggregationCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results)
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.staged_docs[self.num_staged_docs] = doc;
self.num_staged_docs += 1;
if self.num_staged_docs == self.staged_docs.len() {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
self.num_staged_docs = 0;
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collector.collect_block(docs, agg_data)?;
Ok(())
}
#[inline]
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
self.num_staged_docs = 0;
self.collector.flush(agg_data)?;
Ok(())
}
}

View File

@@ -0,0 +1,185 @@
use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC;
use crate::aggregation::BucketId;
use crate::DocId;
#[derive(Debug)]
/// A cache for sub-aggregations, storing doc ids per bucket id.
/// Depending on the cardinality of the parent aggregation, we use different
/// storage strategies.
///
/// ## Low Cardinality
/// Cardinality here refers to the number of unique flattened buckets that can be created
/// by the parent aggregation.
/// Flattened buckets are the result of combining all buckets per collector
/// into a single list of buckets, where each bucket is identified by its BucketId.
///
/// ## Usage
/// Since this is caching for sub-aggregations, it is only used by bucket
/// aggregations.
///
/// TODO: consider using a more advanced data structure for high cardinality
/// aggregations.
/// What this datastructure does in general is to group docs by bucket id.
pub(crate) struct CachedSubAggs<const LOWCARD: bool = false> {
/// Only used when LOWCARD is true.
/// Cache doc ids per bucket for sub-aggregations.
///
/// The outer Vec is indexed by BucketId.
per_bucket_docs: Vec<Vec<DocId>>,
/// Only used when LOWCARD is false.
///
/// This weird partitioning is used to do some cheap grouping on the bucket ids.
/// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality,
/// but there are just 16 bucket ids, each bucket id will go to its own partition.
///
/// We want to keep this cheap, because high cardinality aggregations can have a lot of
/// buckets, and they may be nothing to group.
partitions: [PartitionEntry; NUM_PARTITIONS],
pub(crate) sub_agg_collector: Box<dyn SegmentAggregationCollector>,
num_docs: usize,
}
const FLUSH_THRESHOLD: usize = 2048;
const NUM_PARTITIONS: usize = 16;
impl<const LOWCARD: bool> CachedSubAggs<LOWCARD> {
pub fn get_sub_agg_collector(&mut self) -> &mut Box<dyn SegmentAggregationCollector> {
&mut self.sub_agg_collector
}
pub fn new(sub_agg: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
per_bucket_docs: Vec::new(),
num_docs: 0,
sub_agg_collector: sub_agg,
partitions: core::array::from_fn(|_| PartitionEntry::default()),
}
}
#[inline]
pub fn clear(&mut self) {
for v in &mut self.per_bucket_docs {
v.clear();
}
for partition in &mut self.partitions {
partition.clear();
}
self.num_docs = 0;
}
#[inline]
pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
if LOWCARD {
// TODO: We could flush single buckets here
let idx = bucket_id as usize;
if self.per_bucket_docs.len() <= idx {
self.per_bucket_docs.resize_with(idx + 1, Vec::new);
}
self.per_bucket_docs[idx].push(doc_id);
} else {
let idx = bucket_id % NUM_PARTITIONS as u32;
let slot = &mut self.partitions[idx as usize];
slot.bucket_ids.push(bucket_id);
slot.docs.push(doc_id);
}
self.num_docs += 1;
}
/// Check if we need to flush based on the number of documents cached.
/// If so, flushes the cache to the provided aggregation collector.
pub fn check_flush_local(
&mut self,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if self.num_docs >= FLUSH_THRESHOLD {
self.flush_local(agg_data, false)?;
}
Ok(())
}
/// Note: this does _not_ flush the sub aggregations
fn flush_local(
&mut self,
agg_data: &mut AggregationsSegmentCtx,
force: bool,
) -> crate::Result<()> {
if LOWCARD {
// Pre-aggregated: call collect per bucket.
let max_bucket = (self.per_bucket_docs.len() as BucketId).saturating_sub(1);
self.sub_agg_collector
.prepare_max_bucket(max_bucket, agg_data)?;
// The threshold above which we flush buckets individually.
// Note: We need to make sure that we don't lock ourselves into a situation where we hit
// the FLUSH_THRESHOLD, but never flush any buckets. (except the final flush)
let mut bucket_treshold = FLUSH_THRESHOLD / (self.per_bucket_docs.len().max(1) * 2);
const _: () = {
// MAX_NUM_TERMS_FOR_VEC == LOWCARD threshold
let bucket_treshold = FLUSH_THRESHOLD / (MAX_NUM_TERMS_FOR_VEC as usize * 2);
assert!(
bucket_treshold > 0,
"Bucket threshold must be greater than 0"
);
};
if force {
bucket_treshold = 0;
}
for (bucket_id, docs) in self
.per_bucket_docs
.iter()
.enumerate()
.filter(|(_, docs)| docs.len() > bucket_treshold)
{
self.sub_agg_collector
.collect(bucket_id as BucketId, docs, agg_data)?;
}
} else {
let mut max_bucket = 0u32;
for partition in &self.partitions {
if let Some(&local_max) = partition.bucket_ids.iter().max() {
max_bucket = max_bucket.max(local_max);
}
}
self.sub_agg_collector
.prepare_max_bucket(max_bucket, agg_data)?;
for slot in &self.partitions {
if !slot.bucket_ids.is_empty() {
// Reduce dynamic dispatch overhead by collecting a full partition in one call.
self.sub_agg_collector.collect_multiple(
&slot.bucket_ids,
&slot.docs,
agg_data,
)?;
}
}
}
self.clear();
Ok(())
}
/// Note: this _does_ flush the sub aggregations
pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if self.num_docs != 0 {
self.flush_local(agg_data, true)?;
}
self.sub_agg_collector.flush(agg_data)?;
Ok(())
}
}
#[derive(Debug, Clone, Default)]
struct PartitionEntry {
bucket_ids: Vec<BucketId>,
docs: Vec<DocId>,
}
impl PartitionEntry {
#[inline]
fn clear(&mut self) {
self.bucket_ids.clear();
self.docs.clear();
}
}

View File

@@ -1,9 +1,9 @@
use super::agg_req::Aggregations;
use super::agg_result::AggregationResults;
use super::buf_collector::BufAggregationCollector;
use super::cached_sub_aggs::CachedSubAggs;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationCollector;
use super::AggContextParams;
// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly.
use crate::aggregation::agg_data::{
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
};
@@ -136,7 +136,7 @@ fn merge_fruits(
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
pub struct AggregationSegmentCollector {
aggs_with_accessor: AggregationsSegmentCtx,
agg_collector: BufAggregationCollector,
agg_collector: CachedSubAggs<true>,
error: Option<TantivyError>,
}
@@ -151,8 +151,10 @@ impl AggregationSegmentCollector {
) -> crate::Result<Self> {
let mut agg_data =
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
let result =
BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?);
let mut result = CachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
result
.get_sub_agg_collector()
.prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero
Ok(AggregationSegmentCollector {
aggs_with_accessor: agg_data,
@@ -170,26 +172,31 @@ impl SegmentCollector for AggregationSegmentCollector {
if self.error.is_some() {
return;
}
if let Err(err) = self
self.agg_collector.push(0, doc);
match self
.agg_collector
.collect(doc, &mut self.aggs_with_accessor)
.check_flush_local(&mut self.aggs_with_accessor)
{
self.error = Some(err);
Ok(_) => {}
Err(e) => {
self.error = Some(e);
}
}
}
/// The query pushes the documents to the collector via this method.
///
/// Only valid for Collectors that ignore docs
fn collect_block(&mut self, docs: &[DocId]) {
if self.error.is_some() {
return;
}
if let Err(err) = self
.agg_collector
.collect_block(docs, &mut self.aggs_with_accessor)
{
self.error = Some(err);
match self.agg_collector.get_sub_agg_collector().collect(
0,
docs,
&mut self.aggs_with_accessor,
) {
Ok(_) => {}
Err(e) => {
self.error = Some(e);
}
}
}
@@ -200,10 +207,13 @@ impl SegmentCollector for AggregationSegmentCollector {
self.agg_collector.flush(&mut self.aggs_with_accessor)?;
let mut sub_aggregation_res = IntermediateAggregationResults::default();
Box::new(self.agg_collector).add_intermediate_aggregation_result(
&self.aggs_with_accessor,
&mut sub_aggregation_res,
)?;
self.agg_collector
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
&self.aggs_with_accessor,
&mut sub_aggregation_res,
0,
)?;
Ok(sub_aggregation_res)
}

View File

@@ -792,7 +792,7 @@ pub struct IntermediateRangeBucketEntry {
/// The number of documents in the bucket.
pub doc_count: u64,
/// The sub_aggregation in this bucket.
pub sub_aggregation: IntermediateAggregationResults,
pub sub_aggregation_res: IntermediateAggregationResults,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`.
@@ -811,7 +811,7 @@ impl IntermediateRangeBucketEntry {
key: self.key.into(),
doc_count: self.doc_count,
sub_aggregation: self
.sub_aggregation
.sub_aggregation_res
.into_final_result_internal(req, limits)?,
to: self.to,
from: self.from,
@@ -857,7 +857,8 @@ impl MergeFruits for IntermediateTermBucketEntry {
impl MergeFruits for IntermediateRangeBucketEntry {
fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> {
self.doc_count += other.doc_count;
self.sub_aggregation.merge_fruits(other.sub_aggregation)?;
self.sub_aggregation_res
.merge_fruits(other.sub_aggregation_res)?;
Ok(())
}
}
@@ -887,7 +888,7 @@ mod tests {
IntermediateRangeBucketEntry {
key: IntermediateKey::Str(key.to_string()),
doc_count: *doc_count,
sub_aggregation: Default::default(),
sub_aggregation_res: Default::default(),
from: None,
to: None,
},
@@ -920,7 +921,7 @@ mod tests {
doc_count: *doc_count,
from: None,
to: None,
sub_aggregation: get_sub_test_tree(&[(
sub_aggregation_res: get_sub_test_tree(&[(
sub_aggregation_key.to_string(),
*sub_aggregation_count,
)]),

View File

@@ -52,10 +52,8 @@ pub struct IntermediateAverage {
impl IntermediateAverage {
/// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateAverage) {

View File

@@ -2,7 +2,7 @@ use std::collections::hash_map::DefaultHasher;
use std::hash::{BuildHasher, Hasher};
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{Column, ColumnBlockAccessor, ColumnType, Dictionary, StrColumn};
use columnar::{Column, ColumnType, Dictionary, StrColumn};
use common::f64_to_u64;
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
use rustc_hash::FxHashSet;
@@ -106,8 +106,6 @@ pub struct CardinalityAggReqData {
pub str_dict_column: Option<StrColumn>,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_value_for_accessor: Option<u64>,
/// The column block accessor to access the fast field values.
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
/// The name of the aggregation.
pub name: String,
/// The aggregation request.
@@ -135,45 +133,34 @@ impl CardinalityAggregationReq {
}
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentCardinalityCollector {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
buckets: Vec<SegmentCardinalityCollectorBucket>,
accessor_idx: usize,
/// The column accessor to access the fast field values.
accessor: Column<u64>,
/// The column_type of the field.
column_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
missing_value_for_accessor: Option<u64>,
}
impl SegmentCardinalityCollector {
pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self {
#[derive(Clone, Debug, PartialEq, Default)]
pub(crate) struct SegmentCardinalityCollectorBucket {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
}
impl SegmentCardinalityCollectorBucket {
pub fn new(column_type: ColumnType) -> Self {
Self {
cardinality: CardinalityCollector::new(column_type as u8),
entries: Default::default(),
accessor_idx,
entries: FxHashSet::default(),
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_data: &mut CardinalityAggReqData,
) {
if let Some(missing) = agg_data.missing_value_for_accessor {
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&agg_data.accessor,
missing,
);
} else {
agg_data
.column_block_accessor
.fetch_block(docs, &agg_data.accessor);
}
}
fn into_intermediate_metric_result(
mut self,
agg_data: &AggregationsSegmentCtx,
req_data: &CardinalityAggReqData,
) -> crate::Result<IntermediateMetricResult> {
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
if req_data.column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let dict = req_data
@@ -194,6 +181,7 @@ impl SegmentCardinalityCollector {
term_ids.push(term_ord as u32);
}
}
term_ids.sort_unstable();
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
self.cardinality.sketch.insert_any(&term);
@@ -227,16 +215,49 @@ impl SegmentCardinalityCollector {
}
}
impl SegmentCardinalityCollector {
pub fn from_req(
column_type: ColumnType,
accessor_idx: usize,
accessor: Column<u64>,
missing_value_for_accessor: Option<u64>,
) -> Self {
Self {
buckets: vec![SegmentCardinalityCollectorBucket::new(column_type); 1],
column_type,
accessor_idx,
accessor,
missing_value_for_accessor,
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) {
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_value_for_accessor,
);
}
}
impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
let name = req_data.name.to_string();
// take the bucket in buckets and replace it with a new empty one
let bucket = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let intermediate_result = self.into_intermediate_metric_result(agg_data)?;
let intermediate_result = bucket.into_intermediate_metric_result(req_data)?;
results.push(
name,
IntermediateAggregationResult::Metric(intermediate_result),
@@ -247,27 +268,20 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx);
self.fetch_block_with_field(docs, req_data);
self.fetch_block_with_field(docs, agg_data);
let bucket = &mut self.buckets[parent_bucket_id as usize];
let col_block_accessor = &req_data.column_block_accessor;
if req_data.column_type == ColumnType::Str {
let col_block_accessor = &agg_data.column_block_accessor;
if self.column_type == ColumnType::Str {
for term_ord in col_block_accessor.iter_vals() {
self.entries.insert(term_ord);
bucket.entries.insert(term_ord);
}
} else if req_data.column_type == ColumnType::IpAddr {
let compact_space_accessor = req_data
} else if self.column_type == ColumnType::IpAddr {
let compact_space_accessor = self
.accessor
.values
.clone()
@@ -282,16 +296,29 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
})?;
for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
self.cardinality.sketch.insert_any(&val);
bucket.cardinality.sketch.insert_any(&val);
}
} else {
for val in col_block_accessor.iter_vals() {
self.cardinality.sketch.insert_any(&val);
bucket.cardinality.sketch.insert_any(&val);
}
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
if max_bucket as usize >= self.buckets.len() {
self.buckets.resize_with(max_bucket as usize + 1, || {
SegmentCardinalityCollectorBucket::new(self.column_type)
});
}
Ok(())
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]

View File

@@ -52,10 +52,8 @@ pub struct IntermediateCount {
impl IntermediateCount {
/// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateCount) {

View File

@@ -8,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::{DocId, TantivyError};
use crate::TantivyError;
/// A multi-value metric aggregation that computes a collection of extended statistics
/// on numeric values that are extracted
@@ -62,7 +61,7 @@ impl ExtendedStatsAggregation {
/// Extended stats contains a collection of statistics
/// they extends stats adding variance, standard deviation
/// and bound informations
/// and bound information
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ExtendedStats {
/// The number of documents.
@@ -318,51 +317,28 @@ impl IntermediateExtendedStats {
}
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentExtendedStatsCollector {
name: String,
missing: Option<u64>,
field_type: ColumnType,
pub(crate) extended_stats: IntermediateExtendedStats,
pub(crate) accessor_idx: usize,
val_cache: Vec<u64>,
accessor: columnar::Column<u64>,
buckets: Vec<IntermediateExtendedStats>,
sigma: Option<f64>,
}
impl SegmentExtendedStatsCollector {
pub fn from_req(
field_type: ColumnType,
sigma: Option<f64>,
accessor_idx: usize,
missing: Option<f64>,
) -> Self {
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
pub fn from_req(req: &MetricAggReqData, sigma: Option<f64>) -> Self {
let missing = req
.missing
.and_then(|val| f64_to_fastfield_u64(val, &req.field_type));
Self {
field_type,
extended_stats: IntermediateExtendedStats::with_sigma(sigma),
accessor_idx,
name: req.name.clone(),
field_type: req.field_type,
accessor: req.accessor.clone(),
missing,
val_cache: Default::default(),
}
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = self.missing.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
buckets: vec![IntermediateExtendedStats::with_sigma(sigma); 16],
sigma,
}
}
}
@@ -370,15 +346,18 @@ impl SegmentExtendedStatsCollector {
impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
let name = self.name.clone();
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let extended_stats = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
results.push(
name,
IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats(
self.extended_stats,
extended_stats,
)),
)?;
@@ -388,39 +367,36 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
if let Some(missing) = self.missing {
let mut has_val = false;
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
has_val = true;
}
if !has_val {
self.extended_stats
.collect(f64_from_fastfield_u64(missing, &self.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
}
let mut extended_stats = self.buckets[parent_bucket_id as usize].clone();
agg_data
.column_block_accessor
.fetch_block_with_missing(docs, &self.accessor, self.missing);
for val in agg_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, self.field_type);
extended_stats.collect(val1);
}
// store back
self.buckets[parent_bucket_id as usize] = extended_stats;
Ok(())
}
#[inline]
fn collect_block(
fn prepare_max_bucket(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
if self.buckets.len() <= max_bucket as usize {
self.buckets.resize_with(max_bucket as usize + 1, || {
IntermediateExtendedStats::with_sigma(self.sigma)
});
}
Ok(())
}
}

View File

@@ -52,10 +52,8 @@ pub struct IntermediateMax {
impl IntermediateMax {
/// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateMax) {

View File

@@ -52,10 +52,8 @@ pub struct IntermediateMin {
impl IntermediateMin {
/// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateMin) {

View File

@@ -31,7 +31,7 @@ use std::collections::HashMap;
pub use average::*;
pub use cardinality::*;
use columnar::{Column, ColumnBlockAccessor, ColumnType};
use columnar::{Column, ColumnType};
pub use count::*;
pub use extended_stats::*;
pub use max::*;
@@ -55,8 +55,6 @@ pub struct MetricAggReqData {
pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// Used when converting to intermediate result

View File

@@ -7,10 +7,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::{DocId, TantivyError};
use crate::TantivyError;
/// # Percentiles
///
@@ -131,10 +130,16 @@ impl PercentilesAggregationReq {
}
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentPercentilesCollector {
pub(crate) percentiles: PercentilesCollector,
pub(crate) buckets: Vec<PercentilesCollector>,
pub(crate) accessor_idx: usize,
/// The type of the field.
pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>,
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
}
#[derive(Clone, Serialize, Deserialize)]
@@ -229,33 +234,18 @@ impl PercentilesCollector {
}
impl SegmentPercentilesCollector {
pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result<Self> {
Ok(Self {
percentiles: PercentilesCollector::new(),
pub fn from_req_and_validate(
field_type: ColumnType,
missing_u64: Option<u64>,
accessor: Column<u64>,
accessor_idx: usize,
) -> Self {
Self {
buckets: Vec::with_capacity(64),
field_type,
missing_u64,
accessor,
accessor_idx,
})
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = req_data.missing_u64.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
}
}
}
@@ -263,12 +253,18 @@ impl SegmentPercentilesCollector {
impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles);
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
// Swap collector with an empty one to avoid cloning
let percentiles_collector = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let intermediate_metric_result =
IntermediateMetricResult::Percentiles(percentiles_collector);
results.push(
name,
@@ -281,40 +277,33 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
let percentiles = &mut self.buckets[parent_bucket_id as usize];
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_u64,
);
if let Some(missing) = req_data.missing_u64 {
let mut has_val = false;
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
has_val = true;
}
if !has_val {
self.percentiles
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
}
for val in agg_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, self.field_type);
percentiles.collect(val1);
}
Ok(())
}
#[inline]
fn collect_block(
fn prepare_max_bucket(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
while self.buckets.len() <= max_bucket as usize {
self.buckets.push(PercentilesCollector::new());
}
Ok(())
}
}

View File

@@ -1,5 +1,6 @@
use std::fmt::Debug;
use columnar::{Column, ColumnType};
use serde::{Deserialize, Serialize};
use super::*;
@@ -7,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::{DocId, TantivyError};
use crate::TantivyError;
/// A multi-value metric aggregation that computes a collection of statistics on numeric values that
/// are extracted from the aggregated documents.
@@ -83,7 +83,7 @@ impl Stats {
/// Intermediate result of the stats aggregation that can be combined with other intermediate
/// results.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateStats {
/// The number of extracted values.
pub(crate) count: u64,
@@ -187,75 +187,75 @@ pub enum StatsType {
Percentiles,
}
fn create_collector<const TYPE_ID: u8>(
req: &MetricAggReqData,
) -> Box<dyn SegmentAggregationCollector> {
Box::new(SegmentStatsCollector::<TYPE_ID> {
name: req.name.clone(),
collecting_for: req.collecting_for,
is_number_or_date_type: req.is_number_or_date_type,
missing_u64: req.missing_u64,
accessor: req.accessor.clone(),
buckets: vec![IntermediateStats::default()],
})
}
/// Build a concrete `SegmentStatsCollector` depending on the column type.
pub(crate) fn build_segment_stats_collector(
req: &MetricAggReqData,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
match req.field_type {
ColumnType::I64 => Ok(create_collector::<{ ColumnType::I64 as u8 }>(req)),
ColumnType::U64 => Ok(create_collector::<{ ColumnType::U64 as u8 }>(req)),
ColumnType::F64 => Ok(create_collector::<{ ColumnType::F64 as u8 }>(req)),
ColumnType::Bool => Ok(create_collector::<{ ColumnType::Bool as u8 }>(req)),
ColumnType::DateTime => Ok(create_collector::<{ ColumnType::DateTime as u8 }>(req)),
ColumnType::Bytes => Ok(create_collector::<{ ColumnType::Bytes as u8 }>(req)),
ColumnType::Str => Ok(create_collector::<{ ColumnType::Str as u8 }>(req)),
ColumnType::IpAddr => Ok(create_collector::<{ ColumnType::IpAddr as u8 }>(req)),
}
}
#[repr(C)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentStatsCollector {
pub(crate) stats: IntermediateStats,
pub(crate) accessor_idx: usize,
pub(crate) struct SegmentStatsCollector<const COLUMN_TYPE_ID: u8> {
pub(crate) missing_u64: Option<u64>,
pub(crate) accessor: Column<u64>,
pub(crate) is_number_or_date_type: bool,
pub(crate) buckets: Vec<IntermediateStats>,
pub(crate) name: String,
pub(crate) collecting_for: StatsType,
}
impl SegmentStatsCollector {
pub fn from_req(accessor_idx: usize) -> Self {
Self {
stats: IntermediateStats::default(),
accessor_idx,
}
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = req_data.missing_u64.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
if req_data.is_number_or_date_type {
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
}
} else {
for _val in req_data.column_block_accessor.iter_vals() {
// we ignore the value and simply record that we got something
self.stats.collect(0.0);
}
}
}
}
impl SegmentAggregationCollector for SegmentStatsCollector {
impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
for SegmentStatsCollector<COLUMN_TYPE_ID>
{
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let req = agg_data.get_metric_req_data(self.accessor_idx);
let name = req.name.clone();
let name = self.name.clone();
let intermediate_metric_result = match req.collecting_for {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let stats = self.buckets[parent_bucket_id as usize];
let intermediate_metric_result = match self.collecting_for {
StatsType::Average => {
IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self))
IntermediateMetricResult::Average(IntermediateAverage::from_stats(stats))
}
StatsType::Count => {
IntermediateMetricResult::Count(IntermediateCount::from_collector(*self))
IntermediateMetricResult::Count(IntermediateCount::from_stats(stats))
}
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)),
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)),
StatsType::Stats => IntermediateMetricResult::Stats(self.stats),
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)),
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_stats(stats)),
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_stats(stats)),
StatsType::Stats => IntermediateMetricResult::Stats(stats),
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_stats(stats)),
_ => {
return Err(TantivyError::InvalidArgument(format!(
"Unsupported stats type for stats aggregation: {:?}",
req.collecting_for
self.collecting_for
)))
}
};
@@ -271,41 +271,67 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
if let Some(missing) = req_data.missing_u64 {
let mut has_val = false;
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
has_val = true;
}
if !has_val {
self.stats
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
}
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
// TODO: remove once we fetch all values for all bucket ids in one go
if docs.len() == 1 && self.missing_u64.is_none() {
collect_stats::<COLUMN_TYPE_ID>(
&mut self.buckets[parent_bucket_id as usize],
self.accessor.values_for_doc(docs[0]),
self.is_number_or_date_type,
)?;
return Ok(());
}
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_u64,
);
collect_stats::<COLUMN_TYPE_ID>(
&mut self.buckets[parent_bucket_id as usize],
agg_data.column_block_accessor.iter_vals(),
self.is_number_or_date_type,
)?;
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let required_buckets = (max_bucket as usize) + 1;
if self.buckets.len() < required_buckets {
self.buckets
.resize_with(required_buckets, IntermediateStats::default);
}
Ok(())
}
}
#[inline]
fn collect_stats<const COLUMN_TYPE_ID: u8>(
stats: &mut IntermediateStats,
vals: impl Iterator<Item = u64>,
is_number_or_date_type: bool,
) -> crate::Result<()> {
if is_number_or_date_type {
for val in vals {
let val1 = convert_to_f64::<COLUMN_TYPE_ID>(val);
stats.collect(val1);
}
} else {
for _val in vals {
// we ignore the value and simply record that we got something
stats.collect(0.0);
}
}
Ok(())
}
#[cfg(test)]

View File

@@ -52,10 +52,8 @@ pub struct IntermediateSum {
impl IntermediateSum {
/// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateSum) {

View File

@@ -15,11 +15,11 @@ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::AggregationError;
use crate::aggregation::{AggregationError, BucketId};
use crate::collector::sort_key::ReverseComparator;
use crate::collector::TopNComputer;
use crate::schema::OwnedValue;
use crate::{DocAddress, DocId, SegmentOrdinal};
// duplicate import removed; already imported above
/// Contains all information required by the TopHitsSegmentCollector to perform the
/// top_hits aggregation on a segment.
@@ -458,7 +458,7 @@ impl Eq for DocSortValuesAndFields {}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct TopHitsTopNComputer {
req: TopHitsAggregationReq,
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, false>,
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, ReverseComparator>,
}
impl std::cmp::PartialEq for TopHitsTopNComputer {
@@ -471,7 +471,10 @@ impl TopHitsTopNComputer {
/// Create a new TopHitsCollector
pub fn new(req: &TopHitsAggregationReq) -> Self {
Self {
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
top_n: TopNComputer::new_with_comparator(
req.size + req.from.unwrap_or(0),
ReverseComparator,
),
req: req.clone(),
}
}
@@ -482,7 +485,7 @@ impl TopHitsTopNComputer {
pub(crate) fn merge_fruits(&mut self, other_fruit: Self) -> crate::Result<()> {
for doc in other_fruit.top_n.into_vec() {
self.collect(doc.feature, doc.doc);
self.collect(doc.sort_key, doc.doc);
}
Ok(())
}
@@ -494,9 +497,9 @@ impl TopHitsTopNComputer {
.into_sorted_vec()
.into_iter()
.map(|doc| TopHitsVecEntry {
sort: doc.feature.sorts.iter().map(|f| f.value).collect(),
sort: doc.sort_key.sorts.iter().map(|f| f.value).collect(),
doc_value_fields: doc
.feature
.sort_key
.doc_value_fields
.into_iter()
.map(|(k, v)| (k, v.into()))
@@ -517,7 +520,8 @@ impl TopHitsTopNComputer {
pub(crate) struct TopHitsSegmentCollector {
segment_ordinal: SegmentOrdinal,
accessor_idx: usize,
top_n: TopNComputer<Vec<DocValueAndOrder>, DocAddress, false>,
buckets: Vec<TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>>,
num_hits: usize,
}
impl TopHitsSegmentCollector {
@@ -526,25 +530,35 @@ impl TopHitsSegmentCollector {
accessor_idx: usize,
segment_ordinal: SegmentOrdinal,
) -> Self {
let num_hits = req.size + req.from.unwrap_or(0);
Self {
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
num_hits,
segment_ordinal,
accessor_idx,
buckets: vec![TopNComputer::new_with_comparator(num_hits, ReverseComparator); 1],
}
}
fn into_top_hits_collector(
self,
fn get_top_hits_computer(
&mut self,
parent_bucket_id: BucketId,
value_accessors: &HashMap<String, Vec<DynamicColumn>>,
req: &TopHitsAggregationReq,
) -> TopHitsTopNComputer {
if parent_bucket_id as usize >= self.buckets.len() {
return TopHitsTopNComputer::new(req);
}
let top_n = std::mem::replace(
&mut self.buckets[parent_bucket_id as usize],
TopNComputer::new(0),
);
let mut top_hits_computer = TopHitsTopNComputer::new(req);
let top_results = self.top_n.into_vec();
let top_results = top_n.into_vec();
for res in top_results {
let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id);
top_hits_computer.collect(
DocSortValuesAndFields {
sorts: res.feature,
sorts: res.sort_key,
doc_value_fields,
},
res.doc,
@@ -553,54 +567,24 @@ impl TopHitsSegmentCollector {
top_hits_computer
}
/// TODO add a specialized variant for a single sort field
fn collect_with(
&mut self,
doc_id: crate::DocId,
req: &TopHitsAggregationReq,
accessors: &[(Column<u64>, ColumnType)],
) -> crate::Result<()> {
let sorts: Vec<DocValueAndOrder> = req
.sort
.iter()
.enumerate()
.map(|(idx, KeyOrder { order, .. })| {
let order = *order;
let value = accessors
.get(idx)
.expect("could not find field in accessors")
.0
.values_for_doc(doc_id)
.next();
DocValueAndOrder { value, order }
})
.collect();
self.top_n.push(
sorts,
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,
},
);
Ok(())
}
}
impl SegmentAggregationCollector for TopHitsSegmentCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
let value_accessors = &req_data.value_accessors;
let intermediate_result = IntermediateMetricResult::TopHits(
self.into_top_hits_collector(value_accessors, &req_data.req),
);
let intermediate_result = IntermediateMetricResult::TopHits(self.get_top_hits_computer(
parent_bucket_id,
value_accessors,
&req_data.req,
));
results.push(
req_data.name.to_string(),
IntermediateAggregationResult::Metric(intermediate_result),
@@ -610,26 +594,57 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
/// TODO: Consider a caching layer to reduce the call overhead
fn collect(
&mut self,
doc_id: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
self.collect_with(doc_id, &req_data.req, &req_data.accessors)?;
Ok(())
}
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let top_n = &mut self.buckets[parent_bucket_id as usize];
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
// TODO: Consider getting fields with the column block accessor.
for doc in docs {
self.collect_with(*doc, &req_data.req, &req_data.accessors)?;
let req = &req_data.req;
let accessors = &req_data.accessors;
for doc_id in docs {
let doc_id = *doc_id;
// TODO: this is terrible, a new vec is allocated for every doc
// We can fetch blocks instead
// We don't need to store the order for every value
let sorts: Vec<DocValueAndOrder> = req
.sort
.iter()
.enumerate()
.map(|(idx, KeyOrder { order, .. })| {
let order = *order;
let value = accessors
.get(idx)
.expect("could not find field in accessors")
.0
.values_for_doc(doc_id)
.next();
DocValueAndOrder { value, order }
})
.collect();
top_n.push(
sorts,
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,
},
);
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
self.buckets.resize(
(max_bucket as usize) + 1,
TopNComputer::new_with_comparator(self.num_hits, ReverseComparator),
);
Ok(())
}
}
#[cfg(test)]
@@ -645,6 +660,7 @@ mod tests {
use crate::aggregation::bucket::tests::get_test_index_from_docs;
use crate::aggregation::tests::get_test_index_from_values;
use crate::aggregation::AggregationCollector;
use crate::collector::sort_key::ReverseComparator;
use crate::collector::ComparableDoc;
use crate::query::AllQuery;
use crate::schema::OwnedValue;
@@ -660,7 +676,7 @@ mod tests {
fn collector_with_capacity(capacity: usize) -> super::TopHitsTopNComputer {
super::TopHitsTopNComputer {
top_n: super::TopNComputer::new(capacity),
top_n: super::TopNComputer::new_with_comparator(capacity, ReverseComparator),
req: Default::default(),
}
}
@@ -744,7 +760,7 @@ mod tests {
],
"from": 0,
}
}
}
}))
.unwrap();
@@ -774,12 +790,12 @@ mod tests {
#[test]
fn test_top_hits_collector_single_feature() -> crate::Result<()> {
let docs = vec![
ComparableDoc::<_, _, false> {
ComparableDoc::<_, _> {
doc: crate::DocAddress {
segment_ord: 0,
doc_id: 0,
},
feature: DocSortValuesAndFields {
sort_key: DocSortValuesAndFields {
sorts: vec![DocValueAndOrder {
value: Some(1),
order: Order::Asc,
@@ -792,7 +808,7 @@ mod tests {
segment_ord: 0,
doc_id: 2,
},
feature: DocSortValuesAndFields {
sort_key: DocSortValuesAndFields {
sorts: vec![DocValueAndOrder {
value: Some(3),
order: Order::Asc,
@@ -805,7 +821,7 @@ mod tests {
segment_ord: 0,
doc_id: 1,
},
feature: DocSortValuesAndFields {
sort_key: DocSortValuesAndFields {
sorts: vec![DocValueAndOrder {
value: Some(5),
order: Order::Asc,
@@ -817,7 +833,7 @@ mod tests {
let mut collector = collector_with_capacity(3);
for doc in docs.clone() {
collector.collect(doc.feature, doc.doc);
collector.collect(doc.sort_key, doc.doc);
}
let res = collector.into_final_result();
@@ -827,15 +843,15 @@ mod tests {
super::TopHitsMetricResult {
hits: vec![
super::TopHitsVecEntry {
sort: vec![docs[0].feature.sorts[0].value],
sort: vec![docs[0].sort_key.sorts[0].value],
doc_value_fields: Default::default(),
},
super::TopHitsVecEntry {
sort: vec![docs[1].feature.sorts[0].value],
sort: vec![docs[1].sort_key.sorts[0].value],
doc_value_fields: Default::default(),
},
super::TopHitsVecEntry {
sort: vec![docs[2].feature.sorts[0].value],
sort: vec![docs[2].sort_key.sorts[0].value],
doc_value_fields: Default::default(),
},
]
@@ -873,7 +889,7 @@ mod tests {
"mixed.*",
],
}
}
}
}))?;
let collector = AggregationCollector::from_aggs(d, Default::default());

View File

@@ -133,7 +133,7 @@ mod agg_limits;
pub mod agg_req;
pub mod agg_result;
pub mod bucket;
mod buf_collector;
pub(crate) mod cached_sub_aggs;
mod collector;
mod date;
mod error;
@@ -162,6 +162,19 @@ use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::TokenizerManager;
/// A bucket id is a dense identifier for a bucket within an aggregation.
/// It is used to index into a Vec that hold per-bucket data.
///
/// For example, in a terms aggregation, each unique term will be assigned a incremental BucketId.
/// This BucketId will be forwarded to sub-aggregations to identify the parent bucket.
///
/// This allows to have a single AggregationCollector instance per aggregation,
/// that can handle multiple buckets efficiently.
///
/// The API to call sub-aggregations is therefore a &[(BucketId, &[DocId])].
/// For that we'll need a buffer. One Vec per bucket aggregation is needed.
pub type BucketId = u32;
/// Context parameters for aggregation execution
///
/// This struct holds shared resources needed during aggregation execution:
@@ -335,19 +348,37 @@ impl Display for Key {
}
}
pub(crate) fn convert_to_f64<const COLUMN_TYPE_ID: u8>(val: u64) -> f64 {
if COLUMN_TYPE_ID == ColumnType::U64 as u8 {
val as f64
} else if COLUMN_TYPE_ID == ColumnType::I64 as u8
|| COLUMN_TYPE_ID == ColumnType::DateTime as u8
{
i64::from_u64(val) as f64
} else if COLUMN_TYPE_ID == ColumnType::F64 as u8 {
f64::from_u64(val)
} else if COLUMN_TYPE_ID == ColumnType::Bool as u8 {
val as f64
} else {
panic!(
"ColumnType ID {} cannot be converted to f64 metric",
COLUMN_TYPE_ID
)
}
}
/// Inverse of `to_fastfield_u64`. Used to convert to `f64` for metrics.
///
/// # Panics
/// Only `u64`, `f64`, `date`, and `i64` are supported.
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 {
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: ColumnType) -> f64 {
match field_type {
ColumnType::U64 => val as f64,
ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64,
ColumnType::F64 => f64::from_u64(val),
ColumnType::Bool => val as f64,
_ => {
panic!("unexpected type {field_type:?}. This should not happen")
}
ColumnType::U64 => convert_to_f64::<{ ColumnType::U64 as u8 }>(val),
ColumnType::I64 => convert_to_f64::<{ ColumnType::I64 as u8 }>(val),
ColumnType::F64 => convert_to_f64::<{ ColumnType::F64 as u8 }>(val),
ColumnType::Bool => convert_to_f64::<{ ColumnType::Bool as u8 }>(val),
ColumnType::DateTime => convert_to_f64::<{ ColumnType::DateTime as u8 }>(val),
_ => panic!("unexpected type {field_type:?}. This should not happen"),
}
}

View File

@@ -8,25 +8,67 @@ use std::fmt::Debug;
pub(crate) use super::agg_limits::AggregationLimitsGuard;
use super::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::BucketId;
/// Monotonically increasing provider of BucketIds.
#[derive(Debug, Clone, Default)]
pub struct BucketIdProvider(u32);
impl BucketIdProvider {
/// Get the next BucketId.
pub fn next_bucket_id(&mut self) -> BucketId {
let bucket_id = self.0;
self.0 += 1;
bucket_id
}
}
/// A SegmentAggregationCollector is used to collect aggregation results.
pub trait SegmentAggregationCollector: CollectorClone + Debug {
pub trait SegmentAggregationCollector: Debug {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()>;
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
fn collect(
&mut self,
doc: crate::DocId,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()>;
fn collect_block(
/// Collect docs for multiple buckets in one call.
/// Minimizes dynamic dispatch overhead when collecting many buckets.
///
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
fn collect_multiple(
&mut self,
bucket_ids: &[BucketId],
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
debug_assert_eq!(bucket_ids.len(), docs.len());
let mut start = 0;
while start < bucket_ids.len() {
let bucket_id = bucket_ids[start];
let mut end = start + 1;
while end < bucket_ids.len() && bucket_ids[end] == bucket_id {
end += 1;
}
self.collect(bucket_id, &docs[start..end], agg_data)?;
start = end;
}
Ok(())
}
/// Prepare the collector for collecting up to BucketId `max_bucket`.
/// This is useful so we can split allocation ahead of time of collecting.
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()>;
/// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`.
@@ -36,26 +78,7 @@ pub trait SegmentAggregationCollector: CollectorClone + Debug {
}
}
/// A helper trait to enable cloning of Box<dyn SegmentAggregationCollector>
pub trait CollectorClone {
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector>;
}
impl<T> CollectorClone for T
where T: 'static + SegmentAggregationCollector + Clone
{
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn SegmentAggregationCollector> {
fn clone(&self) -> Box<dyn SegmentAggregationCollector> {
self.clone_box()
}
}
#[derive(Clone, Default)]
#[derive(Default)]
/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which
/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one
/// and can provide specialized versions instead, that remove some of its overhead.
@@ -73,12 +96,13 @@ impl Debug for GenericSegmentAggregationResultsCollector {
impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
for agg in self.aggs {
agg.add_intermediate_aggregation_result(agg_data, results)?;
for agg in &mut self.aggs {
agg.add_intermediate_aggregation_result(agg_data, results, parent_bucket_id)?;
}
Ok(())
@@ -86,23 +110,13 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)?;
Ok(())
}
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.collect_block(docs, agg_data)?;
collector.collect(parent_bucket_id, docs, agg_data)?;
}
Ok(())
}
@@ -112,4 +126,15 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.prepare_max_bucket(max_bucket, agg_data)?;
}
Ok(())
}
}

View File

@@ -1,121 +0,0 @@
use crate::collector::top_collector::{TopCollector, TopSegmentCollector};
use crate::collector::{Collector, SegmentCollector};
use crate::{DocAddress, DocId, Score, SegmentReader};
pub(crate) struct CustomScoreTopCollector<TCustomScorer, TScore = Score> {
custom_scorer: TCustomScorer,
collector: TopCollector<TScore>,
}
impl<TCustomScorer, TScore> CustomScoreTopCollector<TCustomScorer, TScore>
where TScore: Clone + PartialOrd
{
pub(crate) fn new(
custom_scorer: TCustomScorer,
collector: TopCollector<TScore>,
) -> CustomScoreTopCollector<TCustomScorer, TScore> {
CustomScoreTopCollector {
custom_scorer,
collector,
}
}
}
/// A custom segment scorer makes it possible to define any kind of score
/// for a given document belonging to a specific segment.
///
/// It is the segment local version of the [`CustomScorer`].
pub trait CustomSegmentScorer<TScore>: 'static {
/// Computes the score of a specific `doc`.
fn score(&mut self, doc: DocId) -> TScore;
}
/// `CustomScorer` makes it possible to define any kind of score.
///
/// The `CustomerScorer` itself does not make much of the computation itself.
/// Instead, it helps constructing `Self::Child` instances that will compute
/// the score at a segment scale.
pub trait CustomScorer<TScore>: Sync {
/// Type of the associated [`CustomSegmentScorer`].
type Child: CustomSegmentScorer<TScore>;
/// Builds a child scorer for a specific segment. The child scorer is associated with
/// a specific segment.
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child>;
}
impl<TCustomScorer, TScore> Collector for CustomScoreTopCollector<TCustomScorer, TScore>
where
TCustomScorer: CustomScorer<TScore> + Send + Sync,
TScore: 'static + PartialOrd + Clone + Send + Sync,
{
type Fruit = Vec<(TScore, DocAddress)>;
type Child = CustomScoreTopSegmentCollector<TCustomScorer::Child, TScore>;
fn for_segment(
&self,
segment_local_id: u32,
segment_reader: &SegmentReader,
) -> crate::Result<Self::Child> {
let segment_collector = self.collector.for_segment(segment_local_id, segment_reader);
let segment_scorer = self.custom_scorer.segment_scorer(segment_reader)?;
Ok(CustomScoreTopSegmentCollector {
segment_collector,
segment_scorer,
})
}
fn requires_scoring(&self) -> bool {
false
}
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> crate::Result<Self::Fruit> {
self.collector.merge_fruits(segment_fruits)
}
}
pub struct CustomScoreTopSegmentCollector<T, TScore>
where
TScore: 'static + PartialOrd + Clone + Send + Sync + Sized,
T: CustomSegmentScorer<TScore>,
{
segment_collector: TopSegmentCollector<TScore>,
segment_scorer: T,
}
impl<T, TScore> SegmentCollector for CustomScoreTopSegmentCollector<T, TScore>
where
TScore: 'static + PartialOrd + Clone + Send + Sync,
T: 'static + CustomSegmentScorer<TScore>,
{
type Fruit = Vec<(TScore, DocAddress)>;
fn collect(&mut self, doc: DocId, _score: Score) {
let score = self.segment_scorer.score(doc);
self.segment_collector.collect(doc, score);
}
fn harvest(self) -> Vec<(TScore, DocAddress)> {
self.segment_collector.harvest()
}
}
impl<F, TScore, T> CustomScorer<TScore> for F
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> T,
T: CustomSegmentScorer<TScore>,
{
type Child = T;
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child> {
Ok((self)(segment_reader))
}
}
impl<F, TScore> CustomSegmentScorer<TScore> for F
where F: 'static + FnMut(DocId) -> TScore
{
fn score(&mut self, doc: DocId) -> TScore {
(self)(doc)
}
}

View File

@@ -12,6 +12,7 @@ use std::marker::PhantomData;
use columnar::{BytesColumn, Column, DynamicColumn, HasAssociatedColumnType};
use crate::collector::{Collector, SegmentCollector};
use crate::schema::Schema;
use crate::{DocId, Score, SegmentReader};
/// The `FilterCollector` filters docs using a fast field value and a predicate.
@@ -49,13 +50,13 @@ use crate::{DocId, Score, SegmentReader};
///
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query = query_parser.parse_query("diary")?;
/// let no_filter_collector = FilterCollector::new("price".to_string(), |value: u64| value > 20_120u64, TopDocs::with_limit(2));
/// let no_filter_collector = FilterCollector::new("price".to_string(), |value: u64| value > 20_120u64, TopDocs::with_limit(2).order_by_score());
/// let top_docs = searcher.search(&query, &no_filter_collector)?;
///
/// assert_eq!(top_docs.len(), 1);
/// assert_eq!(top_docs[0].1, DocAddress::new(0, 1));
///
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new("price".to_string(), |value| value < 5u64, TopDocs::with_limit(2));
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new("price".to_string(), |value| value < 5u64, TopDocs::with_limit(2).order_by_score());
/// let filtered_top_docs = searcher.search(&query, &filter_all_collector)?;
///
/// assert_eq!(filtered_top_docs.len(), 0);
@@ -104,6 +105,11 @@ where
type Child = FilterSegmentCollector<TCollector::Child, TPredicate, TPredicateValue>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.collector.check_schema(schema)?;
Ok(())
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -120,6 +126,7 @@ where
segment_collector,
predicate: self.predicate.clone(),
t_predicate_value: PhantomData,
filtered_docs: Vec::with_capacity(crate::COLLECT_BLOCK_BUFFER_LEN),
})
}
@@ -140,6 +147,7 @@ pub struct FilterSegmentCollector<TSegmentCollector, TPredicate, TPredicateValue
segment_collector: TSegmentCollector,
predicate: TPredicate,
t_predicate_value: PhantomData<TPredicateValue>,
filtered_docs: Vec<DocId>,
}
impl<TSegmentCollector, TPredicate, TPredicateValue>
@@ -176,6 +184,20 @@ where
}
}
fn collect_block(&mut self, docs: &[DocId]) {
self.filtered_docs.clear();
for &doc in docs {
// TODO: `accept_document` could be further optimized to do batch lookups of column
// values for single-valued columns.
if self.accept_document(doc) {
self.filtered_docs.push(doc);
}
}
if !self.filtered_docs.is_empty() {
self.segment_collector.collect_block(&self.filtered_docs);
}
}
fn harvest(self) -> TSegmentCollector::Fruit {
self.segment_collector.harvest()
}
@@ -218,7 +240,7 @@ where
///
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query = query_parser.parse_query("diary")?;
/// let filter_collector = BytesFilterCollector::new("barcode".to_string(), |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2));
/// let filter_collector = BytesFilterCollector::new("barcode".to_string(), |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2).order_by_score());
/// let top_docs = searcher.search(&query, &filter_collector)?;
///
/// assert_eq!(top_docs.len(), 1);
@@ -258,6 +280,10 @@ where
type Child = BytesFilterSegmentCollector<TCollector::Child, TPredicate>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.collector.check_schema(schema)
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -274,6 +300,7 @@ where
segment_collector,
predicate: self.predicate.clone(),
buffer: Vec::new(),
filtered_docs: Vec::with_capacity(crate::COLLECT_BLOCK_BUFFER_LEN),
})
}
@@ -296,6 +323,7 @@ where TPredicate: 'static
segment_collector: TSegmentCollector,
predicate: TPredicate,
buffer: Vec<u8>,
filtered_docs: Vec<DocId>,
}
impl<TSegmentCollector, TPredicate> BytesFilterSegmentCollector<TSegmentCollector, TPredicate>
@@ -334,6 +362,20 @@ where
}
}
fn collect_block(&mut self, docs: &[DocId]) {
self.filtered_docs.clear();
for &doc in docs {
// TODO: `accept_document` could be further optimized to do batch lookups of column
// values for single-valued columns.
if self.accept_document(doc) {
self.filtered_docs.push(doc);
}
}
if !self.filtered_docs.is_empty() {
self.segment_collector.collect_block(&self.filtered_docs);
}
}
fn harvest(self) -> TSegmentCollector::Fruit {
self.segment_collector.harvest()
}

View File

@@ -57,7 +57,7 @@
//! # let query_parser = QueryParser::for_index(&index, vec![title]);
//! # let query = query_parser.parse_query("diary")?;
//! let (doc_count, top_docs): (usize, Vec<(Score, DocAddress)>) =
//! searcher.search(&query, &(Count, TopDocs::with_limit(2)))?;
//! searcher.search(&query, &(Count, TopDocs::with_limit(2).order_by_score()))?;
//! # Ok(())
//! # }
//! ```
@@ -83,11 +83,15 @@
use downcast_rs::impl_downcast;
use crate::schema::Schema;
use crate::{DocId, Score, SegmentOrdinal, SegmentReader};
mod count_collector;
pub use self::count_collector::Count;
/// Sort keys
pub mod sort_key;
mod histogram_collector;
pub use histogram_collector::HistogramCollector;
@@ -95,16 +99,13 @@ mod multi_collector;
pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
mod top_collector;
pub use self::top_collector::ComparableDoc;
mod top_score_collector;
pub use self::top_collector::ComparableDoc;
pub use self::top_score_collector::{TopDocs, TopNComputer};
mod custom_score_top_collector;
pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer};
mod tweak_score_top_collector;
pub use self::tweak_score_top_collector::{ScoreSegmentTweaker, ScoreTweaker};
mod sort_key_top_collector;
pub use self::sort_key::{SegmentSortKeyComputer, SortKeyComputer};
mod facet_collector;
pub use self::facet_collector::{FacetCollector, FacetCounts};
use crate::query::Weight;
@@ -145,6 +146,11 @@ pub trait Collector: Sync + Send {
/// Type of the `SegmentCollector` associated with this collector.
type Child: SegmentCollector;
/// Returns an error if the schema is not compatible with the collector.
fn check_schema(&self, _schema: &Schema) -> crate::Result<()> {
Ok(())
}
/// `set_segment` is called before beginning to enumerate
/// on this segment.
fn for_segment(
@@ -170,41 +176,50 @@ pub trait Collector: Sync + Send {
segment_ord: u32,
reader: &SegmentReader,
) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> {
let with_scoring = self.requires_scoring();
let mut segment_collector = self.for_segment(segment_ord, reader)?;
match (reader.alive_bitset(), self.requires_scoring()) {
(Some(alive_bitset), true) => {
weight.for_each(reader, &mut |doc, score| {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, score);
}
})?;
}
(Some(alive_bitset), false) => {
weight.for_each_no_score(reader, &mut |docs| {
for doc in docs.iter().cloned() {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, 0.0);
}
}
})?;
}
(None, true) => {
weight.for_each(reader, &mut |doc, score| {
segment_collector.collect(doc, score);
})?;
}
(None, false) => {
weight.for_each_no_score(reader, &mut |docs| {
segment_collector.collect_block(docs);
})?;
}
}
default_collect_segment_impl(&mut segment_collector, weight, reader, with_scoring)?;
Ok(segment_collector.harvest())
}
}
pub(crate) fn default_collect_segment_impl<TSegmentCollector: SegmentCollector>(
segment_collector: &mut TSegmentCollector,
weight: &dyn Weight,
reader: &SegmentReader,
with_scoring: bool,
) -> crate::Result<()> {
match (reader.alive_bitset(), with_scoring) {
(Some(alive_bitset), true) => {
weight.for_each(reader, &mut |doc, score| {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, score);
}
})?;
}
(Some(alive_bitset), false) => {
weight.for_each_no_score(reader, &mut |docs| {
for doc in docs.iter().cloned() {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, 0.0);
}
}
})?;
}
(None, true) => {
weight.for_each(reader, &mut |doc, score| {
segment_collector.collect(doc, score);
})?;
}
(None, false) => {
weight.for_each_no_score(reader, &mut |docs| {
segment_collector.collect_block(docs);
})?;
}
}
Ok(())
}
impl<TSegmentCollector: SegmentCollector> SegmentCollector for Option<TSegmentCollector> {
type Fruit = Option<TSegmentCollector::Fruit>;
@@ -214,6 +229,12 @@ impl<TSegmentCollector: SegmentCollector> SegmentCollector for Option<TSegmentCo
}
}
fn collect_block(&mut self, docs: &[DocId]) {
if let Some(segment_collector) = self {
segment_collector.collect_block(docs);
}
}
fn harvest(self) -> Self::Fruit {
self.map(|segment_collector| segment_collector.harvest())
}
@@ -224,6 +245,13 @@ impl<TCollector: Collector> Collector for Option<TCollector> {
type Child = Option<<TCollector as Collector>::Child>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
if let Some(underlying_collector) = self {
underlying_collector.check_schema(schema)?;
}
Ok(())
}
fn for_segment(
&self,
segment_local_id: SegmentOrdinal,
@@ -299,6 +327,12 @@ where
type Fruit = (Left::Fruit, Right::Fruit);
type Child = (Left::Child, Right::Child);
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
Ok(())
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -342,6 +376,11 @@ where
self.1.collect(doc, score);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.0.collect_block(docs);
self.1.collect_block(docs);
}
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
(self.0.harvest(), self.1.harvest())
}
@@ -358,6 +397,13 @@ where
type Fruit = (One::Fruit, Two::Fruit, Three::Fruit);
type Child = (One::Child, Two::Child, Three::Child);
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
self.2.check_schema(schema)?;
Ok(())
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -407,6 +453,12 @@ where
self.2.collect(doc, score);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.0.collect_block(docs);
self.1.collect_block(docs);
self.2.collect_block(docs);
}
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
(self.0.harvest(), self.1.harvest(), self.2.harvest())
}
@@ -424,6 +476,14 @@ where
type Fruit = (One::Fruit, Two::Fruit, Three::Fruit, Four::Fruit);
type Child = (One::Child, Two::Child, Three::Child, Four::Child);
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
self.2.check_schema(schema)?;
self.3.check_schema(schema)?;
Ok(())
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -482,6 +542,13 @@ where
self.3.collect(doc, score);
}
fn collect_block(&mut self, docs: &[DocId]) {
self.0.collect_block(docs);
self.1.collect_block(docs);
self.2.collect_block(docs);
self.3.collect_block(docs);
}
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
(
self.0.harvest(),

View File

@@ -3,6 +3,7 @@ use std::ops::Deref;
use super::{Collector, SegmentCollector};
use crate::collector::Fruit;
use crate::schema::Schema;
use crate::{DocId, Score, SegmentOrdinal, SegmentReader, TantivyError};
/// MultiFruit keeps Fruits from every nested Collector
@@ -16,6 +17,10 @@ impl<TCollector: Collector> Collector for CollectorWrapper<TCollector> {
type Fruit = Box<dyn Fruit>;
type Child = Box<dyn BoxableSegmentCollector>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)
}
fn for_segment(
&self,
segment_local_id: u32,
@@ -147,7 +152,7 @@ impl<TFruit: Fruit> FruitHandle<TFruit> {
/// let searcher = reader.searcher();
///
/// let mut collectors = MultiCollector::new();
/// let top_docs_handle = collectors.add_collector(TopDocs::with_limit(2));
/// let top_docs_handle = collectors.add_collector(TopDocs::with_limit(2).order_by_score());
/// let count_handle = collectors.add_collector(Count);
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query = query_parser.parse_query("diary").unwrap();
@@ -194,6 +199,13 @@ impl Collector for MultiCollector<'_> {
type Fruit = MultiFruit;
type Child = MultiCollectorChild;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
for collector in &self.collector_wrappers {
collector.check_schema(schema)?;
}
Ok(())
}
fn for_segment(
&self,
segment_local_id: SegmentOrdinal,
@@ -250,6 +262,12 @@ impl SegmentCollector for MultiCollectorChild {
}
}
fn collect_block(&mut self, docs: &[DocId]) {
for child in &mut self.children {
child.collect_block(docs);
}
}
fn harvest(self) -> MultiFruit {
MultiFruit {
sub_fruits: self
@@ -293,7 +311,7 @@ mod tests {
let query = TermQuery::new(term, IndexRecordOption::Basic);
let mut collectors = MultiCollector::new();
let topdocs_handler = collectors.add_collector(TopDocs::with_limit(2));
let topdocs_handler = collectors.add_collector(TopDocs::with_limit(2).order_by_score());
let count_handler = collectors.add_collector(Count);
let mut multifruits = searcher.search(&query, &collectors).unwrap();

View File

@@ -0,0 +1,393 @@
mod order;
mod sort_by_score;
mod sort_by_static_fast_value;
mod sort_by_string;
mod sort_key_computer;
pub use order::*;
pub use sort_by_score::SortBySimilarityScore;
pub use sort_by_static_fast_value::SortByStaticFastValue;
pub use sort_by_string::SortByString;
pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer};
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::ops::Range;
use crate::collector::sort_key::{SortBySimilarityScore, SortByStaticFastValue, SortByString};
use crate::collector::{ComparableDoc, DocSetCollector, TopDocs};
use crate::indexer::NoMergePolicy;
use crate::query::{AllQuery, QueryParser};
use crate::schema::{Schema, FAST, TEXT};
use crate::{DocAddress, Document, Index, Order, Score, Searcher};
fn make_index() -> crate::Result<Index> {
let mut schema_builder = Schema::builder();
let id = schema_builder.add_u64_field("id", FAST);
let city = schema_builder.add_text_field("city", TEXT | FAST);
let catchphrase = schema_builder.add_text_field("catchphrase", TEXT);
let altitude = schema_builder.add_f64_field("altitude", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
fn create_segment(index: &Index, docs: Vec<impl Document>) -> crate::Result<()> {
let mut index_writer = index.writer_for_tests()?;
index_writer.set_merge_policy(Box::new(NoMergePolicy));
for doc in docs {
index_writer.add_document(doc)?;
}
index_writer.commit()?;
Ok(())
}
create_segment(
&index,
vec![
doc!(
id => 0_u64,
city => "austin",
catchphrase => "Hills, Barbeque, Glow",
altitude => 149.0,
),
doc!(
id => 1_u64,
city => "greenville",
catchphrase => "Grow, Glow, Glow",
altitude => 27.0,
),
],
)?;
create_segment(
&index,
vec![doc!(
id => 2_u64,
city => "tokyo",
catchphrase => "Glow, Glow, Glow",
altitude => 40.0,
)],
)?;
create_segment(
&index,
vec![doc!(
id => 3_u64,
catchphrase => "No, No, No",
altitude => 0.0,
)],
)?;
Ok(index)
}
// NOTE: You cannot determine the SegmentIds that will be generated for Segments
// ahead of time, so DocAddresses must be mapped back to a unique id for each Searcher.
fn id_mapping(searcher: &Searcher) -> HashMap<DocAddress, u64> {
searcher
.search(&AllQuery, &DocSetCollector)
.unwrap()
.into_iter()
.map(|doc_address| {
let column = searcher.segment_readers()[doc_address.segment_ord as usize]
.fast_fields()
.u64("id")
.unwrap();
(doc_address, column.first(doc_address.doc_id).unwrap())
})
.collect()
}
#[test]
fn test_order_by_string() -> crate::Result<()> {
let index = make_index()?;
#[track_caller]
fn assert_query(
index: &Index,
order: Order,
doc_range: Range<usize>,
expected: Vec<(Option<String>, u64)>,
) -> crate::Result<()> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
// Try as primitive.
let top_collector = TopDocs::for_doc_range(doc_range)
.order_by((SortByString::for_field("city"), order));
let actual = searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(sort_key_opt, doc)| (sort_key_opt, ids[&doc]))
.collect::<Vec<_>>();
assert_eq!(actual, expected);
Ok(())
}
assert_query(
&index,
Order::Asc,
0..4,
vec![
(Some("austin".to_owned()), 0),
(Some("greenville".to_owned()), 1),
(Some("tokyo".to_owned()), 2),
(None, 3),
],
)?;
assert_query(
&index,
Order::Asc,
0..3,
vec![
(Some("austin".to_owned()), 0),
(Some("greenville".to_owned()), 1),
(Some("tokyo".to_owned()), 2),
],
)?;
assert_query(
&index,
Order::Asc,
0..2,
vec![
(Some("austin".to_owned()), 0),
(Some("greenville".to_owned()), 1),
],
)?;
assert_query(
&index,
Order::Asc,
0..1,
vec![(Some("austin".to_string()), 0)],
)?;
assert_query(
&index,
Order::Asc,
1..3,
vec![
(Some("greenville".to_owned()), 1),
(Some("tokyo".to_owned()), 2),
],
)?;
assert_query(
&index,
Order::Desc,
0..4,
vec![
(Some("tokyo".to_owned()), 2),
(Some("greenville".to_owned()), 1),
(Some("austin".to_owned()), 0),
(None, 3),
],
)?;
assert_query(
&index,
Order::Desc,
1..3,
vec![
(Some("greenville".to_owned()), 1),
(Some("austin".to_owned()), 0),
],
)?;
assert_query(
&index,
Order::Desc,
0..1,
vec![(Some("tokyo".to_owned()), 2)],
)?;
Ok(())
}
#[test]
fn test_order_by_f64() -> crate::Result<()> {
let index = make_index()?;
fn assert_query(
index: &Index,
order: Order,
expected: Vec<(Option<f64>, u64)>,
) -> crate::Result<()> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
// Try as primitive.
let top_collector = TopDocs::with_limit(3)
.order_by((SortByStaticFastValue::<f64>::for_field("altitude"), order));
let actual = searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(altitude_opt, doc)| (altitude_opt, ids[&doc]))
.collect::<Vec<_>>();
assert_eq!(actual, expected);
Ok(())
}
assert_query(
&index,
Order::Asc,
vec![(Some(0.0), 3), (Some(27.0), 1), (Some(40.0), 2)],
)?;
assert_query(
&index,
Order::Desc,
vec![(Some(149.0), 0), (Some(40.0), 2), (Some(27.0), 1)],
)?;
Ok(())
}
#[test]
fn test_order_by_score() -> crate::Result<()> {
let index = make_index()?;
fn query(index: &Index, order: Order) -> crate::Result<Vec<(Score, u64)>> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
let top_collector = TopDocs::with_limit(4).order_by((SortBySimilarityScore, order));
let field = index.schema().get_field("catchphrase").unwrap();
let query_parser = QueryParser::for_index(index, vec![field]);
let text_query = query_parser.parse_query("glow")?;
Ok(searcher
.search(&text_query, &top_collector)?
.into_iter()
.map(|(score, doc)| (score, ids[&doc]))
.collect())
}
assert_eq!(
&query(&index, Order::Desc)?,
&[(0.5604893, 2), (0.4904281, 1), (0.35667497, 0),]
);
assert_eq!(
&query(&index, Order::Asc)?,
&[(0.35667497, 0), (0.4904281, 1), (0.5604893, 2),]
);
Ok(())
}
#[test]
fn test_order_by_score_then_string() -> crate::Result<()> {
let index = make_index()?;
type SortKey = (Score, Option<String>);
fn query(
index: &Index,
score_order: Order,
city_order: Order,
) -> crate::Result<Vec<(SortKey, u64)>> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
let top_collector = TopDocs::with_limit(4).order_by((
(SortBySimilarityScore, score_order),
(SortByString::for_field("city"), city_order),
));
Ok(searcher
.search(&AllQuery, &top_collector)?
.into_iter()
.map(|(f, doc)| (f, ids[&doc]))
.collect())
}
assert_eq!(
&query(&index, Order::Asc, Order::Asc)?,
&[
((1.0, Some("austin".to_owned())), 0),
((1.0, Some("greenville".to_owned())), 1),
((1.0, Some("tokyo".to_owned())), 2),
((1.0, None), 3),
]
);
assert_eq!(
&query(&index, Order::Asc, Order::Desc)?,
&[
((1.0, Some("tokyo".to_owned())), 2),
((1.0, Some("greenville".to_owned())), 1),
((1.0, Some("austin".to_owned())), 0),
((1.0, None), 3),
]
);
Ok(())
}
use proptest::prelude::*;
proptest! {
#[test]
fn test_order_by_string_prop(
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
limit in 1..64_usize,
offset in 0..64_usize,
segments_terms in
proptest::collection::vec(
proptest::collection::vec(0..32_u8, 1..32_usize),
0..8_usize,
)
) {
let mut schema_builder = Schema::builder();
let city = schema_builder.add_text_field("city", TEXT | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests()?;
// A Vec<Vec<u8>>, where the outer Vec represents segments, and the inner Vec
// represents terms.
for segment_terms in segments_terms.into_iter() {
for term in segment_terms.into_iter() {
let term = format!("{term:0>3}");
index_writer.add_document(doc!(
city => term,
))?;
}
index_writer.commit()?;
}
let searcher = index.reader()?.searcher();
let top_n_results = searcher.search(&AllQuery, &TopDocs::with_limit(limit)
.and_offset(offset)
.order_by_string_fast_field("city", order))?;
let all_results = searcher.search(&AllQuery, &DocSetCollector)?.into_iter().map(|doc_address| {
// Get the term for this address.
let column = searcher.segment_readers()[doc_address.segment_ord as usize].fast_fields().str("city").unwrap().unwrap();
let value = column.term_ords(doc_address.doc_id).next().map(|term_ord| {
let mut city = Vec::new();
column.dictionary().ord_to_term(term_ord, &mut city).unwrap();
String::try_from(city).unwrap()
});
(value, doc_address)
});
// Using the TopDocs collector should always be equivalent to sorting, skipping the
// offset, and then taking the limit.
let sorted_docs: Vec<_> = if order.is_desc() {
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
} else {
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
prop_assert_eq!(
expected_docs,
top_n_results
);
}
}
}

View File

@@ -0,0 +1,348 @@
use std::cmp::Ordering;
use serde::{Deserialize, Serialize};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::schema::Schema;
use crate::{DocId, Order, Score};
/// Comparator trait defining the order in which documents should be ordered.
pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
/// Return the order between two values.
fn compare(&self, lhs: &T, rhs: &T) -> Ordering;
}
/// With the natural comparator, the top k collector will return
/// the top documents in decreasing order.
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct NaturalComparator;
impl<T: PartialOrd> Comparator<T> for NaturalComparator {
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
lhs.partial_cmp(rhs).unwrap()
}
}
/// Sorts document in reverse order.
///
/// If the sort key is None, it will considered as the lowest value, and will therefore appear
/// first.
///
/// The ReverseComparator does not necessarily imply that the sort order is reversed compared
/// to the NaturalComparator. In presence of a tie, both version will retain the higher doc ids.
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct ReverseComparator;
impl<T> Comparator<T> for ReverseComparator
where NaturalComparator: Comparator<T>
{
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
NaturalComparator.compare(rhs, lhs)
}
}
/// Sorts document in reverse order, but considers None as having the lowest value.
///
/// This is usually what is wanted when sorting by a field in an ascending order.
/// For instance, in a e-commerce website, if I sort by price ascending, I most likely want the
/// cheapest items first, and the items without a price at last.
#[derive(Debug, Copy, Clone, Default)]
pub struct ReverseNoneIsLowerComparator;
impl<T> Comparator<Option<T>> for ReverseNoneIsLowerComparator
where ReverseComparator: Comparator<T>
{
#[inline(always)]
fn compare(&self, lhs_opt: &Option<T>, rhs_opt: &Option<T>) -> Ordering {
match (lhs_opt, rhs_opt) {
(None, None) => Ordering::Equal,
(None, Some(_)) => Ordering::Less,
(Some(_), None) => Ordering::Greater,
(Some(lhs), Some(rhs)) => ReverseComparator.compare(lhs, rhs),
}
}
}
impl Comparator<u32> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
impl Comparator<u64> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
impl Comparator<f64> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
impl Comparator<f32> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
impl Comparator<i64> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
impl Comparator<String> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
ReverseComparator.compare(lhs, rhs)
}
}
/// An enum representing the different sort orders.
#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
pub enum ComparatorEnum {
/// Natural order (See [NaturalComparator])
#[default]
Natural,
/// Reverse order (See [ReverseComparator])
Reverse,
/// Reverse order by treating None as the lowest value.(See [ReverseNoneLowerComparator])
ReverseNoneLower,
}
impl From<Order> for ComparatorEnum {
fn from(order: Order) -> Self {
match order {
Order::Asc => ComparatorEnum::ReverseNoneLower,
Order::Desc => ComparatorEnum::Natural,
}
}
}
impl<T> Comparator<T> for ComparatorEnum
where
ReverseNoneIsLowerComparator: Comparator<T>,
NaturalComparator: Comparator<T>,
ReverseComparator: Comparator<T>,
{
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
match self {
ComparatorEnum::Natural => NaturalComparator.compare(lhs, rhs),
ComparatorEnum::Reverse => ReverseComparator.compare(lhs, rhs),
ComparatorEnum::ReverseNoneLower => ReverseNoneIsLowerComparator.compare(lhs, rhs),
}
}
}
impl<Head, Tail, LeftComparator, RightComparator> Comparator<(Head, Tail)>
for (LeftComparator, RightComparator)
where
LeftComparator: Comparator<Head>,
RightComparator: Comparator<Tail>,
{
#[inline(always)]
fn compare(&self, lhs: &(Head, Tail), rhs: &(Head, Tail)) -> Ordering {
self.0
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
}
}
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, (Type2, Type3))>
for (Comparator1, Comparator2, Comparator3)
where
Comparator1: Comparator<Type1>,
Comparator2: Comparator<Type2>,
Comparator3: Comparator<Type3>,
{
#[inline(always)]
fn compare(&self, lhs: &(Type1, (Type2, Type3)), rhs: &(Type1, (Type2, Type3))) -> Ordering {
self.0
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
.then_with(|| self.2.compare(&lhs.1 .1, &rhs.1 .1))
}
}
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, Type2, Type3)>
for (Comparator1, Comparator2, Comparator3)
where
Comparator1: Comparator<Type1>,
Comparator2: Comparator<Type2>,
Comparator3: Comparator<Type3>,
{
#[inline(always)]
fn compare(&self, lhs: &(Type1, Type2, Type3), rhs: &(Type1, Type2, Type3)) -> Ordering {
self.0
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
}
}
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
Comparator<(Type1, (Type2, (Type3, Type4)))>
for (Comparator1, Comparator2, Comparator3, Comparator4)
where
Comparator1: Comparator<Type1>,
Comparator2: Comparator<Type2>,
Comparator3: Comparator<Type3>,
Comparator4: Comparator<Type4>,
{
#[inline(always)]
fn compare(
&self,
lhs: &(Type1, (Type2, (Type3, Type4))),
rhs: &(Type1, (Type2, (Type3, Type4))),
) -> Ordering {
self.0
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
.then_with(|| self.2.compare(&lhs.1 .1 .0, &rhs.1 .1 .0))
.then_with(|| self.3.compare(&lhs.1 .1 .1, &rhs.1 .1 .1))
}
}
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
Comparator<(Type1, Type2, Type3, Type4)>
for (Comparator1, Comparator2, Comparator3, Comparator4)
where
Comparator1: Comparator<Type1>,
Comparator2: Comparator<Type2>,
Comparator3: Comparator<Type3>,
Comparator4: Comparator<Type4>,
{
#[inline(always)]
fn compare(
&self,
lhs: &(Type1, Type2, Type3, Type4),
rhs: &(Type1, Type2, Type3, Type4),
) -> Ordering {
self.0
.compare(&lhs.0, &rhs.0)
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
.then_with(|| self.3.compare(&lhs.3, &rhs.3))
}
}
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, ComparatorEnum)
where
TSortKeyComputer: SortKeyComputer,
ComparatorEnum: Comparator<TSortKeyComputer::SortKey>,
ComparatorEnum: Comparator<
<<TSortKeyComputer as SortKeyComputer>::Child as SegmentSortKeyComputer>::SegmentSortKey,
>,
{
type SortKey = TSortKeyComputer::SortKey;
type Child = SegmentSortKeyComputerWithComparator<TSortKeyComputer::Child, Self::Comparator>;
type Comparator = ComparatorEnum;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)
}
fn requires_scoring(&self) -> bool {
self.0.requires_scoring()
}
fn comparator(&self) -> Self::Comparator {
self.1
}
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let child = self.0.segment_sort_key_computer(segment_reader)?;
Ok(SegmentSortKeyComputerWithComparator {
segment_sort_key_computer: child,
comparator: self.comparator(),
})
}
}
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, Order)
where
TSortKeyComputer: SortKeyComputer,
ComparatorEnum: Comparator<TSortKeyComputer::SortKey>,
ComparatorEnum: Comparator<
<<TSortKeyComputer as SortKeyComputer>::Child as SegmentSortKeyComputer>::SegmentSortKey,
>,
{
type SortKey = TSortKeyComputer::SortKey;
type Child = SegmentSortKeyComputerWithComparator<TSortKeyComputer::Child, Self::Comparator>;
type Comparator = ComparatorEnum;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)
}
fn requires_scoring(&self) -> bool {
self.0.requires_scoring()
}
fn comparator(&self) -> Self::Comparator {
self.1.into()
}
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let child = self.0.segment_sort_key_computer(segment_reader)?;
Ok(SegmentSortKeyComputerWithComparator {
segment_sort_key_computer: child,
comparator: self.comparator(),
})
}
}
/// A segment sort key computer with a custom ordering.
pub struct SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator> {
segment_sort_key_computer: TSegmentSortKeyComputer,
comparator: TComparator,
}
impl<TSegmentSortKeyComputer, TSegmentSortKey, TComparator> SegmentSortKeyComputer
for SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator>
where
TSegmentSortKeyComputer: SegmentSortKeyComputer<SegmentSortKey = TSegmentSortKey>,
TSegmentSortKey: PartialOrd + Clone + 'static + Sync + Send,
TComparator: Comparator<TSegmentSortKey> + 'static + Sync + Send,
{
type SortKey = TSegmentSortKeyComputer::SortKey;
type SegmentSortKey = TSegmentSortKey;
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.segment_sort_key_computer.segment_sort_key(doc, score)
}
#[inline(always)]
fn compare_segment_sort_key(
&self,
left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey,
) -> Ordering {
self.comparator.compare(left, right)
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
self.segment_sort_key_computer
.convert_segment_sort_key(sort_key)
}
}

View File

@@ -0,0 +1,77 @@
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::{DocAddress, DocId, Score};
/// Sort by similarity score.
#[derive(Clone, Debug, Copy)]
pub struct SortBySimilarityScore;
impl SortKeyComputer for SortBySimilarityScore {
type SortKey = Score;
type Child = SortBySimilarityScore;
type Comparator = NaturalComparator;
fn requires_scoring(&self) -> bool {
true
}
fn segment_sort_key_computer(
&self,
_segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
Ok(SortBySimilarityScore)
}
// Sorting by score is special in that it allows for the Block-Wand optimization.
fn collect_segment_top_k(
&self,
k: usize,
weight: &dyn crate::query::Weight,
reader: &crate::SegmentReader,
segment_ord: u32,
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
let mut top_n: TopNComputer<Score, DocId, Self::Comparator> =
TopNComputer::new_with_comparator(k, self.comparator());
if let Some(alive_bitset) = reader.alive_bitset() {
let mut threshold = Score::MIN;
top_n.threshold = Some(threshold);
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
if alive_bitset.is_deleted(doc) {
return threshold;
}
top_n.push(score, doc);
threshold = top_n.threshold.unwrap_or(Score::MIN);
threshold
})?;
} else {
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
top_n.push(score, doc);
top_n.threshold.unwrap_or(Score::MIN)
})?;
}
Ok(top_n
.into_vec()
.into_iter()
.map(|cid| (cid.sort_key, DocAddress::new(segment_ord, cid.doc)))
.collect())
}
}
impl SegmentSortKeyComputer for SortBySimilarityScore {
type SortKey = Score;
type SegmentSortKey = Score;
#[inline(always)]
fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score {
score
}
fn convert_segment_sort_key(&self, score: Score) -> Score {
score
}
}

View File

@@ -0,0 +1,98 @@
use std::marker::PhantomData;
use columnar::Column;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::{FastFieldNotAvailableError, FastValue};
use crate::{DocId, Score, SegmentReader};
/// Sorts by a fast value (u64, i64, f64, bool).
///
/// The field must appear explicitly in the schema, with the right type, and declared as
/// a fast field..
///
/// If the field is multivalued, only the first value is considered.
///
/// Documents that do not have this value are still considered.
/// Their sort key will simply be `None`.
#[derive(Debug, Clone)]
pub struct SortByStaticFastValue<T: FastValue> {
field: String,
typ: PhantomData<T>,
}
impl<T: FastValue> SortByStaticFastValue<T> {
/// Creates a new `SortByStaticFastValue` instance for the given field.
pub fn for_field(column_name: impl ToString) -> SortByStaticFastValue<T> {
Self {
field: column_name.to_string(),
typ: PhantomData,
}
}
}
impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
type Child = SortByFastValueSegmentSortKeyComputer<T>;
type SortKey = Option<T>;
type Comparator = NaturalComparator;
fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> {
// At the segment sort key computer level, we rely on the u64 representation.
// The mapping is monotonic, so it is sufficient to compute our top-K docs.
let field = schema.get_field(&self.field)?;
let field_entry = schema.get_field_entry(field);
if !field_entry.is_fast() {
return Err(crate::TantivyError::SchemaError(format!(
"Field `{}` is not a fast field.",
self.field,
)));
}
let schema_type = field_entry.field_type().value_type();
if schema_type != T::to_type() {
return Err(crate::TantivyError::SchemaError(format!(
"Field `{}` is of type {schema_type:?}, not of the type {:?}.",
&self.field,
T::to_type()
)));
}
Ok(())
}
fn segment_sort_key_computer(
&self,
segment_reader: &SegmentReader,
) -> crate::Result<Self::Child> {
let sort_column_opt = segment_reader.fast_fields().u64_lenient(&self.field)?;
let (sort_column, _sort_column_type) =
sort_column_opt.ok_or_else(|| FastFieldNotAvailableError {
field_name: self.field.clone(),
})?;
Ok(SortByFastValueSegmentSortKeyComputer {
sort_column,
typ: PhantomData,
})
}
}
pub struct SortByFastValueSegmentSortKeyComputer<T> {
sort_column: Column<u64>,
typ: PhantomData<T>,
}
impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer<T> {
type SortKey = Option<T>;
type SegmentSortKey = Option<u64>;
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey {
self.sort_column.first(doc)
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
sort_key.map(T::from_u64)
}
}

View File

@@ -0,0 +1,72 @@
use columnar::StrColumn;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::termdict::TermOrdinal;
use crate::{DocId, Score};
/// Sort by the first value of a string column.
///
/// The string can be dynamic (coming from a json field)
/// or static (being specificaly defined in the configuration).
///
/// If the field is multivalued, only the first value is considered.
///
/// Documents that do not have this value are still considered.
/// Their sort key will simply be `None`.
#[derive(Debug, Clone)]
pub struct SortByString {
column_name: String,
}
impl SortByString {
/// Creates a new sort by string sort key computer.
pub fn for_field(column_name: impl ToString) -> Self {
SortByString {
column_name: column_name.to_string(),
}
}
}
impl SortKeyComputer for SortByString {
type SortKey = Option<String>;
type Child = ByStringColumnSegmentSortKeyComputer;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let str_column_opt = segment_reader.fast_fields().str(&self.column_name)?;
Ok(ByStringColumnSegmentSortKeyComputer { str_column_opt })
}
}
pub struct ByStringColumnSegmentSortKeyComputer {
str_column_opt: Option<StrColumn>,
}
impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
type SortKey = Option<String>;
type SegmentSortKey = Option<TermOrdinal>;
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
let str_column = self.str_column_opt.as_ref()?;
str_column.ords().first(doc)
}
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> {
let term_ord = term_ord_opt?;
let str_column = self.str_column_opt.as_ref()?;
let mut bytes = Vec::new();
str_column
.dictionary()
.ord_to_term(term_ord, &mut bytes)
.ok()?;
String::try_from(bytes).ok()
}
}

View File

@@ -0,0 +1,631 @@
use std::cmp::Ordering;
use crate::collector::sort_key::{Comparator, NaturalComparator};
use crate::collector::sort_key_top_collector::TopBySortKeySegmentCollector;
use crate::collector::{default_collect_segment_impl, SegmentCollector as _, TopNComputer};
use crate::schema::Schema;
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
/// A `SegmentSortKeyComputer` makes it possible to modify the default score
/// for a given document belonging to a specific segment.
///
/// It is the segment local version of the [`SortKeyComputer`].
pub trait SegmentSortKeyComputer: 'static {
/// The final score being emitted.
type SortKey: 'static + PartialOrd + Send + Sync + Clone;
/// Sort key used by at the segment level by the `SegmentSortKeyComputer`.
///
/// It is typically small like a `u64`, and is meant to be converted
/// to the final score at the end of the collection of the segment.
type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone;
/// Computes the sort key for the given document and score.
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey;
/// Computes the sort key and pushes the document in a TopN Computer.
///
/// When using a tuple as the sorting key, the sort key is evaluated in a lazy manner.
#[inline(always)]
fn compute_sort_key_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
let sort_key = self.segment_sort_key(doc, score);
top_n_computer.push(sort_key, doc);
}
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering.
///
/// This method must be consistent with the `SortKey` ordering.
#[inline(always)]
fn compare_segment_sort_key(
&self,
left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey,
) -> Ordering {
NaturalComparator.compare(left, right)
}
/// Implementing this method makes it possible to avoid computing
/// a sort_key entirely if we can assess that it won't pass a threshold
/// with a partial computation.
///
/// This is currently used for lexicographic sorting.
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
let sort_key = self.segment_sort_key(doc_id, score);
let cmp = self.compare_segment_sort_key(&sort_key, threshold);
if cmp == Ordering::Less {
None
} else {
Some((cmp, sort_key))
}
}
/// Convert a segment level sort key into the global sort key.
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey;
}
/// `SortKeyComputer` defines the sort key to be used by a TopK Collector.
///
/// The `SortKeyComputer` itself does not make much of the computation itself.
/// Instead, it helps constructing `Self::Child` instances that will compute
/// the sort key at a segment scale.
pub trait SortKeyComputer: Sync {
/// The sort key type.
type SortKey: 'static + Send + Sync + PartialOrd + Clone + std::fmt::Debug;
/// Type of the associated [`SegmentSortKeyComputer`].
type Child: SegmentSortKeyComputer<SortKey = Self::SortKey>;
/// Comparator type.
type Comparator: Comparator<Self::SortKey>
+ Comparator<<Self::Child as SegmentSortKeyComputer>::SegmentSortKey>
+ 'static;
/// Checks whether the schema is compatible with the sort key computer.
fn check_schema(&self, _schema: &Schema) -> crate::Result<()> {
Ok(())
}
/// Returns the sort key comparator.
fn comparator(&self) -> Self::Comparator {
Self::Comparator::default()
}
/// Indicates whether the sort key actually uses the similarity score (by default BM25).
/// If set to false, the similary score might not be computed (as an optimization),
/// and the score fed in the segment sort key computer could take any value.
fn requires_scoring(&self) -> bool {
false
}
/// Sorting by score has a overriding implementation for BM25 scores, using Block-WAND.
fn collect_segment_top_k(
&self,
k: usize,
weight: &dyn crate::query::Weight,
reader: &crate::SegmentReader,
segment_ord: u32,
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
let with_scoring = self.requires_scoring();
let segment_sort_key_computer = self.segment_sort_key_computer(reader)?;
let topn_computer = TopNComputer::new_with_comparator(k, self.comparator());
let mut segment_top_key_collector = TopBySortKeySegmentCollector {
topn_computer,
segment_ord,
segment_sort_key_computer,
};
default_collect_segment_impl(&mut segment_top_key_collector, weight, reader, with_scoring)?;
Ok(segment_top_key_collector.harvest())
}
/// Builds a child sort key computer for a specific segment.
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child>;
}
impl<HeadSortKeyComputer, TailSortKeyComputer> SortKeyComputer
for (HeadSortKeyComputer, TailSortKeyComputer)
where
HeadSortKeyComputer: SortKeyComputer,
TailSortKeyComputer: SortKeyComputer,
{
type SortKey = (
<HeadSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
<TailSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
);
type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child);
type Comparator = (
HeadSortKeyComputer::Comparator,
TailSortKeyComputer::Comparator,
);
fn comparator(&self) -> Self::Comparator {
(self.0.comparator(), self.1.comparator())
}
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok((
self.0.segment_sort_key_computer(segment_reader)?,
self.1.segment_sort_key_computer(segment_reader)?,
))
}
/// Checks whether the schema is compatible with the sort key computer.
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
Ok(())
}
/// Indicates whether the sort key actually uses the similarity score (by default BM25).
/// If set to false, the similary score might not be computed (as an optimization),
/// and the score fed in the segment sort key computer could take any value.
fn requires_scoring(&self) -> bool {
self.0.requires_scoring() || self.1.requires_scoring()
}
}
impl<HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer> SegmentSortKeyComputer
for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer)
where
HeadSegmentSortKeyComputer: SegmentSortKeyComputer,
TailSegmentSortKeyComputer: SegmentSortKeyComputer,
{
type SortKey = (
HeadSegmentSortKeyComputer::SortKey,
TailSegmentSortKeyComputer::SortKey,
);
type SegmentSortKey = (
HeadSegmentSortKeyComputer::SegmentSortKey,
TailSegmentSortKeyComputer::SegmentSortKey,
);
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering.
///
/// By default, it uses the natural ordering.
#[inline]
fn compare_segment_sort_key(
&self,
left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey,
) -> Ordering {
self.0
.compare_segment_sort_key(&left.0, &right.0)
.then_with(|| self.1.compare_segment_sort_key(&left.1, &right.1))
}
#[inline(always)]
fn compute_sort_key_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
let sort_key: Self::SegmentSortKey;
if let Some(threshold) = &top_n_computer.threshold {
if let Some((_cmp, lazy_sort_key)) = self.accept_sort_key_lazy(doc, score, threshold) {
sort_key = lazy_sort_key;
} else {
return;
}
} else {
sort_key = self.segment_sort_key(doc, score);
};
top_n_computer.append_doc(doc, sort_key);
}
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
let head_sort_key = self.0.segment_sort_key(doc, score);
let tail_sort_key = self.1.segment_sort_key(doc, score);
(head_sort_key, tail_sort_key)
}
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
let (head_threshold, tail_threshold) = threshold;
let (head_cmp, head_sort_key) =
self.0.accept_sort_key_lazy(doc_id, score, head_threshold)?;
if head_cmp == Ordering::Equal {
let (tail_cmp, tail_sort_key) =
self.1.accept_sort_key_lazy(doc_id, score, tail_threshold)?;
Some((tail_cmp, (head_sort_key, tail_sort_key)))
} else {
let tail_sort_key = self.1.segment_sort_key(doc_id, score);
Some((head_cmp, (head_sort_key, tail_sort_key)))
}
}
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
let (head_sort_key, tail_sort_key) = sort_key;
(
self.0.convert_segment_sort_key(head_sort_key),
self.1.convert_segment_sort_key(tail_sort_key),
)
}
}
/// This struct is used as an adapter to take a sort key computer and map its score to another
/// new sort key.
pub struct MappedSegmentSortKeyComputer<T, PreviousSortKey, NewSortKey> {
sort_key_computer: T,
map: fn(PreviousSortKey) -> NewSortKey,
}
impl<T, PreviousScore, NewScore> SegmentSortKeyComputer
for MappedSegmentSortKeyComputer<T, PreviousScore, NewScore>
where
T: SegmentSortKeyComputer<SortKey = PreviousScore>,
PreviousScore: 'static + Clone + Send + Sync + PartialOrd,
NewScore: 'static + Clone + Send + Sync + PartialOrd,
{
type SortKey = NewScore;
type SegmentSortKey = T::SegmentSortKey;
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.sort_key_computer.segment_sort_key(doc, score)
}
fn accept_sort_key_lazy(
&mut self,
doc_id: DocId,
score: Score,
threshold: &Self::SegmentSortKey,
) -> Option<(Ordering, Self::SegmentSortKey)> {
self.sort_key_computer
.accept_sort_key_lazy(doc_id, score, threshold)
}
#[inline(always)]
fn compute_sort_key_and_collect<C: Comparator<Self::SegmentSortKey>>(
&mut self,
doc: DocId,
score: Score,
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
) {
self.sort_key_computer
.compute_sort_key_and_collect(doc, score, top_n_computer);
}
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> Self::SortKey {
(self.map)(
self.sort_key_computer
.convert_segment_sort_key(segment_sort_key),
)
}
}
// We then re-use our (head, tail) implement and our mapper by seeing mapping any tuple (a, b, c,
// ...) as the chain (a, (b, (c, ...)))
impl<SortKeyComputer1, SortKeyComputer2, SortKeyComputer3> SortKeyComputer
for (SortKeyComputer1, SortKeyComputer2, SortKeyComputer3)
where
SortKeyComputer1: SortKeyComputer,
SortKeyComputer2: SortKeyComputer,
SortKeyComputer3: SortKeyComputer,
{
type SortKey = (
SortKeyComputer1::SortKey,
SortKeyComputer2::SortKey,
SortKeyComputer3::SortKey,
);
type Child = MappedSegmentSortKeyComputer<
<(SortKeyComputer1, (SortKeyComputer2, SortKeyComputer3)) as SortKeyComputer>::Child,
(
SortKeyComputer1::SortKey,
(SortKeyComputer2::SortKey, SortKeyComputer3::SortKey),
),
Self::SortKey,
>;
type Comparator = (
SortKeyComputer1::Comparator,
SortKeyComputer2::Comparator,
SortKeyComputer3::Comparator,
);
fn comparator(&self) -> Self::Comparator {
(
self.0.comparator(),
self.1.comparator(),
self.2.comparator(),
)
}
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?;
let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?;
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
let map = |(sort_key1, (sort_key2, sort_key3))| (sort_key1, sort_key2, sort_key3);
Ok(MappedSegmentSortKeyComputer {
sort_key_computer: (sort_key_computer1, (sort_key_computer2, sort_key_computer3)),
map,
})
}
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
self.2.check_schema(schema)?;
Ok(())
}
fn requires_scoring(&self) -> bool {
self.0.requires_scoring() || self.1.requires_scoring() || self.2.requires_scoring()
}
}
impl<SortKeyComputer1, SortKeyComputer2, SortKeyComputer3, SortKeyComputer4> SortKeyComputer
for (
SortKeyComputer1,
SortKeyComputer2,
SortKeyComputer3,
SortKeyComputer4,
)
where
SortKeyComputer1: SortKeyComputer,
SortKeyComputer2: SortKeyComputer,
SortKeyComputer3: SortKeyComputer,
SortKeyComputer4: SortKeyComputer,
{
type Child = MappedSegmentSortKeyComputer<
<(
SortKeyComputer1,
(SortKeyComputer2, (SortKeyComputer3, SortKeyComputer4)),
) as SortKeyComputer>::Child,
(
SortKeyComputer1::SortKey,
(
SortKeyComputer2::SortKey,
(SortKeyComputer3::SortKey, SortKeyComputer4::SortKey),
),
),
Self::SortKey,
>;
type SortKey = (
SortKeyComputer1::SortKey,
SortKeyComputer2::SortKey,
SortKeyComputer3::SortKey,
SortKeyComputer4::SortKey,
);
type Comparator = (
SortKeyComputer1::Comparator,
SortKeyComputer2::Comparator,
SortKeyComputer3::Comparator,
SortKeyComputer4::Comparator,
);
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?;
let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?;
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
let sort_key_computer4 = self.3.segment_sort_key_computer(segment_reader)?;
Ok(MappedSegmentSortKeyComputer {
sort_key_computer: (
sort_key_computer1,
(sort_key_computer2, (sort_key_computer3, sort_key_computer4)),
),
map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| {
(sort_key1, sort_key2, sort_key3, sort_key4)
},
})
}
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.0.check_schema(schema)?;
self.1.check_schema(schema)?;
self.2.check_schema(schema)?;
self.3.check_schema(schema)?;
Ok(())
}
fn requires_scoring(&self) -> bool {
self.0.requires_scoring()
|| self.1.requires_scoring()
|| self.2.requires_scoring()
|| self.3.requires_scoring()
}
}
impl<F, SegmentF, TSortKey> SortKeyComputer for F
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> SegmentF,
SegmentF: 'static + FnMut(DocId) -> TSortKey,
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
{
type SortKey = TSortKey;
type Child = SegmentF;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok((self)(segment_reader))
}
}
impl<F, TSortKey> SegmentSortKeyComputer for F
where
F: 'static + FnMut(DocId) -> TSortKey,
TSortKey: 'static + PartialOrd + Clone + Send + Sync,
{
type SortKey = TSortKey;
type SegmentSortKey = TSortKey;
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey {
(self)(doc)
}
/// Convert a segment level score into the global level score.
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
sort_key
}
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use std::sync::Arc;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::schema::Schema;
use crate::{DocId, Index, Order, SegmentReader};
fn build_test_index() -> Index {
let schema = Schema::builder().build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(crate::TantivyDocument::default())
.unwrap();
index_writer.commit().unwrap();
index
}
#[test]
fn test_lazy_score_computer() {
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let score_computer_secondary = move |_segment_reader: &SegmentReader| {
let call_count_new_clone = call_count_clone.clone();
move |_doc: DocId| {
call_count_new_clone.fetch_add(1, AtomicOrdering::SeqCst);
"b"
}
};
let lazy_score_computer = (score_computer_primary, score_computer_secondary);
let index = build_test_index();
let searcher = index.reader().unwrap().searcher();
let mut segment_sort_key_computer = lazy_score_computer
.segment_sort_key_computer(searcher.segment_reader(0))
.unwrap();
let expected_sort_key = (200, "b");
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, "a"));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 1);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, "c"));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 2);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, "a"));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 3);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, "c"));
assert!(sort_key_opt.is_none());
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, "a"));
assert_eq!(sort_key_opt, None);
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, "c"));
assert_eq!(sort_key_opt, None);
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &expected_sort_key);
assert_eq!(sort_key_opt, Some((Ordering::Equal, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 5);
}
}
#[test]
fn test_lazy_score_computer_dynamic_ordering() {
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let score_computer_secondary = move |_segment_reader: &SegmentReader| {
let call_count_new_clone = call_count_clone.clone();
move |_doc: DocId| {
call_count_new_clone.fetch_add(1, AtomicOrdering::SeqCst);
2u32
}
};
let lazy_score_computer = (
(score_computer_primary, Order::Desc),
(score_computer_secondary, Order::Asc),
);
let index = build_test_index();
let searcher = index.reader().unwrap().searcher();
let mut segment_sort_key_computer = lazy_score_computer
.segment_sort_key_computer(searcher.segment_reader(0))
.unwrap();
let expected_sort_key = (200, 2u32);
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, 1u32));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 1);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, 3u32));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 2);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, 1u32));
assert!(sort_key_opt.is_none());
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 3);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, 3u32));
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, 1u32));
assert_eq!(sort_key_opt, None);
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, 3u32));
assert_eq!(sort_key_opt, None);
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
}
{
let sort_key_opt =
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &expected_sort_key);
assert_eq!(sort_key_opt, Some((Ordering::Equal, expected_sort_key)));
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 5);
}
assert_eq!(
segment_sort_key_computer.convert_segment_sort_key(expected_sort_key),
(200u32, 2u32)
);
}
}

View File

@@ -0,0 +1,193 @@
use std::ops::Range;
use crate::collector::sort_key::{Comparator, SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{Collector, SegmentCollector, TopNComputer};
use crate::query::Weight;
use crate::schema::Schema;
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
pub(crate) struct TopBySortKeyCollector<TSortKeyComputer> {
sort_key_computer: TSortKeyComputer,
doc_range: Range<usize>,
}
impl<TSortKeyComputer> TopBySortKeyCollector<TSortKeyComputer> {
pub fn new(sort_key_computer: TSortKeyComputer, doc_range: Range<usize>) -> Self {
TopBySortKeyCollector {
sort_key_computer,
doc_range,
}
}
}
impl<TSortKeyComputer> Collector for TopBySortKeyCollector<TSortKeyComputer>
where TSortKeyComputer: SortKeyComputer + Send + Sync + 'static
{
type Fruit = Vec<(TSortKeyComputer::SortKey, DocAddress)>;
type Child =
TopBySortKeySegmentCollector<TSortKeyComputer::Child, TSortKeyComputer::Comparator>;
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
self.sort_key_computer.check_schema(schema)
}
fn for_segment(&self, segment_ord: u32, segment_reader: &SegmentReader) -> Result<Self::Child> {
let segment_sort_key_computer = self
.sort_key_computer
.segment_sort_key_computer(segment_reader)?;
let topn_computer = TopNComputer::new_with_comparator(
self.doc_range.end,
self.sort_key_computer.comparator(),
);
Ok(TopBySortKeySegmentCollector {
topn_computer,
segment_ord,
segment_sort_key_computer,
})
}
fn requires_scoring(&self) -> bool {
self.sort_key_computer.requires_scoring()
}
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> Result<Self::Fruit> {
Ok(merge_top_k(
segment_fruits.into_iter().flatten(),
self.doc_range.clone(),
self.sort_key_computer.comparator(),
))
}
fn collect_segment(
&self,
weight: &dyn Weight,
segment_ord: u32,
reader: &SegmentReader,
) -> crate::Result<Vec<(TSortKeyComputer::SortKey, DocAddress)>> {
let k = self.doc_range.end;
let docs = self
.sort_key_computer
.collect_segment_top_k(k, weight, reader, segment_ord)?;
Ok(docs)
}
}
fn merge_top_k<D: Ord, TSortKey: Clone + std::fmt::Debug, C: Comparator<TSortKey>>(
sort_key_docs: impl Iterator<Item = (TSortKey, D)>,
doc_range: Range<usize>,
comparator: C,
) -> Vec<(TSortKey, D)> {
if doc_range.is_empty() {
return Vec::new();
}
let mut top_collector: TopNComputer<TSortKey, D, C> =
TopNComputer::new_with_comparator(doc_range.end, comparator);
for (sort_key, doc) in sort_key_docs {
top_collector.push(sort_key, doc);
}
top_collector
.into_sorted_vec()
.into_iter()
.skip(doc_range.start)
.map(|cdoc| (cdoc.sort_key, cdoc.doc))
.collect()
}
pub struct TopBySortKeySegmentCollector<TSegmentSortKeyComputer, C>
where
TSegmentSortKeyComputer: SegmentSortKeyComputer,
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey>,
{
pub(crate) topn_computer: TopNComputer<TSegmentSortKeyComputer::SegmentSortKey, DocId, C>,
pub(crate) segment_ord: u32,
pub(crate) segment_sort_key_computer: TSegmentSortKeyComputer,
}
impl<TSegmentSortKeyComputer, C> SegmentCollector
for TopBySortKeySegmentCollector<TSegmentSortKeyComputer, C>
where
TSegmentSortKeyComputer: 'static + SegmentSortKeyComputer,
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey> + 'static,
{
type Fruit = Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)>;
fn collect(&mut self, doc: DocId, score: Score) {
self.segment_sort_key_computer.compute_sort_key_and_collect(
doc,
score,
&mut self.topn_computer,
);
}
fn harvest(self) -> Self::Fruit {
let segment_ord = self.segment_ord;
let segment_hits: Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)> = self
.topn_computer
.into_vec()
.into_iter()
.map(|comparable_doc| {
let sort_key = self
.segment_sort_key_computer
.convert_segment_sort_key(comparable_doc.sort_key);
(
sort_key,
DocAddress {
segment_ord,
doc_id: comparable_doc.doc,
},
)
})
.collect();
segment_hits
}
}
#[cfg(test)]
mod tests {
use std::ops::Range;
use rand;
use rand::seq::SliceRandom as _;
use super::merge_top_k;
use crate::collector::sort_key::ComparatorEnum;
use crate::Order;
fn test_merge_top_k_aux(
order: Order,
doc_range: Range<usize>,
expected: &[(crate::Score, usize)],
) {
let mut vals: Vec<(crate::Score, usize)> = (0..10).map(|val| (val as f32, val)).collect();
vals.shuffle(&mut rand::thread_rng());
let vals_merged = merge_top_k(vals.into_iter(), doc_range, ComparatorEnum::from(order));
assert_eq!(&vals_merged, expected);
}
#[test]
fn test_merge_top_k() {
test_merge_top_k_aux(Order::Asc, 0..0, &[]);
test_merge_top_k_aux(Order::Asc, 3..3, &[]);
test_merge_top_k_aux(Order::Asc, 0..3, &[(0.0f32, 0), (1.0f32, 1), (2.0f32, 2)]);
test_merge_top_k_aux(
Order::Asc,
0..11,
&[
(0.0f32, 0),
(1.0f32, 1),
(2.0f32, 2),
(3.0f32, 3),
(4.0f32, 4),
(5.0f32, 5),
(6.0f32, 6),
(7.0f32, 7),
(8.0f32, 8),
(9.0f32, 9),
],
);
test_merge_top_k_aux(Order::Asc, 1..3, &[(1.0f32, 1), (2.0f32, 2)]);
test_merge_top_k_aux(Order::Desc, 0..2, &[(9.0f32, 9), (8.0f32, 8)]);
test_merge_top_k_aux(Order::Desc, 2..4, &[(7.0f32, 7), (6.0f32, 6)]);
}
}

View File

@@ -40,7 +40,7 @@ pub fn test_filter_collector() -> crate::Result<()> {
let filter_some_collector = FilterCollector::new(
"price".to_string(),
&|value: u64| value > 20_120u64,
TopDocs::with_limit(2),
TopDocs::with_limit(2).order_by_score(),
);
let top_docs = searcher.search(&query, &filter_some_collector)?;
@@ -50,7 +50,7 @@ pub fn test_filter_collector() -> crate::Result<()> {
let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new(
"price".to_string(),
&|value| value < 5u64,
TopDocs::with_limit(2),
TopDocs::with_limit(2).order_by_score(),
);
let filtered_top_docs = searcher.search(&query, &filter_all_collector).unwrap();
@@ -62,8 +62,11 @@ pub fn test_filter_collector() -> crate::Result<()> {
> 0
}
let filter_dates_collector =
FilterCollector::new("date".to_string(), &date_filter, TopDocs::with_limit(5));
let filter_dates_collector = FilterCollector::new(
"date".to_string(),
&date_filter,
TopDocs::with_limit(5).order_by_score(),
);
let filtered_date_docs = searcher.search(&query, &filter_dates_collector)?;
assert_eq!(filtered_date_docs.len(), 2);

View File

@@ -1,12 +1,7 @@
use std::cmp::Ordering;
use std::marker::PhantomData;
use serde::{Deserialize, Serialize};
use super::top_score_collector::TopNComputer;
use crate::index::SegmentReader;
use crate::{DocAddress, DocId, SegmentOrdinal};
/// Contains a feature (field, score, etc.) of a document along with the document address.
///
/// It guarantees stable sorting: in case of a tie on the feature, the document
@@ -19,7 +14,7 @@ use crate::{DocAddress, DocId, SegmentOrdinal};
pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
/// The feature of the document. In practice, this is
/// is any type that implements `PartialOrd`.
pub feature: T,
pub sort_key: T,
/// The document address. In practice, this is any
/// type that implements `PartialOrd`, and is guaranteed
/// to be unique for each document.
@@ -28,9 +23,9 @@ pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
impl<T: std::fmt::Debug, D: std::fmt::Debug, const R: bool> std::fmt::Debug
for ComparableDoc<T, D, R>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str())
.field("feature", &self.feature)
.field("feature", &self.sort_key)
.field("doc", &self.doc)
.finish()
}
@@ -46,8 +41,8 @@ impl<T: PartialOrd, D: PartialOrd, const R: bool> Ord for ComparableDoc<T, D, R>
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
let by_feature = self
.feature
.partial_cmp(&other.feature)
.sort_key
.partial_cmp(&other.sort_key)
.map(|ord| if R { ord.reverse() } else { ord })
.unwrap_or(Ordering::Equal);
@@ -67,308 +62,3 @@ impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialEq for ComparableDoc<T,
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> Eq for ComparableDoc<T, D, R> {}
pub(crate) struct TopCollector<T> {
pub limit: usize,
pub offset: usize,
_marker: PhantomData<T>,
}
impl<T> TopCollector<T>
where T: PartialOrd + Clone
{
/// Creates a top collector, with a number of documents equal to "limit".
///
/// # Panics
/// The method panics if limit is 0
pub fn with_limit(limit: usize) -> TopCollector<T> {
assert!(limit >= 1, "Limit must be strictly greater than 0.");
Self {
limit,
offset: 0,
_marker: PhantomData,
}
}
/// Skip the first "offset" documents when collecting.
///
/// This is equivalent to `OFFSET` in MySQL or PostgreSQL and `start` in
/// Lucene's TopDocsCollector.
pub fn and_offset(mut self, offset: usize) -> TopCollector<T> {
self.offset = offset;
self
}
pub fn merge_fruits(
&self,
children: Vec<Vec<(T, DocAddress)>>,
) -> crate::Result<Vec<(T, DocAddress)>> {
if self.limit == 0 {
return Ok(Vec::new());
}
let mut top_collector: TopNComputer<_, _> = TopNComputer::new(self.limit + self.offset);
for child_fruit in children {
for (feature, doc) in child_fruit {
top_collector.push(feature, doc);
}
}
Ok(top_collector
.into_sorted_vec()
.into_iter()
.skip(self.offset)
.map(|cdoc| (cdoc.feature, cdoc.doc))
.collect())
}
pub(crate) fn for_segment<F: PartialOrd + Clone>(
&self,
segment_id: SegmentOrdinal,
_: &SegmentReader,
) -> TopSegmentCollector<F> {
TopSegmentCollector::new(segment_id, self.limit + self.offset)
}
/// Create a new TopCollector with the same limit and offset.
///
/// Ideally we would use Into but the blanket implementation seems to cause the Scorer traits
/// to fail.
#[doc(hidden)]
pub(crate) fn into_tscore<TScore: PartialOrd + Clone>(self) -> TopCollector<TScore> {
TopCollector {
limit: self.limit,
offset: self.offset,
_marker: PhantomData,
}
}
}
/// The Top Collector keeps track of the K documents
/// sorted by type `T`.
///
/// The implementation is based on a repeatedly truncating on the median after K * 2 documents
/// The theoretical complexity for collecting the top `K` out of `n` documents
/// is `O(n + K)`.
pub(crate) struct TopSegmentCollector<T> {
/// We reverse the order of the feature in order to
/// have top-semantics instead of bottom semantics.
topn_computer: TopNComputer<T, DocId>,
segment_ord: u32,
}
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
fn new(segment_ord: SegmentOrdinal, limit: usize) -> TopSegmentCollector<T> {
TopSegmentCollector {
topn_computer: TopNComputer::new(limit),
segment_ord,
}
}
}
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
pub fn harvest(self) -> Vec<(T, DocAddress)> {
let segment_ord = self.segment_ord;
self.topn_computer
.into_sorted_vec()
.into_iter()
.map(|comparable_doc| {
(
comparable_doc.feature,
DocAddress {
segment_ord,
doc_id: comparable_doc.doc,
},
)
})
.collect()
}
/// Collects a document scored by the given feature
///
/// It collects documents until it has reached the max capacity. Once it reaches capacity, it
/// will compare the lowest scoring item with the given one and keep whichever is greater.
#[inline]
pub fn collect(&mut self, doc: DocId, feature: T) {
self.topn_computer.push(feature, doc);
}
}
#[cfg(test)]
mod tests {
use super::{TopCollector, TopSegmentCollector};
use crate::DocAddress;
#[test]
fn test_top_collector_not_at_capacity() {
let mut top_collector = TopSegmentCollector::new(0, 4);
top_collector.collect(1, 0.8);
top_collector.collect(3, 0.2);
top_collector.collect(5, 0.3);
assert_eq!(
top_collector.harvest(),
vec![
(0.8, DocAddress::new(0, 1)),
(0.3, DocAddress::new(0, 5)),
(0.2, DocAddress::new(0, 3))
]
);
}
#[test]
fn test_top_collector_at_capacity() {
let mut top_collector = TopSegmentCollector::new(0, 4);
top_collector.collect(1, 0.8);
top_collector.collect(3, 0.2);
top_collector.collect(5, 0.3);
top_collector.collect(7, 0.9);
top_collector.collect(9, -0.2);
assert_eq!(
top_collector.harvest(),
vec![
(0.9, DocAddress::new(0, 7)),
(0.8, DocAddress::new(0, 1)),
(0.3, DocAddress::new(0, 5)),
(0.2, DocAddress::new(0, 3))
]
);
}
#[test]
fn test_top_segment_collector_stable_ordering_for_equal_feature() {
// given that the documents are collected in ascending doc id order,
// when harvesting we have to guarantee stable sorting in case of a tie
// on the score
let doc_ids_collection = [4, 5, 6];
let score = 3.3f32;
let mut top_collector_limit_2 = TopSegmentCollector::new(0, 2);
for id in &doc_ids_collection {
top_collector_limit_2.collect(*id, score);
}
let mut top_collector_limit_3 = TopSegmentCollector::new(0, 3);
for id in &doc_ids_collection {
top_collector_limit_3.collect(*id, score);
}
assert_eq!(
top_collector_limit_2.harvest(),
top_collector_limit_3.harvest()[..2].to_vec(),
);
}
#[test]
fn test_top_collector_with_limit_and_offset() {
let collector = TopCollector::with_limit(2).and_offset(1);
let results = collector
.merge_fruits(vec![vec![
(0.9, DocAddress::new(0, 1)),
(0.8, DocAddress::new(0, 2)),
(0.7, DocAddress::new(0, 3)),
(0.6, DocAddress::new(0, 4)),
(0.5, DocAddress::new(0, 5)),
]])
.unwrap();
assert_eq!(
results,
vec![(0.8, DocAddress::new(0, 2)), (0.7, DocAddress::new(0, 3)),]
);
}
#[test]
fn test_top_collector_with_limit_larger_than_set_and_offset() {
let collector = TopCollector::with_limit(2).and_offset(1);
let results = collector
.merge_fruits(vec![vec![
(0.9, DocAddress::new(0, 1)),
(0.8, DocAddress::new(0, 2)),
]])
.unwrap();
assert_eq!(results, vec![(0.8, DocAddress::new(0, 2)),]);
}
#[test]
fn test_top_collector_with_limit_and_offset_larger_than_set() {
let collector = TopCollector::with_limit(2).and_offset(20);
let results = collector
.merge_fruits(vec![vec![
(0.9, DocAddress::new(0, 1)),
(0.8, DocAddress::new(0, 2)),
]])
.unwrap();
assert_eq!(results, vec![]);
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use test::Bencher;
use super::TopSegmentCollector;
#[bench]
fn bench_top_segment_collector_collect_not_at_capacity(b: &mut Bencher) {
let mut top_collector = TopSegmentCollector::new(0, 400);
b.iter(|| {
for i in 0..100 {
top_collector.collect(i, 0.8);
}
});
}
#[bench]
fn bench_top_segment_collector_collect_at_capacity(b: &mut Bencher) {
let mut top_collector = TopSegmentCollector::new(0, 100);
for i in 0..100 {
top_collector.collect(i, 0.8);
}
b.iter(|| {
for i in 0..100 {
top_collector.collect(i, 0.8);
}
});
}
#[bench]
fn bench_top_segment_collector_collect_and_harvest_many_ties(b: &mut Bencher) {
b.iter(|| {
let mut top_collector = TopSegmentCollector::new(0, 100);
for i in 0..100 {
top_collector.collect(i, 0.8);
}
// it would be nice to be able to do the setup N times but still
// measure only harvest(). We can't since harvest() consumes
// the top_collector.
top_collector.harvest()
});
}
#[bench]
fn bench_top_segment_collector_collect_and_harvest_no_tie(b: &mut Bencher) {
b.iter(|| {
let mut top_collector = TopSegmentCollector::new(0, 100);
let mut score = 1.0;
for i in 0..100 {
score += 1.0;
top_collector.collect(i, score);
}
// it would be nice to be able to do the setup N times but still
// measure only harvest(). We can't since harvest() consumes
// the top_collector.
top_collector.harvest()
});
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,124 +0,0 @@
use crate::collector::top_collector::{TopCollector, TopSegmentCollector};
use crate::collector::{Collector, SegmentCollector};
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
pub(crate) struct TweakedScoreTopCollector<TScoreTweaker, TScore = Score> {
score_tweaker: TScoreTweaker,
collector: TopCollector<TScore>,
}
impl<TScoreTweaker, TScore> TweakedScoreTopCollector<TScoreTweaker, TScore>
where TScore: Clone + PartialOrd
{
pub fn new(
score_tweaker: TScoreTweaker,
collector: TopCollector<TScore>,
) -> TweakedScoreTopCollector<TScoreTweaker, TScore> {
TweakedScoreTopCollector {
score_tweaker,
collector,
}
}
}
/// A `ScoreSegmentTweaker` makes it possible to modify the default score
/// for a given document belonging to a specific segment.
///
/// It is the segment local version of the [`ScoreTweaker`].
pub trait ScoreSegmentTweaker<TScore>: 'static {
/// Tweak the given `score` for the document `doc`.
fn score(&mut self, doc: DocId, score: Score) -> TScore;
}
/// `ScoreTweaker` makes it possible to tweak the score
/// emitted by the scorer into another one.
///
/// The `ScoreTweaker` itself does not make much of the computation itself.
/// Instead, it helps constructing `Self::Child` instances that will compute
/// the score at a segment scale.
pub trait ScoreTweaker<TScore>: Sync {
/// Type of the associated [`ScoreSegmentTweaker`].
type Child: ScoreSegmentTweaker<TScore>;
/// Builds a child tweaker for a specific segment. The child scorer is associated with
/// a specific segment.
fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result<Self::Child>;
}
impl<TScoreTweaker, TScore> Collector for TweakedScoreTopCollector<TScoreTweaker, TScore>
where
TScoreTweaker: ScoreTweaker<TScore> + Send + Sync,
TScore: 'static + PartialOrd + Clone + Send + Sync,
{
type Fruit = Vec<(TScore, DocAddress)>;
type Child = TopTweakedScoreSegmentCollector<TScoreTweaker::Child, TScore>;
fn for_segment(
&self,
segment_local_id: u32,
segment_reader: &SegmentReader,
) -> Result<Self::Child> {
let segment_scorer = self.score_tweaker.segment_tweaker(segment_reader)?;
let segment_collector = self.collector.for_segment(segment_local_id, segment_reader);
Ok(TopTweakedScoreSegmentCollector {
segment_collector,
segment_scorer,
})
}
fn requires_scoring(&self) -> bool {
true
}
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> Result<Self::Fruit> {
self.collector.merge_fruits(segment_fruits)
}
}
pub struct TopTweakedScoreSegmentCollector<TSegmentScoreTweaker, TScore>
where
TScore: 'static + PartialOrd + Clone + Send + Sync + Sized,
TSegmentScoreTweaker: ScoreSegmentTweaker<TScore>,
{
segment_collector: TopSegmentCollector<TScore>,
segment_scorer: TSegmentScoreTweaker,
}
impl<TSegmentScoreTweaker, TScore> SegmentCollector
for TopTweakedScoreSegmentCollector<TSegmentScoreTweaker, TScore>
where
TScore: 'static + PartialOrd + Clone + Send + Sync,
TSegmentScoreTweaker: 'static + ScoreSegmentTweaker<TScore>,
{
type Fruit = Vec<(TScore, DocAddress)>;
fn collect(&mut self, doc: DocId, score: Score) {
let score = self.segment_scorer.score(doc, score);
self.segment_collector.collect(doc, score);
}
fn harvest(self) -> Vec<(TScore, DocAddress)> {
self.segment_collector.harvest()
}
}
impl<F, TScore, TSegmentScoreTweaker> ScoreTweaker<TScore> for F
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> TSegmentScoreTweaker,
TSegmentScoreTweaker: ScoreSegmentTweaker<TScore>,
{
type Child = TSegmentScoreTweaker;
fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok((self)(segment_reader))
}
}
impl<F, TScore> ScoreSegmentTweaker<TScore> for F
where F: 'static + FnMut(DocId, Score) -> TScore
{
fn score(&mut self, doc: DocId, score: Score) -> TScore {
(self)(doc, score)
}
}

View File

@@ -69,7 +69,7 @@ fn assert_date_time_precision(index: &Index, doc_store_precision: DateTimePrecis
.parse_query("dateformat")
.expect("Failed to parse query");
let top_docs = searcher
.search(&query, &TopDocs::with_limit(1))
.search(&query, &TopDocs::with_limit(1).order_by_score())
.expect("Search failed");
assert_eq!(top_docs.len(), 1, "Expected 1 search result");

View File

@@ -48,7 +48,15 @@ impl Executor {
F: Sized + Sync + Fn(A) -> crate::Result<R>,
{
match self {
Executor::SingleThread => args.map(f).collect::<crate::Result<_>>(),
Executor::SingleThread => {
// Avoid `collect`, since the stacktrace is blown up by it, which makes profiling
// harder.
let mut result = Vec::with_capacity(args.size_hint().0);
for arg in args {
result.push(f(arg)?);
}
Ok(result)
}
Executor::ThreadPool(pool) => {
let args: Vec<A> = args.collect();
let num_fruits = args.len();

View File

@@ -3,6 +3,7 @@ use common::json_path_writer::{JSON_END_OF_PATH, JSON_PATH_SEGMENT_SEP};
use common::{replace_in_place, JsonPathWriter};
use rustc_hash::FxHashMap;
use crate::indexer::indexing_term::IndexingTerm;
use crate::postings::{IndexingContext, IndexingPosition, PostingsWriter};
use crate::schema::document::{ReferenceValue, ReferenceValueLeaf, Value};
use crate::schema::{Type, DATE_TIME_PRECISION_INDEXED};
@@ -77,7 +78,7 @@ fn index_json_object<'a, V: Value<'a>>(
doc: DocId,
json_visitor: V::ObjectIter,
text_analyzer: &mut TextAnalyzer,
term_buffer: &mut Term,
term_buffer: &mut IndexingTerm,
json_path_writer: &mut JsonPathWriter,
postings_writer: &mut dyn PostingsWriter,
ctx: &mut IndexingContext,
@@ -107,17 +108,17 @@ pub(crate) fn index_json_value<'a, V: Value<'a>>(
doc: DocId,
json_value: V,
text_analyzer: &mut TextAnalyzer,
term_buffer: &mut Term,
term_buffer: &mut IndexingTerm,
json_path_writer: &mut JsonPathWriter,
postings_writer: &mut dyn PostingsWriter,
ctx: &mut IndexingContext,
positions_per_path: &mut IndexingPositionsPerPath,
) {
let set_path_id = |term_buffer: &mut Term, unordered_id: u32| {
let set_path_id = |term_buffer: &mut IndexingTerm, unordered_id: u32| {
term_buffer.truncate_value_bytes(0);
term_buffer.append_bytes(&unordered_id.to_be_bytes());
};
let set_type = |term_buffer: &mut Term, typ: Type| {
let set_type = |term_buffer: &mut IndexingTerm, typ: Type| {
term_buffer.append_bytes(&[typ.to_code()]);
};

View File

@@ -225,6 +225,7 @@ impl Searcher {
enabled_scoring: EnableScoring,
) -> crate::Result<C::Fruit> {
let weight = query.weight(enabled_scoring)?;
collector.check_schema(self.schema())?;
let segment_readers = self.segment_readers();
let fruits = executor.map(
|(segment_ord, segment_reader)| {

View File

@@ -108,7 +108,7 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
/// Opens a file and returns a boxed `FileHandle`.
///
/// Users of `Directory` should typically call `Directory::open_read(...)`,
/// while `Directory` implementor should implement `get_file_handle()`.
/// while `Directory` implementer should implement `get_file_handle()`.
fn get_file_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>, OpenReadError>;
/// Once a virtual file is open, its data may not

View File

@@ -104,7 +104,7 @@ pub enum TantivyError {
#[error("{0:?}")]
IncompatibleIndex(Incompatibility),
/// An internal error occurred. This is are internal states that should not be reached.
/// e.g. a datastructure is incorrectly inititalized.
/// e.g. a datastructure is incorrectly initialized.
#[error("Internal error: '{0}'")]
InternalError(String),
#[error("Deserialize error: {0}")]

View File

@@ -726,22 +726,22 @@ mod tests {
.column_opt::<DateTime>("multi_date")
.unwrap()
.unwrap();
let mut dates = Vec::new();
{
assert_eq!(date_fast_field.get_val(0).into_timestamp_nanos(), 1i64);
dates_fast_field.fill_vals(0u32, &mut dates);
let dates: Vec<DateTime> = dates_fast_field.values_for_doc(0u32).collect();
assert_eq!(dates.len(), 2);
assert_eq!(dates[0].into_timestamp_nanos(), 2i64);
assert_eq!(dates[1].into_timestamp_nanos(), 3i64);
}
{
assert_eq!(date_fast_field.get_val(1).into_timestamp_nanos(), 4i64);
dates_fast_field.fill_vals(1u32, &mut dates);
let dates: Vec<DateTime> = dates_fast_field.values_for_doc(1u32).collect();
assert!(dates.is_empty());
}
{
assert_eq!(date_fast_field.get_val(2).into_timestamp_nanos(), 0i64);
dates_fast_field.fill_vals(2u32, &mut dates);
let dates: Vec<DateTime> = dates_fast_field.values_for_doc(2u32).collect();
assert_eq!(dates.len(), 2);
assert_eq!(dates[0].into_timestamp_nanos(), 5i64);
assert_eq!(dates[1].into_timestamp_nanos(), 6i64);

View File

@@ -276,13 +276,14 @@ impl Default for IndexSettings {
}
/// The order to sort by
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub enum Order {
/// Ascending Order
Asc,
/// Descending Order
Desc,
}
impl Order {
/// return if the Order is ascending
pub fn is_asc(&self) -> bool {

View File

@@ -608,7 +608,7 @@ mod test {
term_dictionary_size: Some(ByteCount::from(100u64)),
postings_size: Some(ByteCount::from(1_000u64)),
positions_size: Some(ByteCount::from(2_000u64)),
fast_size: Some(ByteCount::from(1_000u64).into()),
fast_size: Some(ByteCount::from(1_000u64)),
};
let field_metadata2 = FieldMetadata {
field_name: "a".to_string(),
@@ -617,7 +617,7 @@ mod test {
term_dictionary_size: Some(ByteCount::from(80u64)),
postings_size: Some(ByteCount::from(1_500u64)),
positions_size: Some(ByteCount::from(2_500u64)),
fast_size: Some(ByteCount::from(3_000u64).into()),
fast_size: Some(ByteCount::from(3_000u64)),
};
let expected = FieldMetadata {
field_name: "a".to_string(),
@@ -626,7 +626,7 @@ mod test {
term_dictionary_size: Some(ByteCount::from(180u64)),
postings_size: Some(ByteCount::from(2_500u64)),
positions_size: Some(ByteCount::from(4_500u64)),
fast_size: Some(ByteCount::from(4_000u64).into()),
fast_size: Some(ByteCount::from(4_000u64)),
};
assert_merge(
&[vec![field_metadata1.clone()], vec![field_metadata2]],

View File

@@ -513,7 +513,7 @@ impl<D: Document> IndexWriter<D> {
/// let searcher = index.reader()?.searcher();
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query_promo = query_parser.parse_query("Prometheus")?;
/// let top_docs_promo = searcher.search(&query_promo, &TopDocs::with_limit(1))?;
/// let top_docs_promo = searcher.search(&query_promo, &TopDocs::with_limit(1).order_by_score())?;
///
/// assert!(top_docs_promo.is_empty());
/// Ok(())
@@ -946,11 +946,11 @@ mod tests {
let searcher = reader.searcher();
let a_docs = searcher
.search(&a_query, &TopDocs::with_limit(1))
.search(&a_query, &TopDocs::with_limit(1).order_by_score())
.expect("search for a failed");
let b_docs = searcher
.search(&b_query, &TopDocs::with_limit(1))
.search(&b_query, &TopDocs::with_limit(1).order_by_score())
.expect("search for b failed");
assert_eq!(a_docs.len(), 1);
@@ -2014,8 +2014,9 @@ mod tests {
let query = QueryParser::for_index(&index, vec![field])
.parse_query(term)
.unwrap();
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(1000)).unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(1000).order_by_score())
.unwrap();
top_docs.iter().map(|el| el.1).collect::<Vec<_>>()
};
@@ -2449,8 +2450,9 @@ mod tests {
Term::from_field_u64(id_field, existing_id),
IndexRecordOption::Basic,
);
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(10)).unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(10).order_by_score())
.unwrap();
assert_eq!(top_docs.len(), 1); // Was failing
@@ -2491,8 +2493,9 @@ mod tests {
Term::from_field_i64(id_field, 10i64),
IndexRecordOption::Basic,
);
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(10)).unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(10).order_by_score())
.unwrap();
assert_eq!(top_docs.len(), 1); // Fails
@@ -2500,8 +2503,9 @@ mod tests {
Term::from_field_i64(id_field, 30i64),
IndexRecordOption::Basic,
);
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(10)).unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(10).order_by_score())
.unwrap();
assert_eq!(top_docs.len(), 1); // Fails

View File

@@ -0,0 +1,184 @@
use std::net::Ipv6Addr;
use columnar::MonotonicallyMappableToU128;
use crate::fastfield::FastValue;
use crate::schema::{Field, Type};
/// Term represents the value that the token can take.
/// It's a serialized representation over different types.
///
/// It actually wraps a `Vec<u8>`. The first 5 bytes are metadata.
/// 4 bytes are the field id, and the last byte is the type.
///
/// The serialized value `ValueBytes` is considered everything after the 4 first bytes (term id).
#[derive(Clone)]
pub(crate) struct IndexingTerm<B = Vec<u8>>(B)
where B: AsRef<[u8]>;
/// The number of bytes used as metadata by `Term`.
const TERM_METADATA_LENGTH: usize = 5;
impl IndexingTerm {
/// Create a new Term with a buffer with a given capacity.
pub fn with_capacity(capacity: usize) -> IndexingTerm {
let mut data = Vec::with_capacity(TERM_METADATA_LENGTH + capacity);
data.resize(TERM_METADATA_LENGTH, 0u8);
IndexingTerm(data)
}
/// Panics when the term is not empty... ie: some value is set.
/// Use `clear_with_field_and_type` in that case.
///
/// Sets field and the type.
pub(crate) fn set_field_and_type(&mut self, field: Field, typ: Type) {
assert!(self.is_empty());
self.0[0..4].clone_from_slice(field.field_id().to_be_bytes().as_ref());
self.0[4] = typ.to_code();
}
/// Is empty if there are no value bytes.
pub fn is_empty(&self) -> bool {
self.0.len() == TERM_METADATA_LENGTH
}
/// Removes the value_bytes and set the field and type code.
pub(crate) fn clear_with_field_and_type(&mut self, typ: Type, field: Field) {
self.truncate_value_bytes(0);
self.set_field_and_type(field, typ);
}
/// Sets a u64 value in the term.
///
/// U64 are serialized using (8-byte) BigEndian
/// representation.
/// The use of BigEndian has the benefit of preserving
/// the natural order of the values.
pub fn set_u64(&mut self, val: u64) {
self.set_fast_value(val);
}
/// Sets a `i64` value in the term.
pub fn set_i64(&mut self, val: i64) {
self.set_fast_value(val);
}
/// Sets a `f64` value in the term.
pub fn set_f64(&mut self, val: f64) {
self.set_fast_value(val);
}
/// Sets a `bool` value in the term.
pub fn set_bool(&mut self, val: bool) {
self.set_fast_value(val);
}
fn set_fast_value<T: FastValue>(&mut self, val: T) {
self.set_bytes(val.to_u64().to_be_bytes().as_ref());
}
/// Append a type marker + fast value to a term.
/// This is used in JSON type to append a fast value after the path.
///
/// It will not clear existing bytes.
pub fn append_type_and_fast_value<T: FastValue>(&mut self, val: T) {
self.0.push(T::to_type().to_code());
let value = val.to_u64();
self.0.extend(value.to_be_bytes().as_ref());
}
/// Sets a `Ipv6Addr` value in the term.
pub fn set_ip_addr(&mut self, val: Ipv6Addr) {
self.set_bytes(val.to_u128().to_be_bytes().as_ref());
}
/// Sets the value of a `Bytes` field.
pub fn set_bytes(&mut self, bytes: &[u8]) {
self.truncate_value_bytes(0);
self.0.extend(bytes);
}
/// Truncates the value bytes of the term. Value and field type stays the same.
pub fn truncate_value_bytes(&mut self, len: usize) {
self.0.truncate(len + TERM_METADATA_LENGTH);
}
/// The length of the bytes.
pub fn len_bytes(&self) -> usize {
self.0.len() - TERM_METADATA_LENGTH
}
/// Appends value bytes to the Term.
///
/// This function returns the segment that has just been added.
#[inline]
pub fn append_bytes(&mut self, bytes: &[u8]) -> &mut [u8] {
let len_before = self.0.len();
self.0.extend_from_slice(bytes);
&mut self.0[len_before..]
}
}
impl<B> IndexingTerm<B>
where B: AsRef<[u8]>
{
/// Returns the serialized representation of Term.
/// This includes field_id, value type and value.
///
/// Do NOT rely on this byte representation in the index.
/// This value is likely to change in the future.
#[inline]
pub fn serialized_term(&self) -> &[u8] {
self.0.as_ref()
}
}
#[cfg(test)]
mod tests {
use crate::schema::*;
#[test]
pub fn test_term_str() {
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("text", STRING);
let title_field = schema_builder.add_text_field("title", STRING);
let term = Term::from_field_text(title_field, "test");
assert_eq!(term.field(), title_field);
assert_eq!(term.typ(), Type::Str);
assert_eq!(term.value().as_str(), Some("test"))
}
/// Size (in bytes) of the buffer of a fast value (u64, i64, f64, or date) term.
/// <field> + <type byte> + <value len>
///
/// - <field> is a big endian encoded u32 field id
/// - <type_byte>'s most significant bit expresses whether the term is a json term or not The
/// remaining 7 bits are used to encode the type of the value. If this is a JSON term, the
/// type is the type of the leaf of the json.
/// - <value> is, if this is not the json term, a binary representation specific to the type.
/// If it is a JSON Term, then it is prepended with the path that leads to this leaf value.
const FAST_VALUE_TERM_LEN: usize = 4 + 1 + 8;
#[test]
pub fn test_term_u64() {
let mut schema_builder = Schema::builder();
let count_field = schema_builder.add_u64_field("count", INDEXED);
let term = Term::from_field_u64(count_field, 983u64);
assert_eq!(term.field(), count_field);
assert_eq!(term.typ(), Type::U64);
assert_eq!(term.serialized_term().len(), FAST_VALUE_TERM_LEN);
assert_eq!(term.value().as_u64(), Some(983u64))
}
#[test]
pub fn test_term_bool() {
let mut schema_builder = Schema::builder();
let bool_field = schema_builder.add_bool_field("bool", INDEXED);
let term = Term::from_field_bool(bool_field, true);
assert_eq!(term.field(), bool_field);
assert_eq!(term.typ(), Type::Bool);
assert_eq!(term.serialized_term().len(), FAST_VALUE_TERM_LEN);
assert_eq!(term.value().as_bool(), Some(true))
}
}

View File

@@ -104,8 +104,9 @@ mod tests {
let query = QueryParser::for_index(&index, vec![my_text_field])
.parse_query(term)
.unwrap();
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(3)).unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(3).order_by_score())
.unwrap();
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>()
};

View File

@@ -1518,7 +1518,8 @@ mod tests {
let searcher = reader.searcher();
let mut term_scorer = term_query
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
.specialized_scorer(searcher.segment_reader(0u32), 1.0)?;
.term_scorer_for_test(searcher.segment_reader(0u32), 1.0)?
.unwrap();
assert_eq!(term_scorer.doc(), 0);
assert_nearly_equals!(term_scorer.block_max_score(), 0.0079681855);
assert_nearly_equals!(term_scorer.score(), 0.0079681855);
@@ -1533,7 +1534,8 @@ mod tests {
for segment_reader in searcher.segment_readers() {
let mut term_scorer = term_query
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
.specialized_scorer(segment_reader, 1.0)?;
.term_scorer_for_test(segment_reader, 1.0)?
.unwrap();
// the difference compared to before is intrinsic to the bm25 formula. no worries
// there.
for doc in segment_reader.doc_ids_alive() {
@@ -1558,7 +1560,8 @@ mod tests {
let segment_reader = searcher.segment_reader(0u32);
let mut term_scorer = term_query
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
.specialized_scorer(segment_reader, 1.0)?;
.term_scorer_for_test(segment_reader, 1.0)?
.unwrap();
// the difference compared to before is intrinsic to the bm25 formula. no worries there.
for doc in segment_reader.doc_ids_alive() {
assert_eq!(term_scorer.doc(), doc);

View File

@@ -12,6 +12,7 @@ mod doc_opstamp_mapping;
mod flat_map_with_buffer;
pub(crate) mod index_writer;
pub(crate) mod index_writer_status;
pub(crate) mod indexing_term;
mod log_merge_policy;
mod merge_index_test;
mod merge_operation;
@@ -181,6 +182,7 @@ mod tests_mmap {
let field_name_out = ".";
test_json_field_name(field_name_in, field_name_out);
}
#[test]
fn test_json_field_dot() {
// Test when field name contains a '.'
@@ -587,7 +589,9 @@ mod tests_mmap {
};
let query_str = &format!("{}:{}", indexed_field.field_name, val);
let query = query_parser.parse_query(query_str).unwrap();
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2)).unwrap();
let count_docs = searcher
.search(&*query, &TopDocs::with_limit(2).order_by_score())
.unwrap();
if indexed_field.field_name.contains("empty") || indexed_field.typ == Type::Json {
assert_eq!(count_docs.len(), 0);
} else {
@@ -659,7 +663,9 @@ mod tests_mmap {
for (indexed_field, val) in fields_and_vals.iter() {
let query_str = &format!("{indexed_field}:{val}");
let query = query_parser.parse_query(query_str).unwrap();
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2)).unwrap();
let count_docs = searcher
.search(&*query, &TopDocs::with_limit(2).order_by_score())
.unwrap();
assert!(!count_docs.is_empty(), "{indexed_field}:{val}");
}
// Test if field name can be used for aggregation

View File

@@ -1052,8 +1052,9 @@ mod tests {
let query = QueryParser::for_index(&index, vec![text_field])
.parse_query(term)
.unwrap();
let top_docs: Vec<(f32, DocAddress)> =
searcher.search(&query, &TopDocs::with_limit(3)).unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(3).order_by_score())
.unwrap();
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>()
};

View File

@@ -7,6 +7,7 @@ use super::operation::AddOperation;
use crate::fastfield::FastFieldsWriter;
use crate::fieldnorm::{FieldNormReaders, FieldNormsWriter};
use crate::index::{Segment, SegmentComponent};
use crate::indexer::indexing_term::IndexingTerm;
use crate::indexer::segment_serializer::SegmentSerializer;
use crate::json_utils::{index_json_value, IndexingPositionsPerPath};
use crate::postings::{
@@ -14,7 +15,7 @@ use crate::postings::{
PerFieldPostingsWriter, PostingsWriter,
};
use crate::schema::document::{Document, Value};
use crate::schema::{FieldEntry, FieldType, Schema, Term, DATE_TIME_PRECISION_INDEXED};
use crate::schema::{FieldEntry, FieldType, Schema, DATE_TIME_PRECISION_INDEXED};
use crate::tokenizer::{FacetTokenizer, PreTokenizedStream, TextAnalyzer, Tokenizer};
use crate::{DocId, Opstamp, TantivyError};
@@ -55,7 +56,7 @@ pub struct SegmentWriter {
pub(crate) json_positions_per_path: IndexingPositionsPerPath,
pub(crate) doc_opstamps: Vec<Opstamp>,
per_field_text_analyzers: Vec<TextAnalyzer>,
term_buffer: Term,
term_buffer: IndexingTerm,
schema: Schema,
}
@@ -112,7 +113,7 @@ impl SegmentWriter {
)?,
doc_opstamps: Vec::with_capacity(1_000),
per_field_text_analyzers,
term_buffer: Term::with_capacity(16),
term_buffer: IndexingTerm::with_capacity(16),
schema,
})
}
@@ -519,7 +520,7 @@ mod tests {
.reader()
.unwrap()
.searcher()
.search(&text_query, &TopDocs::with_limit(4))
.search(&text_query, &TopDocs::with_limit(4).order_by_score())
.unwrap();
assert_eq!(score_docs.len(), 1);
@@ -528,7 +529,7 @@ mod tests {
.reader()
.unwrap()
.searcher()
.search(&text_query, &TopDocs::with_limit(4))
.search(&text_query, &TopDocs::with_limit(4).order_by_score())
.unwrap();
assert_eq!(score_docs.len(), 2);
}
@@ -561,7 +562,7 @@ mod tests {
.reader()
.unwrap()
.searcher()
.search(&text_query, &TopDocs::with_limit(4))
.search(&text_query, &TopDocs::with_limit(4).order_by_score())
.unwrap();
assert_eq!(score_docs.len(), 1);
};

View File

@@ -42,7 +42,6 @@ mod test {
use super::Stamper;
#[expect(clippy::redundant_clone)]
#[test]
fn test_stamper() {
let stamper = Stamper::new(7u64);
@@ -58,7 +57,6 @@ mod test {
assert_eq!(stamper.stamp(), 15u64);
}
#[expect(clippy::redundant_clone)]
#[test]
fn test_stamper_revert() {
let stamper = Stamper::new(7u64);

View File

@@ -85,7 +85,7 @@
//! // Perform search.
//! // `topdocs` contains the 10 most relevant doc ids, sorted by decreasing scores...
//! let top_docs: Vec<(Score, DocAddress)> =
//! searcher.search(&query, &TopDocs::with_limit(10))?;
//! searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
//!
//! for (_score, doc_address) in top_docs {
//! // Retrieve the actual content of documents given its `doc_address`.
@@ -125,7 +125,7 @@
//!
//! - **Searching**: [Searcher] searches the segments with anything that implements
//! [Query](query::Query) and merges the results. The list of [supported
//! queries](query::Query#implementors). Custom Queries are supported by implementing the
//! queries](query::Query#implementers). Custom Queries are supported by implementing the
//! [Query](query::Query) trait.
//!
//! - **[Directory](directory)**: Abstraction over the storage where the index data is stored.

View File

@@ -3,13 +3,14 @@ use std::io;
use common::json_path_writer::JSON_END_OF_PATH;
use stacker::Addr;
use crate::indexer::indexing_term::IndexingTerm;
use crate::indexer::path_to_unordered_id::OrderedPathId;
use crate::postings::postings_writer::SpecializedPostingsWriter;
use crate::postings::recorder::{BufferLender, DocIdRecorder, Recorder};
use crate::postings::{FieldSerializer, IndexingContext, IndexingPosition, PostingsWriter};
use crate::schema::{Field, Type};
use crate::schema::{Field, Type, ValueBytes};
use crate::tokenizer::TokenStream;
use crate::{DocId, Term};
use crate::DocId;
/// The `JsonPostingsWriter` is odd in that it relies on a hidden contract:
///
@@ -33,7 +34,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
&mut self,
doc: crate::DocId,
pos: u32,
term: &crate::Term,
term: &IndexingTerm,
ctx: &mut IndexingContext,
) {
self.non_str_posting_writer.subscribe(doc, pos, term, ctx);
@@ -43,7 +44,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
&mut self,
doc_id: DocId,
token_stream: &mut dyn TokenStream,
term_buffer: &mut Term,
term_buffer: &mut IndexingTerm,
ctx: &mut IndexingContext,
indexing_position: &mut IndexingPosition,
) {
@@ -64,40 +65,38 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
ctx: &IndexingContext,
serializer: &mut FieldSerializer,
) -> io::Result<()> {
let mut term_buffer = Term::with_capacity(48);
let mut term_buffer = JsonTermSerializer(Vec::with_capacity(48));
let mut buffer_lender = BufferLender::default();
term_buffer.clear_with_field_and_type(Type::Json, Field::from_field_id(0));
let mut prev_term_id = u32::MAX;
let mut term_path_len = 0; // this will be set in the first iteration
for (_field, path_id, term, addr) in ordered_term_addrs {
if prev_term_id != path_id.path_id() {
term_buffer.truncate_value_bytes(0);
term_buffer.append_path(ordered_id_to_path[path_id.path_id() as usize].as_bytes());
term_buffer.append_bytes(&[JSON_END_OF_PATH]);
term_path_len = term_buffer.len_bytes();
term_buffer.clear();
term_buffer.append_json_path(ordered_id_to_path[path_id.path_id() as usize]);
term_path_len = term_buffer.len();
prev_term_id = path_id.path_id();
}
term_buffer.truncate_value_bytes(term_path_len);
term_buffer.truncate(term_path_len);
term_buffer.append_bytes(term);
if let Some(json_value) = term_buffer.value().as_json_value_bytes() {
let typ = json_value.typ();
if typ == Type::Str {
SpecializedPostingsWriter::<Rec>::serialize_one_term(
term_buffer.serialized_value_bytes(),
*addr,
&mut buffer_lender,
ctx,
serializer,
)?;
} else {
SpecializedPostingsWriter::<DocIdRecorder>::serialize_one_term(
term_buffer.serialized_value_bytes(),
*addr,
&mut buffer_lender,
ctx,
serializer,
)?;
}
let json_value = ValueBytes::wrap(term);
let typ = json_value.typ();
if typ == Type::Str {
SpecializedPostingsWriter::<Rec>::serialize_one_term(
term_buffer.as_bytes(),
*addr,
&mut buffer_lender,
ctx,
serializer,
)?;
} else {
SpecializedPostingsWriter::<DocIdRecorder>::serialize_one_term(
term_buffer.as_bytes(),
*addr,
&mut buffer_lender,
ctx,
serializer,
)?;
}
}
Ok(())
@@ -107,3 +106,48 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
self.str_posting_writer.total_num_tokens() + self.non_str_posting_writer.total_num_tokens()
}
}
struct JsonTermSerializer(Vec<u8>);
impl JsonTermSerializer {
/// Appends a JSON path to the Term.
/// The path is terminated by a special end-of-path 0 byte.
#[inline]
pub fn append_json_path(&mut self, path: &str) {
let bytes = path.as_bytes();
// Replace any occurrence of the end-of-path byte with Ascii '0' byte.
if bytes.contains(&JSON_END_OF_PATH) {
self.0.extend(
bytes
.iter()
.map(|&b| if b == JSON_END_OF_PATH { b'0' } else { b }),
);
} else {
self.0.extend_from_slice(bytes);
}
self.0.push(JSON_END_OF_PATH);
}
/// Appends value bytes to the Term.
///
/// This function returns the segment that has just been added.
#[inline]
pub fn append_bytes(&mut self, bytes: &[u8]) -> &mut [u8] {
let len_before = self.0.len();
self.0.extend_from_slice(bytes);
&mut self.0[len_before..]
}
fn clear(&mut self) {
self.0.clear();
}
fn truncate(&mut self, len: usize) {
self.0.truncate(len);
}
fn len(&self) -> usize {
self.0.len()
}
fn as_bytes(&self) -> &[u8] {
&self.0
}
}

View File

@@ -5,6 +5,7 @@ use std::ops::Range;
use stacker::Addr;
use crate::fieldnorm::FieldNormReaders;
use crate::indexer::indexing_term::IndexingTerm;
use crate::indexer::path_to_unordered_id::OrderedPathId;
use crate::postings::recorder::{BufferLender, Recorder};
use crate::postings::{
@@ -111,7 +112,7 @@ pub(crate) trait PostingsWriter: Send + Sync {
/// * term - the term
/// * ctx - Contains a term hashmap and a memory arena to store all necessary posting list
/// information.
fn subscribe(&mut self, doc: DocId, pos: u32, term: &Term, ctx: &mut IndexingContext);
fn subscribe(&mut self, doc: DocId, pos: u32, term: &IndexingTerm, ctx: &mut IndexingContext);
/// Serializes the postings on disk.
/// The actual serialization format is handled by the `PostingsSerializer`.
@@ -128,7 +129,7 @@ pub(crate) trait PostingsWriter: Send + Sync {
&mut self,
doc_id: DocId,
token_stream: &mut dyn TokenStream,
term_buffer: &mut Term,
term_buffer: &mut IndexingTerm,
ctx: &mut IndexingContext,
indexing_position: &mut IndexingPosition,
) {
@@ -198,7 +199,13 @@ impl<Rec: Recorder> SpecializedPostingsWriter<Rec> {
impl<Rec: Recorder> PostingsWriter for SpecializedPostingsWriter<Rec> {
#[inline]
fn subscribe(&mut self, doc: DocId, position: u32, term: &Term, ctx: &mut IndexingContext) {
fn subscribe(
&mut self,
doc: DocId,
position: u32,
term: &IndexingTerm,
ctx: &mut IndexingContext,
) {
debug_assert!(term.serialized_term().len() >= 4);
self.total_num_tokens += 1;
let (term_index, arena) = (&mut ctx.term_index, &mut ctx.arena);

View File

@@ -1,3 +1,5 @@
use std::sync::Arc;
use crate::fieldnorm::FieldNormReader;
use crate::query::Explanation;
use crate::schema::Field;
@@ -57,13 +59,13 @@ fn cached_tf_component(fieldnorm: u32, average_fieldnorm: Score) -> Score {
K1 * (1.0 - B + B * fieldnorm as Score / average_fieldnorm)
}
fn compute_tf_cache(average_fieldnorm: Score) -> [Score; 256] {
fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> {
let mut cache: [Score; 256] = [0.0; 256];
for (fieldnorm_id, cache_mut) in cache.iter_mut().enumerate() {
let fieldnorm = FieldNormReader::id_to_fieldnorm(fieldnorm_id as u8);
*cache_mut = cached_tf_component(fieldnorm, average_fieldnorm);
}
cache
Arc::new(cache)
}
/// A struct used for computing BM25 scores.
@@ -71,17 +73,20 @@ fn compute_tf_cache(average_fieldnorm: Score) -> [Score; 256] {
pub struct Bm25Weight {
idf_explain: Option<Explanation>,
weight: Score,
cache: [Score; 256],
cache: Arc<[Score; 256]>,
average_fieldnorm: Score,
}
impl Bm25Weight {
/// Increase the weight by a multiplicative factor.
pub fn boost_by(&self, boost: Score) -> Bm25Weight {
if boost == 1.0f32 {
return self.clone();
}
Bm25Weight {
idf_explain: self.idf_explain.clone(),
weight: self.weight * boost,
cache: self.cache,
cache: self.cache.clone(),
average_fieldnorm: self.average_fieldnorm,
}
}

View File

@@ -9,7 +9,7 @@ use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::term_query::TermScorer;
use crate::query::weight::{for_each_docset_buffered, for_each_pruning_scorer, for_each_scorer};
use crate::query::{
intersect_scorers, BufferedUnionScorer, EmptyScorer, Exclude, Explanation, Occur,
intersect_scorers, AllScorer, BufferedUnionScorer, EmptyScorer, Exclude, Explanation, Occur,
RequiredOptionalScorer, Scorer, Weight,
};
use crate::{DocId, Score};
@@ -97,6 +97,15 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
}
}
enum ShouldScorersCombinationMethod {
// Should scorers are irrelevant.
Ignored,
// Only contributes to final score.
Optional(SpecializedScorer),
// Regardless of score, the should scorers may impact whether a document is matching or not.
Required(SpecializedScorer),
}
/// Weight associated to the `BoolQuery`.
pub struct BooleanWeight<TScoreCombiner: ScoreCombiner> {
weights: Vec<(Occur, Box<dyn Weight>)>,
@@ -159,27 +168,50 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
) -> crate::Result<SpecializedScorer> {
let num_docs = reader.num_docs();
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
// Indicate how should clauses are combined with other clauses.
enum CombinationMethod {
Ignored,
// Only contributes to final score.
Optional(SpecializedScorer),
Required(SpecializedScorer),
// Indicate how should clauses are combined with must clauses.
let mut must_scorers: Vec<Box<dyn Scorer>> =
per_occur_scorers.remove(&Occur::Must).unwrap_or_default();
let must_special_scorer_counts = remove_and_count_all_and_empty_scorers(&mut must_scorers);
if must_special_scorer_counts.num_empty_scorers > 0 {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
let mut must_scorers = per_occur_scorers.remove(&Occur::Must);
let should_opt = if let Some(mut should_scorers) = per_occur_scorers.remove(&Occur::Should)
{
let mut should_scorers = per_occur_scorers.remove(&Occur::Should).unwrap_or_default();
let should_special_scorer_counts =
remove_and_count_all_and_empty_scorers(&mut should_scorers);
let mut exclude_scorers: Vec<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::MustNot)
.unwrap_or_default();
let exclude_special_scorer_counts =
remove_and_count_all_and_empty_scorers(&mut exclude_scorers);
if exclude_special_scorer_counts.num_all_scorers > 0 {
// We exclude all documents at one point.
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
let minimum_number_should_match = self
.minimum_number_should_match
.saturating_sub(should_special_scorer_counts.num_all_scorers);
let should_scorers: ShouldScorersCombinationMethod = {
let num_of_should_scorers = should_scorers.len();
if self.minimum_number_should_match > num_of_should_scorers {
if minimum_number_should_match > num_of_should_scorers {
// We don't have enough scorers to satisfy the minimum number of should matches.
// The request will match no documents.
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
match self.minimum_number_should_match {
0 => CombinationMethod::Optional(scorer_union(
match minimum_number_should_match {
0 if num_of_should_scorers == 0 => ShouldScorersCombinationMethod::Ignored,
0 => ShouldScorersCombinationMethod::Optional(scorer_union(
should_scorers,
&score_combiner_fn,
num_docs,
)),
1 => CombinationMethod::Required(scorer_union(
1 => ShouldScorersCombinationMethod::Required(scorer_union(
should_scorers,
&score_combiner_fn,
num_docs,
@@ -187,76 +219,120 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
n if num_of_should_scorers == n => {
// When num_of_should_scorers equals the number of should clauses,
// they are no different from must clauses.
must_scorers = match must_scorers.take() {
Some(mut must_scorers) => {
must_scorers.append(&mut should_scorers);
Some(must_scorers)
}
None => Some(should_scorers),
};
CombinationMethod::Ignored
must_scorers.append(&mut should_scorers);
ShouldScorersCombinationMethod::Ignored
}
_ => CombinationMethod::Required(SpecializedScorer::Other(scorer_disjunction(
should_scorers,
score_combiner_fn(),
self.minimum_number_should_match,
))),
}
} else {
// None of should clauses are provided.
if self.minimum_number_should_match > 0 {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
} else {
CombinationMethod::Ignored
_ => ShouldScorersCombinationMethod::Required(SpecializedScorer::Other(
scorer_disjunction(
should_scorers,
score_combiner_fn(),
self.minimum_number_should_match,
),
)),
}
};
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::MustNot)
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default, num_docs))
.map(|specialized_scorer: SpecializedScorer| {
into_box_scorer(specialized_scorer, DoNothingCombiner::default, num_docs)
});
let positive_scorer = match (should_opt, must_scorers) {
(CombinationMethod::Ignored, Some(must_scorers)) => {
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
let exclude_scorer_opt: Option<Box<dyn Scorer>> = if exclude_scorers.is_empty() {
None
} else {
let exclude_specialized_scorer: SpecializedScorer =
scorer_union(exclude_scorers, DoNothingCombiner::default, num_docs);
Some(into_box_scorer(
exclude_specialized_scorer,
DoNothingCombiner::default,
num_docs,
))
};
let include_scorer = match (should_scorers, must_scorers) {
(ShouldScorersCombinationMethod::Ignored, must_scorers) => {
let boxed_scorer: Box<dyn Scorer> = if must_scorers.is_empty() {
// We do not have any should scorers, nor all scorers.
// There are still two cases here.
//
// If this follows the removal of some AllScorers in the should/must clauses,
// then we match all documents.
//
// Otherwise, it is really just an EmptyScorer.
if must_special_scorer_counts.num_all_scorers
+ should_special_scorer_counts.num_all_scorers
> 0
{
Box::new(AllScorer::new(reader.max_doc()))
} else {
Box::new(EmptyScorer)
}
} else {
intersect_scorers(must_scorers, num_docs)
};
SpecializedScorer::Other(boxed_scorer)
}
(CombinationMethod::Optional(should_scorer), Some(must_scorers)) => {
let must_scorer = intersect_scorers(must_scorers, num_docs);
if self.scoring_enabled {
SpecializedScorer::Other(Box::new(
RequiredOptionalScorer::<_, _, TScoreCombiner>::new(
(ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => {
if must_scorers.is_empty() && must_special_scorer_counts.num_all_scorers == 0 {
// Optional options are promoted to required if no must scorers exists.
should_scorer
} else {
let must_scorer = intersect_scorers(must_scorers, num_docs);
if self.scoring_enabled {
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
_,
_,
TScoreCombiner,
>::new(
must_scorer,
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
),
))
} else {
SpecializedScorer::Other(must_scorer)
)))
} else {
SpecializedScorer::Other(must_scorer)
}
}
}
(CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => {
must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs));
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
(ShouldScorersCombinationMethod::Required(should_scorer), mut must_scorers) => {
if must_scorers.is_empty() {
should_scorer
} else {
must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs));
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
}
}
(CombinationMethod::Ignored, None) => {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)))
}
(CombinationMethod::Required(should_scorer), None) => should_scorer,
// Optional options are promoted to required if no must scorers exists.
(CombinationMethod::Optional(should_scorer), None) => should_scorer,
};
if let Some(exclude_scorer) = exclude_scorer_opt {
let positive_scorer_boxed =
into_box_scorer(positive_scorer, &score_combiner_fn, num_docs);
let include_scorer_boxed =
into_box_scorer(include_scorer, &score_combiner_fn, num_docs);
Ok(SpecializedScorer::Other(Box::new(Exclude::new(
positive_scorer_boxed,
include_scorer_boxed,
exclude_scorer,
))))
} else {
Ok(positive_scorer)
Ok(include_scorer)
}
}
}
#[derive(Default, Copy, Clone, Debug)]
struct AllAndEmptyScorerCounts {
num_all_scorers: usize,
num_empty_scorers: usize,
}
fn remove_and_count_all_and_empty_scorers(
scorers: &mut Vec<Box<dyn Scorer>>,
) -> AllAndEmptyScorerCounts {
let mut counts = AllAndEmptyScorerCounts::default();
scorers.retain(|scorer| {
if scorer.is::<AllScorer>() {
counts.num_all_scorers += 1;
false
} else if scorer.is::<EmptyScorer>() {
counts.num_empty_scorers += 1;
false
} else {
true
}
});
counts
}
impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombiner> {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let num_docs = reader.num_docs();
@@ -293,7 +369,7 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
let mut explanation = Explanation::new("BooleanClause. sum of ...", scorer.score());
for (occur, subweight) in &self.weights {
if is_positive_occur(*occur) {
if is_include_occur(*occur) {
if let Ok(child_explanation) = subweight.explain(reader, doc) {
explanation.add_detail(child_explanation);
}
@@ -377,7 +453,7 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
}
}
fn is_positive_occur(occur: Occur) -> bool {
fn is_include_occur(occur: Occur) -> bool {
match occur {
Occur::Must | Occur::Should => true,
Occur::MustNot => false,

View File

@@ -14,8 +14,8 @@ mod tests {
use crate::collector::TopDocs;
use crate::query::term_query::TermScorer;
use crate::query::{
EnableScoring, Intersection, Occur, Query, QueryParser, RequiredOptionalScorer, Scorer,
SumCombiner, TermQuery,
AllScorer, EmptyScorer, EnableScoring, Intersection, Occur, Query, QueryParser,
RequiredOptionalScorer, Scorer, SumCombiner, TermQuery,
};
use crate::schema::*;
use crate::{assert_nearly_equals, DocAddress, DocId, Index, IndexWriter, Score};
@@ -182,7 +182,7 @@ mod tests {
let matching_topdocs = |query: &dyn Query| {
reader
.searcher()
.search(query, &TopDocs::with_limit(3))
.search(query, &TopDocs::with_limit(3).order_by_score())
.unwrap()
};
@@ -311,4 +311,67 @@ mod tests {
assert_nearly_equals!(explanation.value(), std::f32::consts::LN_2);
Ok(())
}
#[test]
pub fn test_boolean_weight_optimization() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
index_writer.add_document(doc!(text_field=>"hello"))?;
index_writer.add_document(doc!(text_field=>"hello happy"))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let term_match_all: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "hello"),
IndexRecordOption::Basic,
));
let term_match_some: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "happy"),
IndexRecordOption::Basic,
));
let term_match_none: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "tax"),
IndexRecordOption::Basic,
));
{
let query = BooleanQuery::from(vec![
(Occur::Must, term_match_all.box_clone()),
(Occur::Must, term_match_some.box_clone()),
]);
let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
assert!(scorer.is::<TermScorer>());
}
{
let query = BooleanQuery::from(vec![
(Occur::Must, term_match_all.box_clone()),
(Occur::Must, term_match_some.box_clone()),
(Occur::Must, term_match_none.box_clone()),
]);
let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
assert!(scorer.is::<EmptyScorer>());
}
{
let query = BooleanQuery::from(vec![
(Occur::Should, term_match_all.box_clone()),
(Occur::Should, term_match_none.box_clone()),
]);
let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
assert!(scorer.is::<AllScorer>());
}
{
let query = BooleanQuery::from(vec![
(Occur::Should, term_match_some.box_clone()),
(Occur::Should, term_match_none.box_clone()),
]);
let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
assert!(scorer.is::<TermScorer>());
}
Ok(())
}
}

View File

@@ -53,7 +53,7 @@ use crate::{Score, Term};
/// // TermQuery "diary" and "girl" should be present and only one should be accounted in score
/// let queries1 = vec![diary_term_query.box_clone(), girl_term_query.box_clone()];
/// let diary_and_girl = DisjunctionMaxQuery::new(queries1);
/// let documents = searcher.search(&diary_and_girl, &TopDocs::with_limit(3))?;
/// let documents = searcher.search(&diary_and_girl, &TopDocs::with_limit(3).order_by_score())?;
/// assert_eq!(documents[0].0, documents[1].0);
/// assert_eq!(documents[1].0, documents[2].0);
///
@@ -62,7 +62,7 @@ use crate::{Score, Term};
/// let queries2 = vec![diary_term_query.box_clone(), girl_term_query.box_clone()];
/// let tie_breaker = 0.7;
/// let diary_and_girl_with_tie_breaker = DisjunctionMaxQuery::with_tie_breaker(queries2, tie_breaker);
/// let documents = searcher.search(&diary_and_girl_with_tie_breaker, &TopDocs::with_limit(3))?;
/// let documents = searcher.search(&diary_and_girl_with_tie_breaker, &TopDocs::with_limit(3).order_by_score())?;
/// assert_eq!(documents[1].0, documents[2].0);
/// // For this test all terms brings the same score. So we can do easy math and assume that
/// // `DisjunctionMaxQuery` with tie breakers score should be equal

View File

@@ -127,7 +127,11 @@ impl Weight for ExistsWeight {
.any(|col| matches!(col.column_index(), ColumnIndex::Full))
{
let all_scorer = AllScorer::new(max_doc);
return Ok(Box::new(BoostScorer::new(all_scorer, boost)));
if boost != 1.0f32 {
return Ok(Box::new(BoostScorer::new(all_scorer, boost)));
} else {
return Ok(Box::new(all_scorer));
}
}
// If we have a single dynamic column, use ExistsDocSet

View File

@@ -67,7 +67,7 @@ impl Automaton for DfaWrapper {
/// {
/// let term = Term::from_field_text(title, "Diary");
/// let query = FuzzyTermQuery::new(term, 1, true);
/// let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count)).unwrap();
/// let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count)).unwrap();
/// assert_eq!(count, 2);
/// assert_eq!(top_docs.len(), 2);
/// }
@@ -241,7 +241,8 @@ mod test {
{
let term = get_json_path_term("attributes.aa:japan")?;
let fuzzy_query = FuzzyTermQuery::new(term, 2, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
assert_eq!(top_docs[0].1.doc_id, 1, "Expected the second document");
}
@@ -252,7 +253,8 @@ mod test {
let term = get_json_path_term("attributes.a:japon")?;
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
assert_eq!(top_docs[0].1.doc_id, 0, "Expected the first document");
}
@@ -262,7 +264,8 @@ mod test {
let term = get_json_path_term("attributes.a:jap")?;
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 0, "Expected no document");
}
@@ -292,7 +295,8 @@ mod test {
{
let term = Term::from_field_text(country_field, "japon");
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
let (score, _) = top_docs[0];
assert_nearly_equals!(1.0, score);
@@ -303,7 +307,8 @@ mod test {
let term = Term::from_field_text(country_field, "jap");
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 0, "Expected no document");
}
@@ -311,7 +316,8 @@ mod test {
{
let term = Term::from_field_text(country_field, "jap");
let fuzzy_query = FuzzyTermQuery::new_prefix(term, 1, true);
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
let top_docs =
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
let (score, _) = top_docs[0];
assert_nearly_equals!(1.0, score);

View File

@@ -267,7 +267,7 @@ mod tests {
.with_boost_factor(1.0)
.with_stop_words(vec!["old".to_string()])
.with_document(DocAddress::new(0, 0));
let top_docs = searcher.search(&query, &TopDocs::with_limit(5))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(5).order_by_score())?;
let mut doc_ids: Vec<_> = top_docs.iter().map(|item| item.1.doc_id).collect();
doc_ids.sort_unstable();
@@ -283,7 +283,7 @@ mod tests {
.with_max_word_length(5)
.with_boost_factor(1.0)
.with_document(DocAddress::new(0, 4));
let top_docs = searcher.search(&query, &TopDocs::with_limit(5))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(5).order_by_score())?;
let mut doc_ids: Vec<_> = top_docs.iter().map(|item| item.1.doc_id).collect();
doc_ids.sort_unstable();

View File

@@ -266,8 +266,9 @@ mod tests {
use super::RangeQuery;
use crate::collector::{Count, TopDocs};
use crate::indexer::NoMergePolicy;
use crate::query::range_query::fast_field_range_doc_set::RangeDocSet;
use crate::query::range_query::range_query::InvertedIndexRangeQuery;
use crate::query::QueryParser;
use crate::query::{AllScorer, ConstScorer, EmptyScorer, EnableScoring, Query, QueryParser};
use crate::schema::{
Field, IntoIpv6Addr, Schema, TantivyDocument, FAST, INDEXED, STORED, TEXT,
};
@@ -495,7 +496,7 @@ mod tests {
let searcher = reader.searcher();
let query_parser = QueryParser::for_index(&index, vec![title]);
let query = query_parser.parse_query("hemoglobin AND year:[1970 TO 1990]")?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
assert_eq!(top_docs.len(), 1);
Ok(())
}
@@ -549,7 +550,7 @@ mod tests {
let get_num_hits = |query| {
let (_top_docs, count) = searcher
.search(&query, &(TopDocs::with_limit(10), Count))
.search(&query, &(TopDocs::with_limit(10).order_by_score(), Count))
.unwrap();
count
};
@@ -660,4 +661,46 @@ mod tests {
0
);
}
#[test]
fn test_range_query_simplified() {
// This test checks that if the targeted column values are entirely
// within the range, and the column is full, we end up with a AllScorer.
let mut schema_builder = Schema::builder();
let u64_field = schema_builder.add_u64_field("u64_field", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
let mut index_writer = index.writer_for_tests().unwrap();
index_writer.add_document(doc!(u64_field=> 2u64)).unwrap();
index_writer.add_document(doc!(u64_field=> 4u64)).unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
let make_term = |value: u64| Term::from_field_u64(u64_field, value);
let make_scorer = move |lower_bound: Bound<u64>, upper_bound: Bound<u64>| {
let lower_bound_term = lower_bound.map(make_term);
let upper_bound_term = upper_bound.map(make_term);
let range_query = RangeQuery::new(lower_bound_term, upper_bound_term);
let range_weight = range_query
.weight(EnableScoring::disabled_from_schema(&schema))
.unwrap();
let range_scorer = range_weight
.scorer(&searcher.segment_readers()[0], 1.0f32)
.unwrap();
range_scorer
};
let range_scorer = make_scorer(Bound::Included(1), Bound::Included(4));
assert!(range_scorer.is::<AllScorer>());
let range_scorer = make_scorer(Bound::Included(0), Bound::Included(2));
assert!(range_scorer.is::<ConstScorer<RangeDocSet<u64>>>());
let range_scorer = make_scorer(Bound::Included(3), Bound::Included(10));
assert!(range_scorer.is::<ConstScorer<RangeDocSet<u64>>>());
let range_scorer = make_scorer(Bound::Included(10), Bound::Included(12));
assert!(range_scorer.is::<ConstScorer<RangeDocSet<u64>>>());
let range_scorer = make_scorer(Bound::Included(0), Bound::Included(1));
assert!(range_scorer.is::<EmptyScorer>());
let range_scorer = make_scorer(Bound::Included(0), Bound::Excluded(2));
assert!(range_scorer.is::<EmptyScorer>());
}
}

Some files were not shown because too many files have changed in this diff Show More