mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-01-02 15:22:55 +00:00
Compare commits
25 Commits
low_card_o
...
bucket_id_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ddca31292 | ||
|
|
87fe3a311f | ||
|
|
71dc08424c | ||
|
|
15913446b8 | ||
|
|
78bd3826dc | ||
|
|
1b56487307 | ||
|
|
030554d544 | ||
|
|
c852bac532 | ||
|
|
2ce4da8b66 | ||
|
|
0dd6a958f8 | ||
|
|
254314a4a3 | ||
|
|
b2f99c6217 | ||
|
|
76de5bab6f | ||
|
|
b7eb31162b | ||
|
|
63c66005db | ||
|
|
7d513a44c5 | ||
|
|
ca87fcd454 | ||
|
|
08a92675dc | ||
|
|
f7f4b354d6 | ||
|
|
25d44fcec8 | ||
|
|
842fe9295f | ||
|
|
f88b7200b2 | ||
|
|
8725594d47 | ||
|
|
43a784671a | ||
|
|
c363bbd23d |
@@ -78,7 +78,7 @@ This will slightly increase space and access time. [#2439](https://github.com/qu
|
||||
|
||||
- **Store DateTime as nanoseconds in doc store** DateTime in the doc store was truncated to microseconds previously. This removes this truncation, while still keeping backwards compatibility. [#2486](https://github.com/quickwit-oss/tantivy/pull/2486)(@PSeitz)
|
||||
|
||||
- **Performace/Memory**
|
||||
- **Performance/Memory**
|
||||
- lift clauses in LogicalAst for optimized ast during execution [#2449](https://github.com/quickwit-oss/tantivy/pull/2449)(@PSeitz)
|
||||
- Use Vec instead of BTreeMap to back OwnedValue object [#2364](https://github.com/quickwit-oss/tantivy/pull/2364)(@fulmicoton)
|
||||
- Replace TantivyDocument with CompactDoc. CompactDoc is much smaller and provides similar performance. [#2402](https://github.com/quickwit-oss/tantivy/pull/2402)(@PSeitz)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tantivy"
|
||||
version = "0.25.0"
|
||||
version = "0.26.0"
|
||||
authors = ["Paul Masurel <paul.masurel@gmail.com>"]
|
||||
license = "MIT"
|
||||
categories = ["database-implementations", "data-structures"]
|
||||
|
||||
@@ -123,6 +123,7 @@ You can also find other bindings on [GitHub](https://github.com/search?q=tantivy
|
||||
- [seshat](https://github.com/matrix-org/seshat/): A matrix message database/indexer
|
||||
- [tantiny](https://github.com/baygeldin/tantiny): Tiny full-text search for Ruby
|
||||
- [lnx](https://github.com/lnx-search/lnx): adaptable, typo tolerant search engine with a REST API
|
||||
- [Bichon](https://github.com/rustmailer/bichon): A lightweight, high-performance Rust email archiver with WebUI
|
||||
- and [more](https://github.com/search?q=tantivy)!
|
||||
|
||||
### On average, how much faster is Tantivy compared to Lucene?
|
||||
|
||||
2
TODO.txt
2
TODO.txt
@@ -10,7 +10,7 @@ rename FastFieldReaders::open to load
|
||||
remove fast field reader
|
||||
|
||||
find a way to unify the two DateTime.
|
||||
readd type check in the filter wrapper
|
||||
re-add type check in the filter wrapper
|
||||
|
||||
add unit test on columnar list columns.
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use binggan::plugins::PeakMemAllocPlugin;
|
||||
use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM};
|
||||
use rand::distributions::WeightedIndex;
|
||||
use rand::prelude::SliceRandom;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
@@ -53,25 +54,33 @@ fn bench_agg(mut group: InputGroup<Index>) {
|
||||
register!(group, stats_f64);
|
||||
register!(group, extendedstats_f64);
|
||||
register!(group, percentiles_f64);
|
||||
register!(group, terms_few);
|
||||
register!(group, terms_many);
|
||||
register!(group, terms_7);
|
||||
register!(group, terms_all_unique);
|
||||
register!(group, terms_150_000);
|
||||
register!(group, terms_many_top_1000);
|
||||
register!(group, terms_many_order_by_term);
|
||||
register!(group, terms_many_with_top_hits);
|
||||
register!(group, terms_all_unique_with_avg_sub_agg);
|
||||
register!(group, terms_many_with_avg_sub_agg);
|
||||
register!(group, terms_status_with_avg_sub_agg);
|
||||
register!(group, terms_status_with_histogram);
|
||||
register!(group, terms_zipf_1000);
|
||||
register!(group, terms_zipf_1000_with_histogram);
|
||||
register!(group, terms_zipf_1000_with_avg_sub_agg);
|
||||
|
||||
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
|
||||
|
||||
register!(group, cardinality_agg);
|
||||
register!(group, terms_few_with_cardinality_agg);
|
||||
register!(group, terms_status_with_cardinality_agg);
|
||||
|
||||
register!(group, range_agg);
|
||||
register!(group, range_agg_with_avg_sub_agg);
|
||||
register!(group, range_agg_with_term_agg_few);
|
||||
register!(group, range_agg_with_term_agg_status);
|
||||
register!(group, range_agg_with_term_agg_many);
|
||||
register!(group, histogram);
|
||||
register!(group, histogram_hard_bounds);
|
||||
register!(group, histogram_with_avg_sub_agg);
|
||||
register!(group, histogram_with_term_agg_few);
|
||||
register!(group, histogram_with_term_agg_status);
|
||||
register!(group, avg_and_range_with_avg_sub_agg);
|
||||
|
||||
// Filter aggregation benchmarks
|
||||
@@ -130,12 +139,12 @@ fn extendedstats_f64(index: &Index) {
|
||||
}
|
||||
fn percentiles_f64(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"mypercentiles": {
|
||||
"percentiles": {
|
||||
"field": "score_f64",
|
||||
"percents": [ 95, 99, 99.9 ]
|
||||
"mypercentiles": {
|
||||
"percentiles": {
|
||||
"field": "score_f64",
|
||||
"percents": [ 95, 99, 99.9 ]
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
@@ -150,10 +159,10 @@ fn cardinality_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_few_with_cardinality_agg(index: &Index) {
|
||||
fn terms_status_with_cardinality_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms" },
|
||||
"terms": { "field": "text_few_terms_status" },
|
||||
"aggs": {
|
||||
"cardinality": {
|
||||
"cardinality": {
|
||||
@@ -166,13 +175,20 @@ fn terms_few_with_cardinality_agg(index: &Index) {
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_few(index: &Index) {
|
||||
fn terms_7(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_few_terms" } },
|
||||
"my_texts": { "terms": { "field": "text_few_terms_status" } },
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_many(index: &Index) {
|
||||
fn terms_all_unique(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_all_unique_terms" } },
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_150_000(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_many_terms" } },
|
||||
});
|
||||
@@ -220,6 +236,72 @@ fn terms_many_with_avg_sub_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_all_unique_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_all_unique_terms" },
|
||||
"aggs": {
|
||||
"average_f64": { "avg": { "field": "score_f64" } }
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_status_with_histogram(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms_status" },
|
||||
"aggs": {
|
||||
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_zipf_1000_with_histogram(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_1000_terms_zipf" },
|
||||
"aggs": {
|
||||
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
|
||||
}
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_status_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms_status" },
|
||||
"aggs": {
|
||||
"average_f64": { "avg": { "field": "score_f64" } }
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_zipf_1000_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_1000_terms_zipf" },
|
||||
"aggs": {
|
||||
"average_f64": { "avg": { "field": "score_f64" } }
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_zipf_1000(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_1000_terms_zipf" } },
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
@@ -275,7 +357,7 @@ fn range_agg_with_avg_sub_agg(index: &Index) {
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn range_agg_with_term_agg_few(index: &Index) {
|
||||
fn range_agg_with_term_agg_status(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"rangef64": {
|
||||
"range": {
|
||||
@@ -290,7 +372,7 @@ fn range_agg_with_term_agg_few(index: &Index) {
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"my_texts": { "terms": { "field": "text_few_terms" } },
|
||||
"my_texts": { "terms": { "field": "text_few_terms_status" } },
|
||||
}
|
||||
},
|
||||
});
|
||||
@@ -346,12 +428,12 @@ fn histogram_with_avg_sub_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn histogram_with_term_agg_few(index: &Index) {
|
||||
fn histogram_with_term_agg_status(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"rangef64": {
|
||||
"histogram": { "field": "score_f64", "interval": 10 },
|
||||
"aggs": {
|
||||
"my_texts": { "terms": { "field": "text_few_terms" } }
|
||||
"my_texts": { "terms": { "field": "text_few_terms_status" } }
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -396,6 +478,13 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
|
||||
}
|
||||
|
||||
fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
// Flag to use existing index
|
||||
let reuse_index = std::env::var("REUSE_AGG_BENCH_INDEX").is_ok();
|
||||
if reuse_index && std::path::Path::new("agg_bench").exists() {
|
||||
return Index::open_in_dir("agg_bench");
|
||||
}
|
||||
// crreate dir
|
||||
std::fs::create_dir_all("agg_bench")?;
|
||||
let mut schema_builder = Schema::builder();
|
||||
let text_fieldtype = tantivy::schema::TextOptions::default()
|
||||
.set_indexing_options(
|
||||
@@ -404,20 +493,47 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
.set_stored();
|
||||
let text_field = schema_builder.add_text_field("text", text_fieldtype);
|
||||
let json_field = schema_builder.add_json_field("json", FAST);
|
||||
let text_field_all_unique_terms =
|
||||
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
|
||||
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
|
||||
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
|
||||
let text_field_few_terms_status =
|
||||
schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
|
||||
let text_field_1000_terms_zipf =
|
||||
schema_builder.add_text_field("text_1000_terms_zipf", STRING | FAST);
|
||||
let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast();
|
||||
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
|
||||
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
|
||||
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
|
||||
let index = Index::create_from_tempdir(schema_builder.build())?;
|
||||
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"];
|
||||
// use tmp dir
|
||||
let index = if reuse_index {
|
||||
Index::create_in_dir("agg_bench", schema_builder.build())?
|
||||
} else {
|
||||
Index::create_from_tempdir(schema_builder.build())?
|
||||
};
|
||||
// Approximate log proportions
|
||||
let status_field_data = [
|
||||
("INFO", 8000),
|
||||
("ERROR", 300),
|
||||
("WARN", 1200),
|
||||
("DEBUG", 500),
|
||||
("OK", 500),
|
||||
("CRITICAL", 20),
|
||||
("EMERGENCY", 1),
|
||||
];
|
||||
let log_level_distribution =
|
||||
WeightedIndex::new(status_field_data.iter().map(|item| item.1)).unwrap();
|
||||
|
||||
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
|
||||
|
||||
let many_terms_data = (0..150_000)
|
||||
.map(|num| format!("author{num}"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Prepare 1000 unique terms sampled using a Zipf distribution.
|
||||
// Exponent ~1.1 approximates top-20 terms covering around ~20%.
|
||||
let terms_1000: Vec<String> = (1..=1000).map(|i| format!("term_{i}")).collect();
|
||||
let zipf_1000 = rand_distr::Zipf::new(1000, 1.1f64).unwrap();
|
||||
|
||||
{
|
||||
let mut rng = StdRng::from_seed([1u8; 32]);
|
||||
let mut index_writer = index.writer_with_num_threads(1, 200_000_000)?;
|
||||
@@ -427,15 +543,25 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
index_writer.add_document(doc!())?;
|
||||
}
|
||||
if cardinality == Cardinality::Multivalued {
|
||||
let log_level_sample_a = status_field_data[log_level_distribution.sample(&mut rng)].0;
|
||||
let log_level_sample_b = status_field_data[log_level_distribution.sample(&mut rng)].0;
|
||||
let idx_a = zipf_1000.sample(&mut rng) as usize - 1;
|
||||
let idx_b = zipf_1000.sample(&mut rng) as usize - 1;
|
||||
let term_1000_a = &terms_1000[idx_a];
|
||||
let term_1000_b = &terms_1000[idx_b];
|
||||
index_writer.add_document(doc!(
|
||||
json_field => json!({"mixed_type": 10.0}),
|
||||
json_field => json!({"mixed_type": 10.0}),
|
||||
text_field => "cool",
|
||||
text_field => "cool",
|
||||
text_field_all_unique_terms => "cool",
|
||||
text_field_all_unique_terms => "coolo",
|
||||
text_field_many_terms => "cool",
|
||||
text_field_many_terms => "cool",
|
||||
text_field_few_terms => "cool",
|
||||
text_field_few_terms => "cool",
|
||||
text_field_few_terms_status => log_level_sample_a,
|
||||
text_field_few_terms_status => log_level_sample_b,
|
||||
text_field_1000_terms_zipf => term_1000_a.as_str(),
|
||||
text_field_1000_terms_zipf => term_1000_b.as_str(),
|
||||
score_field => 1u64,
|
||||
score_field => 1u64,
|
||||
score_field_f64 => lg_norm.sample(&mut rng),
|
||||
@@ -460,8 +586,10 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
|
||||
index_writer.add_document(doc!(
|
||||
text_field => "cool",
|
||||
json_field => json,
|
||||
text_field_all_unique_terms => format!("unique_term_{}", rng.gen::<u64>()),
|
||||
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
|
||||
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
|
||||
text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0,
|
||||
text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(),
|
||||
score_field => val as u64,
|
||||
score_field_f64 => lg_norm.sample(&mut rng),
|
||||
score_field_i64 => val as i64,
|
||||
@@ -513,7 +641,7 @@ fn filter_agg_all_query_with_sub_aggs(index: &Index) {
|
||||
"avg_score": { "avg": { "field": "score" } },
|
||||
"stats_score": { "stats": { "field": "score_f64" } },
|
||||
"terms_text": {
|
||||
"terms": { "field": "text_few_terms" }
|
||||
"terms": { "field": "text_few_terms_status" }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -529,7 +657,7 @@ fn filter_agg_term_query_with_sub_aggs(index: &Index) {
|
||||
"avg_score": { "avg": { "field": "score" } },
|
||||
"stats_score": { "stats": { "field": "score_f64" } },
|
||||
"terms_text": {
|
||||
"terms": { "field": "text_few_terms" }
|
||||
"terms": { "field": "text_few_terms_status" }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,14 +16,15 @@
|
||||
// - This bench isolates boolean iteration speed and intersection/union cost.
|
||||
// - Use `cargo bench --bench boolean_conjunction` to run.
|
||||
|
||||
use binggan::{black_box, BenchRunner};
|
||||
use binggan::{black_box, BenchGroup, BenchRunner};
|
||||
use rand::prelude::*;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
use tantivy::collector::{Count, TopDocs};
|
||||
use tantivy::query::QueryParser;
|
||||
use tantivy::schema::{Schema, TEXT};
|
||||
use tantivy::{doc, Index, ReloadPolicy, Searcher};
|
||||
use tantivy::collector::sort_key::SortByStaticFastValue;
|
||||
use tantivy::collector::{Collector, Count, TopDocs};
|
||||
use tantivy::query::{Query, QueryParser};
|
||||
use tantivy::schema::{Schema, FAST, TEXT};
|
||||
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct BenchIndex {
|
||||
@@ -33,23 +34,6 @@ struct BenchIndex {
|
||||
query_parser: QueryParser,
|
||||
}
|
||||
|
||||
impl BenchIndex {
|
||||
#[inline(always)]
|
||||
fn count_query(&self, query_str: &str) -> usize {
|
||||
let query = self.query_parser.parse_query(query_str).unwrap();
|
||||
self.searcher.search(&query, &Count).unwrap()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn topk_len(&self, query_str: &str, k: usize) -> usize {
|
||||
let query = self.query_parser.parse_query(query_str).unwrap();
|
||||
self.searcher
|
||||
.search(&query, &TopDocs::with_limit(k))
|
||||
.unwrap()
|
||||
.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a single index containing both fields (title, body) and
|
||||
/// return two BenchIndex views:
|
||||
/// - single_field: QueryParser defaults to only "body"
|
||||
@@ -59,6 +43,8 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
|
||||
let mut schema_builder = Schema::builder();
|
||||
let f_title = schema_builder.add_text_field("title", TEXT);
|
||||
let f_body = schema_builder.add_text_field("body", TEXT);
|
||||
let f_score = schema_builder.add_u64_field("score", FAST);
|
||||
let f_score2 = schema_builder.add_u64_field("score2", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
|
||||
@@ -67,11 +53,13 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
|
||||
|
||||
// Populate: spread each present token 90/10 to body/title
|
||||
{
|
||||
let mut writer = index.writer(500_000_000).unwrap();
|
||||
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
|
||||
for _ in 0..num_docs {
|
||||
let has_a = rng.gen_bool(p_a as f64);
|
||||
let has_b = rng.gen_bool(p_b as f64);
|
||||
let has_c = rng.gen_bool(p_c as f64);
|
||||
let score = rng.gen_range(0u64..100u64);
|
||||
let score2 = rng.gen_range(0u64..100_000u64);
|
||||
let mut title_tokens: Vec<&str> = Vec::new();
|
||||
let mut body_tokens: Vec<&str> = Vec::new();
|
||||
if has_a {
|
||||
@@ -101,7 +89,9 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
|
||||
writer
|
||||
.add_document(doc!(
|
||||
f_title=>title_tokens.join(" "),
|
||||
f_body=>body_tokens.join(" ")
|
||||
f_body=>body_tokens.join(" "),
|
||||
f_score=>score,
|
||||
f_score2=>score2,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
@@ -153,72 +143,76 @@ fn main() {
|
||||
),
|
||||
];
|
||||
|
||||
let queries = &["a", "+a +b", "+a +b +c", "a OR b", "a OR b OR c"];
|
||||
|
||||
let mut runner = BenchRunner::new();
|
||||
for (label, n, pa, pb, pc) in scenarios {
|
||||
let (single_view, multi_view) = build_shared_indices(n, pa, pb, pc);
|
||||
|
||||
// Single-field group: default field is body only
|
||||
for (view_name, bench_index) in [("single_field", single_view), ("multi_field", multi_view)]
|
||||
{
|
||||
// Single-field group: default field is body only
|
||||
let mut group = runner.new_group();
|
||||
group.set_name(format!("single_field — {}", label));
|
||||
group.register_with_input("+a_+b_count", &single_view, |benv: &BenchIndex| {
|
||||
black_box(benv.count_query("+a +b"))
|
||||
});
|
||||
group.register_with_input("+a_+b_+c_count", &single_view, |benv: &BenchIndex| {
|
||||
black_box(benv.count_query("+a +b +c"))
|
||||
});
|
||||
group.register_with_input("+a_+b_top10", &single_view, |benv: &BenchIndex| {
|
||||
black_box(benv.topk_len("+a +b", 10))
|
||||
});
|
||||
group.register_with_input("+a_+b_+c_top10", &single_view, |benv: &BenchIndex| {
|
||||
black_box(benv.topk_len("+a +b +c", 10))
|
||||
});
|
||||
// OR queries
|
||||
group.register_with_input("a_OR_b_count", &single_view, |benv: &BenchIndex| {
|
||||
black_box(benv.count_query("a OR b"))
|
||||
});
|
||||
group.register_with_input("a_OR_b_OR_c_count", &single_view, |benv: &BenchIndex| {
|
||||
black_box(benv.count_query("a OR b OR c"))
|
||||
});
|
||||
group.register_with_input("a_OR_b_top10", &single_view, |benv: &BenchIndex| {
|
||||
black_box(benv.topk_len("a OR b", 10))
|
||||
});
|
||||
group.register_with_input("a_OR_b_OR_c_top10", &single_view, |benv: &BenchIndex| {
|
||||
black_box(benv.topk_len("a OR b OR c", 10))
|
||||
});
|
||||
group.run();
|
||||
}
|
||||
|
||||
// Multi-field group: default fields are [title, body]
|
||||
{
|
||||
let mut group = runner.new_group();
|
||||
group.set_name(format!("multi_field — {}", label));
|
||||
group.register_with_input("+a_+b_count", &multi_view, |benv: &BenchIndex| {
|
||||
black_box(benv.count_query("+a +b"))
|
||||
});
|
||||
group.register_with_input("+a_+b_+c_count", &multi_view, |benv: &BenchIndex| {
|
||||
black_box(benv.count_query("+a +b +c"))
|
||||
});
|
||||
group.register_with_input("+a_+b_top10", &multi_view, |benv: &BenchIndex| {
|
||||
black_box(benv.topk_len("+a +b", 10))
|
||||
});
|
||||
group.register_with_input("+a_+b_+c_top10", &multi_view, |benv: &BenchIndex| {
|
||||
black_box(benv.topk_len("+a +b +c", 10))
|
||||
});
|
||||
// OR queries
|
||||
group.register_with_input("a_OR_b_count", &multi_view, |benv: &BenchIndex| {
|
||||
black_box(benv.count_query("a OR b"))
|
||||
});
|
||||
group.register_with_input("a_OR_b_OR_c_count", &multi_view, |benv: &BenchIndex| {
|
||||
black_box(benv.count_query("a OR b OR c"))
|
||||
});
|
||||
group.register_with_input("a_OR_b_top10", &multi_view, |benv: &BenchIndex| {
|
||||
black_box(benv.topk_len("a OR b", 10))
|
||||
});
|
||||
group.register_with_input("a_OR_b_OR_c_top10", &multi_view, |benv: &BenchIndex| {
|
||||
black_box(benv.topk_len("a OR b OR c", 10))
|
||||
});
|
||||
group.set_name(format!("{} — {}", view_name, label));
|
||||
for query_str in queries {
|
||||
add_bench_task(&mut group, &bench_index, query_str, Count, "count");
|
||||
add_bench_task(
|
||||
&mut group,
|
||||
&bench_index,
|
||||
query_str,
|
||||
TopDocs::with_limit(10).order_by_score(),
|
||||
"top10",
|
||||
);
|
||||
add_bench_task(
|
||||
&mut group,
|
||||
&bench_index,
|
||||
query_str,
|
||||
TopDocs::with_limit(10).order_by_fast_field::<u64>("score", Order::Asc),
|
||||
"top10_by_ff",
|
||||
);
|
||||
add_bench_task(
|
||||
&mut group,
|
||||
&bench_index,
|
||||
query_str,
|
||||
TopDocs::with_limit(10).order_by((
|
||||
SortByStaticFastValue::<u64>::for_field("score"),
|
||||
SortByStaticFastValue::<u64>::for_field("score2"),
|
||||
)),
|
||||
"top10_by_2ff",
|
||||
);
|
||||
}
|
||||
group.run();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_bench_task<C: Collector + 'static>(
|
||||
bench_group: &mut BenchGroup,
|
||||
bench_index: &BenchIndex,
|
||||
query_str: &str,
|
||||
collector: C,
|
||||
collector_name: &str,
|
||||
) {
|
||||
let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name);
|
||||
let query = bench_index.query_parser.parse_query(query_str).unwrap();
|
||||
let search_task = SearchTask {
|
||||
searcher: bench_index.searcher.clone(),
|
||||
collector,
|
||||
query,
|
||||
};
|
||||
bench_group.register(task_name, move |_| black_box(search_task.run()));
|
||||
}
|
||||
|
||||
struct SearchTask<C: Collector> {
|
||||
searcher: Searcher,
|
||||
collector: C,
|
||||
query: Box<dyn Query>,
|
||||
}
|
||||
|
||||
impl<C: Collector> SearchTask<C> {
|
||||
#[inline(never)]
|
||||
pub fn run(&self) -> usize {
|
||||
self.searcher.search(&self.query, &self.collector).unwrap();
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
@@ -258,7 +258,7 @@ mod test {
|
||||
bitpacker.write(val, num_bits, &mut data).unwrap();
|
||||
}
|
||||
bitpacker.close(&mut data).unwrap();
|
||||
assert_eq!(data.len(), ((num_bits as usize) * len + 7) / 8);
|
||||
assert_eq!(data.len(), ((num_bits as usize) * len).div_ceil(8));
|
||||
let bitunpacker = BitUnpacker::new(num_bits);
|
||||
(bitunpacker, vals, data)
|
||||
}
|
||||
@@ -304,7 +304,7 @@ mod test {
|
||||
bitpacker.write(val, num_bits, &mut buffer).unwrap();
|
||||
}
|
||||
bitpacker.flush(&mut buffer).unwrap();
|
||||
assert_eq!(buffer.len(), (vals.len() * num_bits as usize + 7) / 8);
|
||||
assert_eq!(buffer.len(), (vals.len() * num_bits as usize).div_ceil(8));
|
||||
let bitunpacker = BitUnpacker::new(num_bits);
|
||||
let max_val = if num_bits == 64 {
|
||||
u64::MAX
|
||||
|
||||
@@ -19,7 +19,7 @@ fn u32_to_i32(val: u32) -> i32 {
|
||||
#[inline]
|
||||
unsafe fn u32_to_i32_avx2(vals_u32x8s: DataType) -> DataType {
|
||||
const HIGHEST_BIT_MASK: DataType = from_u32x8([HIGHEST_BIT; NUM_LANES]);
|
||||
op_xor(vals_u32x8s, HIGHEST_BIT_MASK)
|
||||
unsafe { op_xor(vals_u32x8s, HIGHEST_BIT_MASK) }
|
||||
}
|
||||
|
||||
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
|
||||
@@ -66,17 +66,19 @@ unsafe fn filter_vec_avx2_aux(
|
||||
]);
|
||||
const SHIFT: __m256i = from_u32x8([NUM_LANES as u32; NUM_LANES]);
|
||||
for _ in 0..num_words {
|
||||
let word = load_unaligned(input);
|
||||
let word = u32_to_i32_avx2(word);
|
||||
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
|
||||
let added_len = keeper_bitset.count_ones();
|
||||
let filtered_doc_ids = compact(ids, keeper_bitset);
|
||||
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
|
||||
output_tail = output_tail.offset(added_len as isize);
|
||||
ids = op_add(ids, SHIFT);
|
||||
input = input.offset(1);
|
||||
unsafe {
|
||||
let word = load_unaligned(input);
|
||||
let word = u32_to_i32_avx2(word);
|
||||
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
|
||||
let added_len = keeper_bitset.count_ones();
|
||||
let filtered_doc_ids = compact(ids, keeper_bitset);
|
||||
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
|
||||
output_tail = output_tail.offset(added_len as isize);
|
||||
ids = op_add(ids, SHIFT);
|
||||
input = input.offset(1);
|
||||
}
|
||||
}
|
||||
output_tail.offset_from(output) as usize
|
||||
unsafe { output_tail.offset_from(output) as usize }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@@ -92,8 +94,7 @@ unsafe fn compute_filter_bitset(val: __m256i, range: std::ops::RangeInclusive<__
|
||||
let too_low = op_greater(*range.start(), val);
|
||||
let too_high = op_greater(val, *range.end());
|
||||
let inside = op_or(too_low, too_high);
|
||||
255 - std::arch::x86_64::_mm256_movemask_ps(std::mem::transmute::<DataType, __m256>(inside))
|
||||
as u8
|
||||
255 - std::arch::x86_64::_mm256_movemask_ps(_mm256_castsi256_ps(inside)) as u8
|
||||
}
|
||||
|
||||
union U8x32 {
|
||||
|
||||
@@ -73,7 +73,7 @@ The crate introduces the following concepts.
|
||||
`Columnar` is an equivalent of a dataframe.
|
||||
It maps `column_key` to `Column`.
|
||||
|
||||
A `Column<T>` asssociates a `RowId` (u32) to any
|
||||
A `Column<T>` associates a `RowId` (u32) to any
|
||||
number of values.
|
||||
|
||||
This is made possible by wrapping a `ColumnIndex` and a `ColumnValue` object.
|
||||
|
||||
@@ -89,13 +89,6 @@ fn main() {
|
||||
black_box(sum);
|
||||
});
|
||||
|
||||
group.register("first_block_fetch", |column| {
|
||||
let mut block: Vec<Option<u64>> = vec![None; 64];
|
||||
let fetch_docids = (0..64).collect::<Vec<_>>();
|
||||
column.first_vals(&fetch_docids, &mut block);
|
||||
black_box(block[0]);
|
||||
});
|
||||
|
||||
group.register("first_block_single_calls", |column| {
|
||||
let mut block: Vec<Option<u64>> = vec![None; 64];
|
||||
let fetch_docids = (0..64).collect::<Vec<_>>();
|
||||
|
||||
@@ -29,12 +29,20 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
pub fn fetch_block_with_missing(&mut self, docs: &[u32], accessor: &Column<T>, missing: T) {
|
||||
pub fn fetch_block_with_missing(
|
||||
&mut self,
|
||||
docs: &[u32],
|
||||
accessor: &Column<T>,
|
||||
missing: Option<T>,
|
||||
) {
|
||||
self.fetch_block(docs, accessor);
|
||||
// no missing values
|
||||
if accessor.index.get_cardinality().is_full() {
|
||||
return;
|
||||
}
|
||||
let Some(missing) = missing else {
|
||||
return;
|
||||
};
|
||||
|
||||
// We can compare docid_cache length with docs to find missing docs
|
||||
// For multi value columns we can't rely on the length and always need to scan
|
||||
|
||||
@@ -131,6 +131,8 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
|
||||
self.index.docids_to_rowids(doc_ids, doc_ids_out, row_ids)
|
||||
}
|
||||
|
||||
/// Get an iterator over the values for the provided docid.
|
||||
#[inline]
|
||||
pub fn values_for_doc(&self, doc_id: DocId) -> impl Iterator<Item = T> + '_ {
|
||||
self.index
|
||||
.value_row_ids(doc_id)
|
||||
@@ -158,15 +160,6 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
|
||||
.select_batch_in_place(selected_docid_range.start, doc_ids);
|
||||
}
|
||||
|
||||
/// Fills the output vector with the (possibly multiple values that are associated_with
|
||||
/// `row_id`.
|
||||
///
|
||||
/// This method clears the `output` vector.
|
||||
pub fn fill_vals(&self, row_id: RowId, output: &mut Vec<T>) {
|
||||
output.clear();
|
||||
output.extend(self.values_for_doc(row_id));
|
||||
}
|
||||
|
||||
pub fn first_or_default_col(self, default_value: T) -> Arc<dyn ColumnValues<T>> {
|
||||
Arc::new(FirstValueWithDefault {
|
||||
column: self,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::fmt::Debug;
|
||||
use std::net::Ipv6Addr;
|
||||
|
||||
/// Montonic maps a value to u128 value space
|
||||
/// Monotonic maps a value to u128 value space
|
||||
/// Monotonic mapping enables `PartialOrd` on u128 space without conversion to original space.
|
||||
pub trait MonotonicallyMappableToU128: 'static + PartialOrd + Copy + Debug + Send + Sync {
|
||||
/// Converts a value to u128.
|
||||
|
||||
@@ -8,7 +8,7 @@ use crate::column_values::ColumnValues;
|
||||
const MID_POINT: u64 = (1u64 << 32) - 1u64;
|
||||
|
||||
/// `Line` describes a line function `y: ax + b` using integer
|
||||
/// arithmetics.
|
||||
/// arithmetic.
|
||||
///
|
||||
/// The slope is in fact a decimal split into a 32 bit integer value,
|
||||
/// and a 32-bit decimal value.
|
||||
@@ -94,7 +94,7 @@ impl Line {
|
||||
// `(i, ys[])`.
|
||||
//
|
||||
// The best intercept therefore has the form
|
||||
// `y[i] - line.eval(i)` (using wrapping arithmetics).
|
||||
// `y[i] - line.eval(i)` (using wrapping arithmetic).
|
||||
// In other words, the best intercept is one of the `y - Line::eval(ys[i])`
|
||||
// and our task is just to pick the one that minimizes our error.
|
||||
//
|
||||
|
||||
@@ -52,7 +52,7 @@ pub trait ColumnCodecEstimator<T = u64>: 'static {
|
||||
) -> io::Result<()>;
|
||||
}
|
||||
|
||||
/// A column codec describes a colunm serialization format.
|
||||
/// A column codec describes a column serialization format.
|
||||
pub trait ColumnCodec<T: PartialOrd = u64> {
|
||||
/// Specialized `ColumnValues` type.
|
||||
type ColumnValues: ColumnValues<T> + 'static;
|
||||
|
||||
@@ -181,6 +181,14 @@ pub struct BitSet {
|
||||
len: u64,
|
||||
max_value: u32,
|
||||
}
|
||||
impl std::fmt::Debug for BitSet {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("BitSet")
|
||||
.field("len", &self.len)
|
||||
.field("max_value", &self.max_value)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
fn num_buckets(max_val: u32) -> u32 {
|
||||
max_val.div_ceil(64u32)
|
||||
|
||||
@@ -28,7 +28,9 @@ impl BinarySerializable for VIntU128 {
|
||||
writer.write_all(&buffer)
|
||||
}
|
||||
|
||||
#[allow(clippy::unbuffered_bytes)]
|
||||
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
|
||||
#[allow(clippy::unbuffered_bytes)]
|
||||
let mut bytes = reader.bytes();
|
||||
let mut result = 0u128;
|
||||
let mut shift = 0u64;
|
||||
@@ -195,7 +197,9 @@ impl BinarySerializable for VInt {
|
||||
writer.write_all(&buffer[0..num_bytes])
|
||||
}
|
||||
|
||||
#[allow(clippy::unbuffered_bytes)]
|
||||
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
|
||||
#[allow(clippy::unbuffered_bytes)]
|
||||
let mut bytes = reader.bytes();
|
||||
let mut result = 0u64;
|
||||
let mut shift = 0u64;
|
||||
|
||||
@@ -208,7 +208,7 @@ fn main() -> tantivy::Result<()> {
|
||||
// is the role of the `TopDocs` collector.
|
||||
|
||||
// We can now perform our query.
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
|
||||
// The actual documents still need to be
|
||||
// retrieved from Tantivy's store.
|
||||
@@ -226,7 +226,7 @@ fn main() -> tantivy::Result<()> {
|
||||
let query = query_parser.parse_query("title:sea^20 body:whale^70")?;
|
||||
|
||||
let (_score, doc_address) = searcher
|
||||
.search(&query, &TopDocs::with_limit(1))?
|
||||
.search(&query, &TopDocs::with_limit(1).order_by_score())?
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap();
|
||||
|
||||
@@ -100,7 +100,7 @@ fn main() -> tantivy::Result<()> {
|
||||
// here we want to get a hit on the 'ken' in Frankenstein
|
||||
let query = query_parser.parse_query("ken")?;
|
||||
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
|
||||
for (_, doc_address) in top_docs {
|
||||
let retrieved_doc: TantivyDocument = searcher.doc(doc_address)?;
|
||||
|
||||
@@ -50,14 +50,14 @@ fn main() -> tantivy::Result<()> {
|
||||
{
|
||||
// Simple exact search on the date
|
||||
let query = query_parser.parse_query("occurred_at:\"2022-06-22T12:53:50.53Z\"")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 1);
|
||||
}
|
||||
{
|
||||
// Range query on the date field
|
||||
let query = query_parser
|
||||
.parse_query(r#"occurred_at:[2022-06-22T12:58:00Z TO 2022-06-23T00:00:00Z}"#)?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(4))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(4).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 1);
|
||||
for (_score, doc_address) in count_docs {
|
||||
let retrieved_doc = searcher.doc::<TantivyDocument>(doc_address)?;
|
||||
|
||||
@@ -28,7 +28,7 @@ fn extract_doc_given_isbn(
|
||||
// The second argument is here to tell we don't care about decoding positions,
|
||||
// or term frequencies.
|
||||
let term_query = TermQuery::new(isbn_term.clone(), IndexRecordOption::Basic);
|
||||
let top_docs = searcher.search(&term_query, &TopDocs::with_limit(1))?;
|
||||
let top_docs = searcher.search(&term_query, &TopDocs::with_limit(1).order_by_score())?;
|
||||
|
||||
if let Some((_score, doc_address)) = top_docs.first() {
|
||||
let doc = searcher.doc(*doc_address)?;
|
||||
|
||||
@@ -145,7 +145,7 @@ fn main() -> tantivy::Result<()> {
|
||||
let query = FuzzyTermQuery::new(term, 2, true);
|
||||
|
||||
let (top_docs, count) = searcher
|
||||
.search(&query, &(TopDocs::with_limit(5), Count))
|
||||
.search(&query, &(TopDocs::with_limit(5).order_by_score(), Count))
|
||||
.unwrap();
|
||||
assert_eq!(count, 3);
|
||||
assert_eq!(top_docs.len(), 3);
|
||||
|
||||
@@ -69,25 +69,25 @@ fn main() -> tantivy::Result<()> {
|
||||
{
|
||||
// Inclusive range queries
|
||||
let query = query_parser.parse_query("ip:[192.168.0.80 TO 192.168.0.100]")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 1);
|
||||
}
|
||||
{
|
||||
// Exclusive range queries
|
||||
let query = query_parser.parse_query("ip:{192.168.0.80 TO 192.168.1.100]")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 0);
|
||||
}
|
||||
{
|
||||
// Find docs with IP addresses smaller equal 192.168.1.100
|
||||
let query = query_parser.parse_query("ip:[* TO 192.168.1.100]")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 2);
|
||||
}
|
||||
{
|
||||
// Find docs with IP addresses smaller than 192.168.1.100
|
||||
let query = query_parser.parse_query("ip:[* TO 192.168.1.100}")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 2);
|
||||
}
|
||||
|
||||
|
||||
@@ -59,12 +59,12 @@ fn main() -> tantivy::Result<()> {
|
||||
let query_parser = QueryParser::for_index(&index, vec![event_type, attributes]);
|
||||
{
|
||||
let query = query_parser.parse_query("target:submit-button")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 2);
|
||||
}
|
||||
{
|
||||
let query = query_parser.parse_query("target:submit")?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(count_docs.len(), 2);
|
||||
}
|
||||
{
|
||||
@@ -74,33 +74,33 @@ fn main() -> tantivy::Result<()> {
|
||||
}
|
||||
{
|
||||
let query = query_parser.parse_query("click AND cart.product_id:133")?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(hits.len(), 1);
|
||||
}
|
||||
{
|
||||
// The sub-fields in the json field marked as default field still need to be explicitly
|
||||
// addressed
|
||||
let query = query_parser.parse_query("click AND 133")?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(hits.len(), 0);
|
||||
}
|
||||
{
|
||||
// Default json fields are ignored if they collide with the schema
|
||||
let query = query_parser.parse_query("event_type:holiday-sale")?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(hits.len(), 0);
|
||||
}
|
||||
// # Query via full attribute path
|
||||
{
|
||||
// This only searches in our schema's `event_type` field
|
||||
let query = query_parser.parse_query("event_type:click")?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(hits.len(), 2);
|
||||
}
|
||||
{
|
||||
// Default json fields can still be accessed by full path
|
||||
let query = query_parser.parse_query("attributes.event_type:holiday-sale")?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2))?;
|
||||
let hits = searcher.search(&*query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(hits.len(), 1);
|
||||
}
|
||||
Ok(())
|
||||
|
||||
@@ -63,7 +63,7 @@ fn main() -> Result<()> {
|
||||
// but not "in the Gulf Stream".
|
||||
let query = query_parser.parse_query("\"in the su\"*")?;
|
||||
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
let mut titles = top_docs
|
||||
.into_iter()
|
||||
.map(|(_score, doc_address)| {
|
||||
|
||||
@@ -107,7 +107,8 @@ fn main() -> tantivy::Result<()> {
|
||||
IndexRecordOption::Basic,
|
||||
);
|
||||
|
||||
let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?;
|
||||
let (top_docs, count) =
|
||||
searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?;
|
||||
|
||||
assert_eq!(count, 2);
|
||||
|
||||
@@ -128,7 +129,8 @@ fn main() -> tantivy::Result<()> {
|
||||
IndexRecordOption::Basic,
|
||||
);
|
||||
|
||||
let (_top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count))?;
|
||||
let (_top_docs, count) =
|
||||
searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count))?;
|
||||
|
||||
assert_eq!(count, 0);
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ fn main() -> tantivy::Result<()> {
|
||||
let query_parser = QueryParser::for_index(&index, vec![title, body]);
|
||||
let query = query_parser.parse_query("sycamore spring")?;
|
||||
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
|
||||
let snippet_generator = SnippetGenerator::create(&searcher, &*query, body)?;
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ fn main() -> tantivy::Result<()> {
|
||||
// stop words are applied on the query as well.
|
||||
// The following will be equivalent to `title:frankenstein`
|
||||
let query = query_parser.parse_query("title:\"the Frankenstein\"")?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
|
||||
for (score, doc_address) in top_docs {
|
||||
let retrieved_doc: TantivyDocument = searcher.doc(doc_address)?;
|
||||
|
||||
@@ -164,7 +164,7 @@ fn main() -> tantivy::Result<()> {
|
||||
move |doc_id: DocId| Reverse(price[doc_id as usize])
|
||||
};
|
||||
|
||||
let most_expensive_first = TopDocs::with_limit(10).custom_score(score_by_price);
|
||||
let most_expensive_first = TopDocs::with_limit(10).order_by(score_by_price);
|
||||
|
||||
let hits = searcher.search(&query, &most_expensive_first)?;
|
||||
assert_eq!(
|
||||
|
||||
@@ -758,7 +758,17 @@ fn negate(expr: UserInputAst) -> UserInputAst {
|
||||
fn leaf(inp: &str) -> IResult<&str, UserInputAst> {
|
||||
alt((
|
||||
delimited(char('('), ast, char(')')),
|
||||
map(char('*'), |_| UserInputAst::from(UserInputLeaf::All)),
|
||||
map(
|
||||
terminated(
|
||||
char('*'),
|
||||
peek(alt((
|
||||
value((), multispace1),
|
||||
value((), char(')')),
|
||||
value((), eof),
|
||||
))),
|
||||
),
|
||||
|_| UserInputAst::from(UserInputLeaf::All),
|
||||
),
|
||||
map(preceded(tuple((tag("NOT"), multispace1)), leaf), negate),
|
||||
literal,
|
||||
))(inp)
|
||||
@@ -779,7 +789,17 @@ fn leaf_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>> {
|
||||
),
|
||||
),
|
||||
(
|
||||
value((), char('*')),
|
||||
value(
|
||||
(),
|
||||
terminated(
|
||||
char('*'),
|
||||
peek(alt((
|
||||
value((), multispace1),
|
||||
value((), char(')')),
|
||||
value((), eof),
|
||||
))),
|
||||
),
|
||||
),
|
||||
map(nothing, |_| {
|
||||
(Some(UserInputAst::from(UserInputLeaf::All)), Vec::new())
|
||||
}),
|
||||
@@ -1671,6 +1691,21 @@ mod test {
|
||||
test_parse_query_to_ast_helper("abc:a b", "(*\"abc\":a *b)");
|
||||
test_parse_query_to_ast_helper("abc:\"a b\"", "\"abc\":\"a b\"");
|
||||
test_parse_query_to_ast_helper("foo:[1 TO 5]", "\"foo\":[\"1\" TO \"5\"]");
|
||||
|
||||
// Phrase prefixed with *
|
||||
test_parse_query_to_ast_helper("foo:(*A)", "\"foo\":*A");
|
||||
test_parse_query_to_ast_helper("*A", "*A");
|
||||
test_parse_query_to_ast_helper("(*A)", "*A");
|
||||
test_parse_query_to_ast_helper("foo:(A OR B)", "(?\"foo\":A ?\"foo\":B)");
|
||||
test_parse_query_to_ast_helper("foo:(A* OR B*)", "(?\"foo\":A* ?\"foo\":B*)");
|
||||
test_parse_query_to_ast_helper("foo:(*A OR *B)", "(?\"foo\":*A ?\"foo\":*B)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_query_all() {
|
||||
test_parse_query_to_ast_helper("*", "*");
|
||||
test_parse_query_to_ast_helper("(*)", "*");
|
||||
test_parse_query_to_ast_helper("(* )", "*");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -16,15 +16,16 @@ use crate::index::SegmentReader;
|
||||
/// That way we can use it the same way as if it would come from the fastfield.
|
||||
pub(crate) fn get_missing_val_as_u64_lenient(
|
||||
column_type: ColumnType,
|
||||
column_max_value: u64,
|
||||
missing: &Key,
|
||||
field_name: &str,
|
||||
) -> crate::Result<Option<u64>> {
|
||||
let missing_val = match missing {
|
||||
Key::Str(_) if column_type == ColumnType::Str => Some(u64::MAX),
|
||||
Key::Str(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
|
||||
// Allow fallback to number on text fields
|
||||
Key::F64(_) if column_type == ColumnType::Str => Some(u64::MAX),
|
||||
Key::U64(_) if column_type == ColumnType::Str => Some(u64::MAX),
|
||||
Key::I64(_) if column_type == ColumnType::Str => Some(u64::MAX),
|
||||
Key::F64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
|
||||
Key::U64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
|
||||
Key::I64(_) if column_type == ColumnType::Str => Some(column_max_value + 1),
|
||||
Key::F64(val) if column_type.numerical_type().is_some() => {
|
||||
f64_to_fastfield_u64(*val, &column_type)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use columnar::{Column, ColumnType, StrColumn};
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn};
|
||||
use common::BitSet;
|
||||
use rustc_hash::FxHashSet;
|
||||
use serde::Serialize;
|
||||
@@ -10,16 +10,16 @@ use crate::aggregation::accessor_helpers::{
|
||||
};
|
||||
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
|
||||
use crate::aggregation::bucket::{
|
||||
FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam,
|
||||
MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector,
|
||||
SegmentRangeCollector, SegmentTermCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
|
||||
build_segment_range_collector, FilterAggReqData, HistogramAggReqData, HistogramBounds,
|
||||
IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector,
|
||||
SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
|
||||
TermsAggregationInternal,
|
||||
};
|
||||
use crate::aggregation::metric::{
|
||||
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation,
|
||||
ExtendedStatsAggregation, MaxAggregation, MetricAggReqData, MinAggregation,
|
||||
SegmentCardinalityCollector, SegmentExtendedStatsCollector, SegmentPercentilesCollector,
|
||||
SegmentStatsCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
|
||||
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
|
||||
CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation,
|
||||
MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
|
||||
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
|
||||
TopHitsSegmentCollector,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{
|
||||
@@ -35,6 +35,7 @@ pub struct AggregationsSegmentCtx {
|
||||
/// Request data for each aggregation type.
|
||||
pub per_request: PerRequestAggSegCtx,
|
||||
pub context: AggContextParams,
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
}
|
||||
|
||||
impl AggregationsSegmentCtx {
|
||||
@@ -107,21 +108,14 @@ impl AggregationsSegmentCtx {
|
||||
.as_deref()
|
||||
.expect("range_req_data slot is empty (taken)")
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData {
|
||||
self.per_request.filter_req_data[idx]
|
||||
.as_deref()
|
||||
.expect("filter_req_data slot is empty (taken)")
|
||||
}
|
||||
|
||||
// ---------- mutable getters ----------
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn get_term_req_data_mut(&mut self, idx: usize) -> &mut TermsAggReqData {
|
||||
self.per_request.term_req_data[idx]
|
||||
.as_deref_mut()
|
||||
.expect("term_req_data slot is empty (taken)")
|
||||
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
|
||||
&mut self.per_request.stats_metric_req_data[idx]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn get_cardinality_req_data_mut(
|
||||
&mut self,
|
||||
@@ -129,10 +123,7 @@ impl AggregationsSegmentCtx {
|
||||
) -> &mut CardinalityAggReqData {
|
||||
&mut self.per_request.cardinality_req_data[idx]
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
|
||||
&mut self.per_request.stats_metric_req_data[idx]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData {
|
||||
self.per_request.histogram_req_data[idx]
|
||||
@@ -142,21 +133,6 @@ impl AggregationsSegmentCtx {
|
||||
|
||||
// ---------- take / put (terms, histogram, range) ----------
|
||||
|
||||
/// Move out the boxed Terms request at `idx`, leaving `None`.
|
||||
#[inline]
|
||||
pub(crate) fn take_term_req_data(&mut self, idx: usize) -> Box<TermsAggReqData> {
|
||||
self.per_request.term_req_data[idx]
|
||||
.take()
|
||||
.expect("term_req_data slot is empty (taken)")
|
||||
}
|
||||
|
||||
/// Put back a Terms request into an empty slot at `idx`.
|
||||
#[inline]
|
||||
pub(crate) fn put_back_term_req_data(&mut self, idx: usize, value: Box<TermsAggReqData>) {
|
||||
debug_assert!(self.per_request.term_req_data[idx].is_none());
|
||||
self.per_request.term_req_data[idx] = Some(value);
|
||||
}
|
||||
|
||||
/// Move out the boxed Histogram request at `idx`, leaving `None`.
|
||||
#[inline]
|
||||
pub(crate) fn take_histogram_req_data(&mut self, idx: usize) -> Box<HistogramAggReqData> {
|
||||
@@ -320,6 +296,7 @@ impl PerRequestAggSegCtx {
|
||||
|
||||
/// Convert the aggregation tree into a serializable struct representation.
|
||||
/// Each node contains: { name, kind, children }.
|
||||
#[allow(dead_code)]
|
||||
pub fn get_view_tree(&self) -> Vec<AggTreeViewNode> {
|
||||
fn node_to_view(node: &AggRefNode, pr: &PerRequestAggSegCtx) -> AggTreeViewNode {
|
||||
let mut children: Vec<AggTreeViewNode> =
|
||||
@@ -345,12 +322,19 @@ impl PerRequestAggSegCtx {
|
||||
pub(crate) fn build_segment_agg_collectors_root(
|
||||
req: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
build_segment_agg_collectors(req, &req.per_request.agg_tree.clone())
|
||||
build_segment_agg_collectors_generic(req, &req.per_request.agg_tree.clone())
|
||||
}
|
||||
|
||||
pub(crate) fn build_segment_agg_collectors(
|
||||
req: &mut AggregationsSegmentCtx,
|
||||
nodes: &[AggRefNode],
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
build_segment_agg_collectors_generic(req, nodes)
|
||||
}
|
||||
|
||||
fn build_segment_agg_collectors_generic(
|
||||
req: &mut AggregationsSegmentCtx,
|
||||
nodes: &[AggRefNode],
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
let mut collectors = Vec::new();
|
||||
for node in nodes.iter() {
|
||||
@@ -373,9 +357,7 @@ pub(crate) fn build_segment_agg_collector(
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
match node.kind {
|
||||
AggKind::Terms => Ok(Box::new(SegmentTermCollector::from_req_and_validate(
|
||||
req, node,
|
||||
)?)),
|
||||
AggKind::Terms => crate::aggregation::bucket::build_segment_term_collector(req, node),
|
||||
AggKind::MissingTerm => {
|
||||
let req_data = &mut req.per_request.missing_term_req_data[node.idx_in_req_data];
|
||||
if req_data.accessors.is_empty() {
|
||||
@@ -390,6 +372,8 @@ pub(crate) fn build_segment_agg_collector(
|
||||
Ok(Box::new(SegmentCardinalityCollector::from_req(
|
||||
req_data.column_type,
|
||||
node.idx_in_req_data,
|
||||
req_data.accessor.clone(),
|
||||
req_data.missing_value_for_accessor,
|
||||
)))
|
||||
}
|
||||
AggKind::StatsKind(stats_type) => {
|
||||
@@ -400,20 +384,21 @@ pub(crate) fn build_segment_agg_collector(
|
||||
| StatsType::Count
|
||||
| StatsType::Max
|
||||
| StatsType::Min
|
||||
| StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req(
|
||||
node.idx_in_req_data,
|
||||
))),
|
||||
StatsType::ExtendedStats(sigma) => {
|
||||
Ok(Box::new(SegmentExtendedStatsCollector::from_req(
|
||||
req_data.field_type,
|
||||
sigma,
|
||||
node.idx_in_req_data,
|
||||
req_data.missing,
|
||||
)))
|
||||
}
|
||||
StatsType::Percentiles => Ok(Box::new(
|
||||
SegmentPercentilesCollector::from_req_and_validate(node.idx_in_req_data)?,
|
||||
| StatsType::Stats => build_segment_stats_collector(req_data),
|
||||
StatsType::ExtendedStats(sigma) => Ok(Box::new(
|
||||
SegmentExtendedStatsCollector::from_req(req_data, sigma),
|
||||
)),
|
||||
StatsType::Percentiles => {
|
||||
let req_data = req.get_metric_req_data_mut(node.idx_in_req_data);
|
||||
Ok(Box::new(
|
||||
SegmentPercentilesCollector::from_req_and_validate(
|
||||
req_data.field_type,
|
||||
req_data.missing_u64,
|
||||
req_data.accessor.clone(),
|
||||
node.idx_in_req_data,
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
AggKind::TopHits => {
|
||||
@@ -430,9 +415,7 @@ pub(crate) fn build_segment_agg_collector(
|
||||
AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
|
||||
req, node,
|
||||
)?)),
|
||||
AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
|
||||
req, node,
|
||||
)?)),
|
||||
AggKind::Range => Ok(build_segment_range_collector(req, node)?),
|
||||
AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate(
|
||||
req, node,
|
||||
)?)),
|
||||
@@ -495,10 +478,11 @@ pub(crate) fn build_aggregations_data_from_req(
|
||||
let mut data = AggregationsSegmentCtx {
|
||||
per_request: Default::default(),
|
||||
context,
|
||||
column_block_accessor: ColumnBlockAccessor::default(),
|
||||
};
|
||||
|
||||
for (name, agg) in aggs.iter() {
|
||||
let nodes = build_nodes(name, agg, reader, segment_ordinal, &mut data)?;
|
||||
let nodes = build_nodes(name, agg, reader, segment_ordinal, &mut data, true)?;
|
||||
data.per_request.agg_tree.extend(nodes);
|
||||
}
|
||||
Ok(data)
|
||||
@@ -510,6 +494,7 @@ fn build_nodes(
|
||||
reader: &SegmentReader,
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
data: &mut AggregationsSegmentCtx,
|
||||
is_top_level: bool,
|
||||
) -> crate::Result<Vec<AggRefNode>> {
|
||||
use AggregationVariants::*;
|
||||
match &req.agg {
|
||||
@@ -522,9 +507,9 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_range_req_data(RangeAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
req: range_req.clone(),
|
||||
is_top_level,
|
||||
});
|
||||
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
|
||||
Ok(vec![AggRefNode {
|
||||
@@ -542,9 +527,7 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
sub_aggregation_blueprint: None,
|
||||
req: histo_req.clone(),
|
||||
is_date_histogram: false,
|
||||
bounds: HistogramBounds {
|
||||
@@ -569,9 +552,7 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
sub_aggregation_blueprint: None,
|
||||
req: histo_req,
|
||||
is_date_histogram: true,
|
||||
bounds: HistogramBounds {
|
||||
@@ -596,6 +577,7 @@ fn build_nodes(
|
||||
data,
|
||||
&req.sub_aggregation,
|
||||
TermsOrCardinalityRequest::Terms(terms_req.clone()),
|
||||
is_top_level,
|
||||
),
|
||||
Cardinality(card_req) => build_terms_or_cardinality_nodes(
|
||||
agg_name,
|
||||
@@ -606,6 +588,7 @@ fn build_nodes(
|
||||
data,
|
||||
&req.sub_aggregation,
|
||||
TermsOrCardinalityRequest::Cardinality(card_req.clone()),
|
||||
is_top_level,
|
||||
),
|
||||
Average(AverageAggregation { field, missing, .. })
|
||||
| Max(MaxAggregation { field, missing, .. })
|
||||
@@ -649,7 +632,6 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
collecting_for,
|
||||
missing: *missing,
|
||||
@@ -677,7 +659,6 @@ fn build_nodes(
|
||||
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
|
||||
accessor,
|
||||
field_type,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
collecting_for: StatsType::Percentiles,
|
||||
missing: percentiles_req.missing,
|
||||
@@ -734,7 +715,7 @@ fn build_nodes(
|
||||
// Build the query and evaluator upfront
|
||||
let schema = reader.schema();
|
||||
let tokenizers = &data.context.tokenizers;
|
||||
let query = filter_req.parse_query(&schema, tokenizers)?;
|
||||
let query = filter_req.parse_query(schema, tokenizers)?;
|
||||
let evaluator = crate::aggregation::bucket::DocumentQueryEvaluator::new(
|
||||
query,
|
||||
schema.clone(),
|
||||
@@ -771,7 +752,14 @@ fn build_children(
|
||||
) -> crate::Result<Vec<AggRefNode>> {
|
||||
let mut children = Vec::new();
|
||||
for (name, agg) in aggs.iter() {
|
||||
children.extend(build_nodes(name, agg, reader, segment_ordinal, data)?);
|
||||
children.extend(build_nodes(
|
||||
name,
|
||||
agg,
|
||||
reader,
|
||||
segment_ordinal,
|
||||
data,
|
||||
false,
|
||||
)?);
|
||||
}
|
||||
Ok(children)
|
||||
}
|
||||
@@ -835,6 +823,7 @@ fn build_terms_or_cardinality_nodes(
|
||||
data: &mut AggregationsSegmentCtx,
|
||||
sub_aggs: &Aggregations,
|
||||
req: TermsOrCardinalityRequest,
|
||||
is_top_level: bool,
|
||||
) -> crate::Result<Vec<AggRefNode>> {
|
||||
let mut nodes = Vec::new();
|
||||
|
||||
@@ -886,12 +875,12 @@ fn build_terms_or_cardinality_nodes(
|
||||
});
|
||||
}
|
||||
|
||||
// Add one node per accessor to mirror previous behavior and allow per-type missing handling.
|
||||
// Add one node per accessor
|
||||
for (accessor, column_type) in column_and_types {
|
||||
let missing_value_for_accessor = if use_special_missing_agg {
|
||||
None
|
||||
} else if let Some(m) = missing.as_ref() {
|
||||
get_missing_val_as_u64_lenient(column_type, m, field_name)?
|
||||
get_missing_val_as_u64_lenient(column_type, accessor.max_value(), m, field_name)?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -917,13 +906,11 @@ fn build_terms_or_cardinality_nodes(
|
||||
column_type,
|
||||
str_dict_column: str_dict_column.clone(),
|
||||
missing_value_for_accessor,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
req: TermsAggregationInternal::from_req(req),
|
||||
// Will be filled later when building collectors
|
||||
sub_aggregation_blueprint: None,
|
||||
sug_aggregations: sub_aggs.clone(),
|
||||
allowed_term_ids,
|
||||
is_top_level,
|
||||
});
|
||||
(idx_in_req_data, AggKind::Terms)
|
||||
}
|
||||
@@ -933,7 +920,6 @@ fn build_terms_or_cardinality_nodes(
|
||||
column_type,
|
||||
str_dict_column: str_dict_column.clone(),
|
||||
missing_value_for_accessor,
|
||||
column_block_accessor: Default::default(),
|
||||
name: agg_name.to_string(),
|
||||
req: req.clone(),
|
||||
});
|
||||
|
||||
@@ -35,6 +35,7 @@ pub struct AggregationLimitsGuard {
|
||||
/// Allocated memory with this guard.
|
||||
allocated_with_the_guard: u64,
|
||||
}
|
||||
|
||||
impl Clone for AggregationLimitsGuard {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
@@ -70,7 +71,7 @@ impl AggregationLimitsGuard {
|
||||
/// *memory_limit*
|
||||
/// memory_limit is defined in bytes.
|
||||
/// Aggregation fails when the estimated memory consumption of the aggregation is higher than
|
||||
/// memory_limit.
|
||||
/// memory_limit.
|
||||
/// memory_limit will default to `DEFAULT_MEMORY_LIMIT` (500MB)
|
||||
///
|
||||
/// *bucket_limit*
|
||||
|
||||
@@ -16,7 +16,7 @@ use super::{AggregationError, Key};
|
||||
use crate::TantivyError;
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
|
||||
/// The final aggegation result.
|
||||
/// The final aggregation result.
|
||||
pub struct AggregationResults(pub FxHashMap<String, AggregationResult>);
|
||||
|
||||
impl AggregationResults {
|
||||
|
||||
@@ -2,15 +2,441 @@ use serde_json::Value;
|
||||
|
||||
use crate::aggregation::agg_req::{Aggregation, Aggregations};
|
||||
use crate::aggregation::agg_result::AggregationResults;
|
||||
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
|
||||
use crate::aggregation::collector::AggregationCollector;
|
||||
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
|
||||
use crate::aggregation::DistributedAggregationCollector;
|
||||
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
|
||||
use crate::query::{AllQuery, TermQuery};
|
||||
use crate::schema::{IndexRecordOption, Schema, FAST};
|
||||
use crate::{Index, IndexWriter, Term};
|
||||
|
||||
// The following tests ensure that each bucket aggregation type correctly functions as a
|
||||
// sub-aggregation of another bucket aggregation in two scenarios:
|
||||
// 1) The parent has more buckets than the child sub-aggregation
|
||||
// 2) The child sub-aggregation has more buckets than the parent
|
||||
//
|
||||
// These scenarios exercise the bucket id mapping and sub-aggregation routing logic.
|
||||
|
||||
#[test]
|
||||
fn test_terms_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
|
||||
let index = get_test_index_2_segments(false)?;
|
||||
|
||||
// Case A: parent has more buckets than child
|
||||
// Parent: range with 4 buckets
|
||||
// Child: terms on text -> 2 buckets
|
||||
let agg_parent_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 20.0},
|
||||
{"from": 20.0}
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
|
||||
// Exact expected structure and counts
|
||||
assert_eq!(
|
||||
res["parent_range"]["buckets"],
|
||||
json!([
|
||||
{
|
||||
"key": "*-3",
|
||||
"doc_count": 1,
|
||||
"to": 3.0,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 1, "key": "cool"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "3-7",
|
||||
"doc_count": 3,
|
||||
"from": 3.0,
|
||||
"to": 7.0,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 2, "key": "cool"},
|
||||
{"doc_count": 1, "key": "nohit"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "7-20",
|
||||
"doc_count": 3,
|
||||
"from": 7.0,
|
||||
"to": 20.0,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 3, "key": "cool"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "20-*",
|
||||
"doc_count": 2,
|
||||
"from": 20.0,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 1, "key": "cool"},
|
||||
{"doc_count": 1, "key": "nohit"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
}
|
||||
])
|
||||
);
|
||||
|
||||
// Case B: child has more buckets than parent
|
||||
// Parent: histogram on score with large interval -> 1 bucket
|
||||
// Child: terms on text -> 2 buckets (cool/nohit)
|
||||
let agg_child_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_hist": {
|
||||
"histogram": {"field": "score", "interval": 100.0},
|
||||
"aggs": {
|
||||
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
|
||||
assert_eq!(
|
||||
res["parent_hist"],
|
||||
json!({
|
||||
"buckets": [
|
||||
{
|
||||
"key": 0.0,
|
||||
"doc_count": 9,
|
||||
"child_terms": {
|
||||
"buckets": [
|
||||
{"doc_count": 7, "key": "cool"},
|
||||
{"doc_count": 2, "key": "nohit"}
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
|
||||
let index = get_test_index_2_segments(false)?;
|
||||
|
||||
// Case A: parent has more buckets than child
|
||||
// Parent: range with 5 buckets
|
||||
// Child: coarse range with 3 buckets
|
||||
let agg_parent_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 11.0},
|
||||
{"from": 11.0, "to": 20.0},
|
||||
{"from": 20.0}
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"child_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 20.0}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
|
||||
assert_eq!(
|
||||
res["parent_range"]["buckets"],
|
||||
json!([
|
||||
{"key": "*-3", "doc_count": 1, "to": 3.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 1, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 0, "from": 20.0}
|
||||
]}
|
||||
},
|
||||
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 3, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 0, "from": 20.0}
|
||||
]}
|
||||
},
|
||||
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 1, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 0, "from": 20.0}
|
||||
]}
|
||||
},
|
||||
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 2, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 0, "from": 20.0}
|
||||
]}
|
||||
},
|
||||
{"key": "20-*", "doc_count": 2, "from": 20.0,
|
||||
"child_range": {"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 2, "from": 20.0}
|
||||
]}
|
||||
}
|
||||
])
|
||||
);
|
||||
|
||||
// Case B: child has more buckets than parent
|
||||
// Parent: terms on text (2 buckets)
|
||||
// Child: range with 4 buckets
|
||||
let agg_child_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_terms": {
|
||||
"terms": {"field": "text"},
|
||||
"aggs": {
|
||||
"child_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 20.0}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
|
||||
|
||||
assert_eq!(
|
||||
res["parent_terms"],
|
||||
json!({
|
||||
"buckets": [
|
||||
{
|
||||
"key": "cool",
|
||||
"doc_count": 7,
|
||||
"child_range": {
|
||||
"buckets": [
|
||||
{"key": "*-3", "doc_count": 1, "to": 3.0},
|
||||
{"key": "3-7", "doc_count": 2, "from": 3.0, "to": 7.0},
|
||||
{"key": "7-20", "doc_count": 3, "from": 7.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 1, "from": 20.0}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "nohit",
|
||||
"doc_count": 2,
|
||||
"child_range": {
|
||||
"buckets": [
|
||||
{"key": "*-3", "doc_count": 0, "to": 3.0},
|
||||
{"key": "3-7", "doc_count": 1, "from": 3.0, "to": 7.0},
|
||||
{"key": "7-20", "doc_count": 0, "from": 7.0, "to": 20.0},
|
||||
{"key": "20-*", "doc_count": 1, "from": 20.0}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"doc_count_error_upper_bound": 0,
|
||||
"sum_other_doc_count": 0
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
|
||||
let index = get_test_index_2_segments(false)?;
|
||||
|
||||
// Case A: parent has more buckets than child
|
||||
// Parent: range with several ranges
|
||||
// Child: histogram with large interval (single bucket per parent)
|
||||
let agg_parent_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 11.0},
|
||||
{"from": 11.0, "to": 20.0},
|
||||
{"from": 20.0}
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"child_hist": {"histogram": {"field": "score", "interval": 100.0}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
|
||||
assert_eq!(
|
||||
res["parent_range"]["buckets"],
|
||||
json!([
|
||||
{"key": "*-3", "doc_count": 1, "to": 3.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
|
||||
},
|
||||
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 3} ]}
|
||||
},
|
||||
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
|
||||
},
|
||||
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
|
||||
},
|
||||
{"key": "20-*", "doc_count": 2, "from": 20.0,
|
||||
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
|
||||
}
|
||||
])
|
||||
);
|
||||
|
||||
// Case B: child has more buckets than parent
|
||||
// Parent: terms on text -> 2 buckets
|
||||
// Child: histogram with small interval -> multiple buckets including empties
|
||||
let agg_child_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_terms": {
|
||||
"terms": {"field": "text"},
|
||||
"aggs": {
|
||||
"child_hist": {"histogram": {"field": "score", "interval": 10.0}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
|
||||
assert_eq!(
|
||||
res["parent_terms"],
|
||||
json!({
|
||||
"buckets": [
|
||||
{
|
||||
"key": "cool",
|
||||
"doc_count": 7,
|
||||
"child_hist": {
|
||||
"buckets": [
|
||||
{"key": 0.0, "doc_count": 4},
|
||||
{"key": 10.0, "doc_count": 2},
|
||||
{"key": 20.0, "doc_count": 0},
|
||||
{"key": 30.0, "doc_count": 0},
|
||||
{"key": 40.0, "doc_count": 1}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "nohit",
|
||||
"doc_count": 2,
|
||||
"child_hist": {
|
||||
"buckets": [
|
||||
{"key": 0.0, "doc_count": 1},
|
||||
{"key": 10.0, "doc_count": 0},
|
||||
{"key": 20.0, "doc_count": 0},
|
||||
{"key": 30.0, "doc_count": 0},
|
||||
{"key": 40.0, "doc_count": 1}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"doc_count_error_upper_bound": 0,
|
||||
"sum_other_doc_count": 0
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_date_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
|
||||
let index = get_test_index_2_segments(false)?;
|
||||
|
||||
// Case A: parent has more buckets than child
|
||||
// Parent: range with several buckets
|
||||
// Child: date_histogram with 30d -> single bucket per parent
|
||||
let agg_parent_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_range": {
|
||||
"range": {
|
||||
"field": "score",
|
||||
"ranges": [
|
||||
{"to": 3.0},
|
||||
{"from": 3.0, "to": 7.0},
|
||||
{"from": 7.0, "to": 11.0},
|
||||
{"from": 11.0, "to": 20.0},
|
||||
{"from": 20.0}
|
||||
]
|
||||
},
|
||||
"aggs": {
|
||||
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "30d"}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
|
||||
let buckets = res["parent_range"]["buckets"].as_array().unwrap();
|
||||
// Verify each parent bucket has exactly one child date bucket with matching doc_count
|
||||
for bucket in buckets {
|
||||
let parent_count = bucket["doc_count"].as_u64().unwrap();
|
||||
let child_buckets = bucket["child_date_hist"]["buckets"].as_array().unwrap();
|
||||
assert_eq!(child_buckets.len(), 1);
|
||||
assert_eq!(child_buckets[0]["doc_count"], parent_count);
|
||||
}
|
||||
|
||||
// Case B: child has more buckets than parent
|
||||
// Parent: terms on text (2 buckets)
|
||||
// Child: date_histogram with 1d -> multiple buckets
|
||||
let agg_child_more: Aggregations = serde_json::from_value(json!({
|
||||
"parent_terms": {
|
||||
"terms": {"field": "text"},
|
||||
"aggs": {
|
||||
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "1d"}}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
|
||||
let buckets = res["parent_terms"]["buckets"].as_array().unwrap();
|
||||
|
||||
// cool bucket
|
||||
assert_eq!(buckets[0]["key"], "cool");
|
||||
let cool_buckets = buckets[0]["child_date_hist"]["buckets"].as_array().unwrap();
|
||||
assert_eq!(cool_buckets.len(), 3);
|
||||
assert_eq!(cool_buckets[0]["doc_count"], 1); // day 0
|
||||
assert_eq!(cool_buckets[1]["doc_count"], 4); // day 1
|
||||
assert_eq!(cool_buckets[2]["doc_count"], 2); // day 2
|
||||
|
||||
// nohit bucket
|
||||
assert_eq!(buckets[1]["key"], "nohit");
|
||||
let nohit_buckets = buckets[1]["child_date_hist"]["buckets"].as_array().unwrap();
|
||||
assert_eq!(nohit_buckets.len(), 2);
|
||||
assert_eq!(nohit_buckets[0]["doc_count"], 1); // day 1
|
||||
assert_eq!(nohit_buckets[1]["doc_count"], 1); // day 2
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_avg_req(field_name: &str) -> Aggregation {
|
||||
serde_json::from_value(json!({
|
||||
"avg": {
|
||||
@@ -25,6 +451,10 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
|
||||
}
|
||||
|
||||
// *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE ***
|
||||
// Note: The flushng part of these tests are outdated, since the buffering change after converting
|
||||
// the collection into one collector per request instead of per bucket.
|
||||
//
|
||||
// However they are useful as they test a complex aggregation requests.
|
||||
fn test_aggregation_flushing(
|
||||
merge_segments: bool,
|
||||
use_distributed_collector: bool,
|
||||
@@ -37,8 +467,9 @@ fn test_aggregation_flushing(
|
||||
|
||||
let reader = index.reader()?;
|
||||
|
||||
assert_eq!(DOC_BLOCK_SIZE, 64);
|
||||
// In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block.
|
||||
assert_eq!(COLLECT_BLOCK_BUFFER_LEN, 64);
|
||||
// In the tree we cache documents of COLLECT_BLOCK_BUFFER_LEN before passing them down as one
|
||||
// block.
|
||||
//
|
||||
// Build a request so that on the first level we have one full cache, which is then flushed.
|
||||
// The same cache should have some residue docs at the end, which are flushed (Range 0-70)
|
||||
|
||||
@@ -6,10 +6,12 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::{CollectorClone, SegmentAggregationCollector};
|
||||
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
|
||||
use crate::aggregation::BucketId;
|
||||
use crate::docset::DocSet;
|
||||
use crate::query::{AllQuery, EnableScoring, Query, QueryParser};
|
||||
use crate::schema::Schema;
|
||||
@@ -32,7 +34,7 @@ use crate::{DocId, SegmentReader, TantivyError};
|
||||
///
|
||||
/// # Implementation Requirements
|
||||
///
|
||||
/// Implementors must:
|
||||
/// Implementers must:
|
||||
/// 1. Derive `Debug`, `Clone`, `Serialize`, and `Deserialize`
|
||||
/// 2. Use `#[typetag::serde]` attribute on the impl block
|
||||
/// 3. Implement `build_query()` to construct the query from schema/tokenizers
|
||||
@@ -410,9 +412,9 @@ impl FilterAggReqData {
|
||||
pub(crate) fn get_memory_consumption(&self) -> usize {
|
||||
// Estimate: name + segment reader reference + bitset + buffer capacity
|
||||
self.name.len()
|
||||
+ std::mem::size_of::<SegmentReader>()
|
||||
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
|
||||
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
|
||||
+ std::mem::size_of::<SegmentReader>()
|
||||
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
|
||||
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -489,12 +491,19 @@ impl Debug for DocumentQueryEvaluator {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Copy)]
|
||||
struct DocCount {
|
||||
doc_count: u64,
|
||||
bucket_id: BucketId,
|
||||
}
|
||||
|
||||
/// Segment collector for filter aggregation
|
||||
pub struct SegmentFilterCollector {
|
||||
/// Document count in this bucket
|
||||
doc_count: u64,
|
||||
/// Document counts per parent bucket
|
||||
parent_buckets: Vec<DocCount>,
|
||||
/// Sub-aggregation collectors
|
||||
sub_aggregations: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
sub_aggregations: Option<CachedSubAggs<true>>,
|
||||
bucket_id_provider: BucketIdProvider,
|
||||
/// Accessor index for this filter aggregation (to access FilterAggReqData)
|
||||
accessor_idx: usize,
|
||||
}
|
||||
@@ -511,11 +520,13 @@ impl SegmentFilterCollector {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new);
|
||||
|
||||
Ok(SegmentFilterCollector {
|
||||
doc_count: 0,
|
||||
parent_buckets: Vec::new(),
|
||||
sub_aggregations: sub_agg_collector,
|
||||
accessor_idx: node.idx_in_req_data,
|
||||
bucket_id_provider: BucketIdProvider::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -523,35 +534,41 @@ impl SegmentFilterCollector {
|
||||
impl Debug for SegmentFilterCollector {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SegmentFilterCollector")
|
||||
.field("doc_count", &self.doc_count)
|
||||
.field("buckets", &self.parent_buckets)
|
||||
.field("has_sub_aggs", &self.sub_aggregations.is_some())
|
||||
.field("accessor_idx", &self.accessor_idx)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl CollectorClone for SegmentFilterCollector {
|
||||
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
|
||||
// For now, panic - this needs proper implementation with weight recreation
|
||||
panic!("SegmentFilterCollector cloning not yet implemented - requires weight recreation")
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentFilterCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let mut sub_results = IntermediateAggregationResults::default();
|
||||
let bucket_opt = self.parent_buckets.get(parent_bucket_id as usize);
|
||||
|
||||
if let Some(sub_aggs) = self.sub_aggregations {
|
||||
sub_aggs.add_intermediate_aggregation_result(agg_data, &mut sub_results)?;
|
||||
if let Some(sub_aggs) = &mut self.sub_aggregations {
|
||||
sub_aggs
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(
|
||||
agg_data,
|
||||
&mut sub_results,
|
||||
// Here we create a new bucket ID for sub-aggregations if the bucket doesn't
|
||||
// exist, so that sub-aggregations can still produce results (e.g., zero doc
|
||||
// count)
|
||||
bucket_opt
|
||||
.map(|bucket| bucket.bucket_id)
|
||||
.unwrap_or(self.bucket_id_provider.next_bucket_id()),
|
||||
)?;
|
||||
}
|
||||
|
||||
// Create the filter bucket result
|
||||
let filter_bucket_result = IntermediateBucketResult::Filter {
|
||||
doc_count: self.doc_count,
|
||||
doc_count: bucket_opt.map(|b| b.doc_count).unwrap_or(0),
|
||||
sub_aggregations: sub_results,
|
||||
};
|
||||
|
||||
@@ -570,32 +587,17 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect(&mut self, doc: DocId, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
// Access the evaluator from FilterAggReqData
|
||||
let req_data = agg_data.get_filter_req_data(self.accessor_idx);
|
||||
|
||||
// O(1) BitSet lookup to check if document matches filter
|
||||
if req_data.evaluator.matches_document(doc) {
|
||||
self.doc_count += 1;
|
||||
|
||||
// If we have sub-aggregations, collect on them for this filtered document
|
||||
if let Some(sub_aggs) = &mut self.sub_aggregations {
|
||||
sub_aggs.collect(doc, agg_data)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
fn collect(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
if docs.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut bucket = self.parent_buckets[parent_bucket_id as usize];
|
||||
// Take the request data to avoid borrow checker issues with sub-aggregations
|
||||
let mut req = agg_data.take_filter_req_data(self.accessor_idx);
|
||||
|
||||
@@ -604,18 +606,24 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
|
||||
req.evaluator
|
||||
.filter_batch(docs, &mut req.matching_docs_buffer);
|
||||
|
||||
self.doc_count += req.matching_docs_buffer.len() as u64;
|
||||
bucket.doc_count += req.matching_docs_buffer.len() as u64;
|
||||
|
||||
// Batch process sub-aggregations if we have matches
|
||||
if !req.matching_docs_buffer.is_empty() {
|
||||
if let Some(sub_aggs) = &mut self.sub_aggregations {
|
||||
// Use collect_block for better sub-aggregation performance
|
||||
sub_aggs.collect_block(&req.matching_docs_buffer, agg_data)?;
|
||||
for &doc_id in &req.matching_docs_buffer {
|
||||
sub_aggs.push(bucket.bucket_id, doc_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Put the request data back
|
||||
agg_data.put_back_filter_req_data(self.accessor_idx, req);
|
||||
if let Some(sub_aggs) = &mut self.sub_aggregations {
|
||||
sub_aggs.check_flush_local(agg_data)?;
|
||||
}
|
||||
// put back bucket
|
||||
self.parent_buckets[parent_bucket_id as usize] = bucket;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -626,6 +634,21 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
while self.parent_buckets.len() <= max_bucket as usize {
|
||||
let bucket_id = self.bucket_id_provider.next_bucket_id();
|
||||
self.parent_buckets.push(DocCount {
|
||||
doc_count: 0,
|
||||
bucket_id,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Intermediate result for filter aggregation
|
||||
@@ -639,16 +662,14 @@ pub struct IntermediateFilterBucketResult {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Instant;
|
||||
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use super::*;
|
||||
use crate::aggregation::agg_req::Aggregations;
|
||||
use crate::aggregation::agg_result::AggregationResults;
|
||||
use crate::aggregation::{AggContextParams, AggregationCollector};
|
||||
use crate::query::{AllQuery, QueryParser, TermQuery};
|
||||
use crate::schema::{IndexRecordOption, Schema, Term, FAST, INDEXED, STORED, TEXT};
|
||||
use crate::query::{AllQuery, TermQuery};
|
||||
use crate::schema::{IndexRecordOption, Schema, Term, FAST, INDEXED, TEXT};
|
||||
use crate::{doc, Index, IndexWriter};
|
||||
|
||||
// Test helper functions
|
||||
@@ -729,12 +750,13 @@ mod tests {
|
||||
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer: IndexWriter = index.writer(50_000_000)?;
|
||||
let mut writer: IndexWriter = index.writer_for_tests()?;
|
||||
|
||||
writer.add_document(doc!(
|
||||
category => "electronics", brand => "apple",
|
||||
price => 999u64, rating => 4.5f64, in_stock => true
|
||||
))?;
|
||||
writer.commit()?;
|
||||
writer.add_document(doc!(
|
||||
category => "electronics", brand => "samsung",
|
||||
price => 799u64, rating => 4.2f64, in_stock => true
|
||||
@@ -938,7 +960,7 @@ mod tests {
|
||||
let index = create_standard_test_index()?;
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
|
||||
assert_eq!(searcher.segment_readers().len(), 2);
|
||||
let agg = json!({
|
||||
"premium_electronics": {
|
||||
"filter": "category:electronics AND price:[800 TO *]",
|
||||
@@ -1520,9 +1542,9 @@ mod tests {
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let agg = json!({
|
||||
"test": {
|
||||
"filter": deserialized,
|
||||
"aggs": { "count": { "value_count": { "field": "brand" } } }
|
||||
"test": {
|
||||
"filter": deserialized,
|
||||
"aggs": { "count": { "value_count": { "field": "brand" } } }
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType};
|
||||
use columnar::{Column, ColumnType};
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tantivy_bitpacker::minmax;
|
||||
@@ -8,14 +8,14 @@ use tantivy_bitpacker::minmax;
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::agg_limits::MemoryConsumption;
|
||||
use crate::aggregation::agg_req::Aggregations;
|
||||
use crate::aggregation::agg_result::BucketEntry;
|
||||
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
IntermediateHistogramBucketEntry,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
|
||||
@@ -26,13 +26,8 @@ pub struct HistogramAggReqData {
|
||||
pub accessor: Column<u64>,
|
||||
/// The field type of the fast field.
|
||||
pub field_type: ColumnType,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// The sub aggregation blueprint, used to create sub aggregations for each bucket.
|
||||
/// Will be filled during initialization of the collector.
|
||||
pub sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
/// The histogram aggregation request.
|
||||
pub req: HistogramAggregation,
|
||||
/// True if this is a date_histogram aggregation.
|
||||
@@ -257,18 +252,24 @@ impl HistogramBounds {
|
||||
pub(crate) struct SegmentHistogramBucketEntry {
|
||||
pub key: f64,
|
||||
pub doc_count: u64,
|
||||
pub bucket_id: BucketId,
|
||||
}
|
||||
|
||||
impl SegmentHistogramBucketEntry {
|
||||
pub(crate) fn into_intermediate_bucket_entry(
|
||||
self,
|
||||
sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
sub_aggregation: &mut Option<CachedSubAggs>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<IntermediateHistogramBucketEntry> {
|
||||
let mut sub_aggregation_res = IntermediateAggregationResults::default();
|
||||
if let Some(sub_aggregation) = sub_aggregation {
|
||||
sub_aggregation
|
||||
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?;
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(
|
||||
agg_data,
|
||||
&mut sub_aggregation_res,
|
||||
self.bucket_id,
|
||||
)?;
|
||||
}
|
||||
Ok(IntermediateHistogramBucketEntry {
|
||||
key: self.key,
|
||||
@@ -278,27 +279,38 @@ impl SegmentHistogramBucketEntry {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct HistogramBuckets {
|
||||
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
|
||||
}
|
||||
|
||||
/// The collector puts values from the fast field into the correct buckets and does a conversion to
|
||||
/// the correct datatype.
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Debug)]
|
||||
pub struct SegmentHistogramCollector {
|
||||
/// The buckets containing the aggregation data.
|
||||
buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
|
||||
sub_aggregations: FxHashMap<i64, Box<dyn SegmentAggregationCollector>>,
|
||||
/// One Histogram bucket per parent bucket id.
|
||||
parent_buckets: Vec<HistogramBuckets>,
|
||||
sub_agg: Option<CachedSubAggs>,
|
||||
accessor_idx: usize,
|
||||
bucket_id_provider: BucketIdProvider,
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_data
|
||||
.get_histogram_req_data(self.accessor_idx)
|
||||
.name
|
||||
.clone();
|
||||
let bucket = self.into_intermediate_bucket_result(agg_data)?;
|
||||
// TODO: avoid prepare_max_bucket here and handle empty buckets.
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let histogram = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
|
||||
let bucket = self.add_intermediate_bucket_result(agg_data, histogram)?;
|
||||
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
|
||||
|
||||
Ok(())
|
||||
@@ -307,44 +319,40 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let mut req = agg_data.take_histogram_req_data(self.accessor_idx);
|
||||
let req = agg_data.take_histogram_req_data(self.accessor_idx);
|
||||
let mem_pre = self.get_memory_consumption();
|
||||
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
|
||||
|
||||
let bounds = req.bounds;
|
||||
let interval = req.req.interval;
|
||||
let offset = req.offset;
|
||||
let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64;
|
||||
|
||||
req.column_block_accessor.fetch_block(docs, &req.accessor);
|
||||
for (doc, val) in req
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req.accessor);
|
||||
for (doc, val) in agg_data
|
||||
.column_block_accessor
|
||||
.iter_docid_vals(docs, &req.accessor)
|
||||
{
|
||||
let val = f64_from_fastfield_u64(val, &req.field_type);
|
||||
let val = f64_from_fastfield_u64(val, req.field_type);
|
||||
let bucket_pos = get_bucket_pos(val);
|
||||
if bounds.contains(val) {
|
||||
let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| {
|
||||
let bucket = buckets.entry(bucket_pos).or_insert_with(|| {
|
||||
let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset);
|
||||
SegmentHistogramBucketEntry { key, doc_count: 0 }
|
||||
SegmentHistogramBucketEntry {
|
||||
key,
|
||||
doc_count: 0,
|
||||
bucket_id: self.bucket_id_provider.next_bucket_id(),
|
||||
}
|
||||
});
|
||||
bucket.doc_count += 1;
|
||||
if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() {
|
||||
self.sub_aggregations
|
||||
.entry(bucket_pos)
|
||||
.or_insert_with(|| sub_aggregation_blueprint.clone())
|
||||
.collect(doc, agg_data)?;
|
||||
if let Some(sub_agg) = &mut self.sub_agg {
|
||||
sub_agg.push(bucket.bucket_id, doc);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -358,14 +366,30 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
.add_memory_consumed(mem_delta as u64)?;
|
||||
}
|
||||
|
||||
if let Some(sub_agg) = &mut self.sub_agg {
|
||||
sub_agg.check_flush_local(agg_data)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
for sub_aggregation in self.sub_aggregations.values_mut() {
|
||||
if let Some(sub_aggregation) = &mut self.sub_agg {
|
||||
sub_aggregation.flush(agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
while self.parent_buckets.len() <= max_bucket as usize {
|
||||
self.parent_buckets.push(HistogramBuckets {
|
||||
buckets: FxHashMap::default(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -373,22 +397,19 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
|
||||
impl SegmentHistogramCollector {
|
||||
fn get_memory_consumption(&self) -> usize {
|
||||
let self_mem = std::mem::size_of::<Self>();
|
||||
let sub_aggs_mem = self.sub_aggregations.memory_consumption();
|
||||
let buckets_mem = self.buckets.memory_consumption();
|
||||
self_mem + sub_aggs_mem + buckets_mem
|
||||
let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
|
||||
self_mem + buckets_mem
|
||||
}
|
||||
/// Converts the collector result into a intermediate bucket result.
|
||||
pub fn into_intermediate_bucket_result(
|
||||
self,
|
||||
fn add_intermediate_bucket_result(
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
histogram: HistogramBuckets,
|
||||
) -> crate::Result<IntermediateBucketResult> {
|
||||
let mut buckets = Vec::with_capacity(self.buckets.len());
|
||||
let mut buckets = Vec::with_capacity(histogram.buckets.len());
|
||||
|
||||
for (bucket_pos, bucket) in self.buckets {
|
||||
let bucket_res = bucket.into_intermediate_bucket_entry(
|
||||
self.sub_aggregations.get(&bucket_pos).cloned(),
|
||||
agg_data,
|
||||
);
|
||||
for bucket in histogram.buckets.into_values() {
|
||||
let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data);
|
||||
|
||||
buckets.push(bucket_res?);
|
||||
}
|
||||
@@ -408,7 +429,7 @@ impl SegmentHistogramCollector {
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
let blueprint = if !node.children.is_empty() {
|
||||
let sub_agg = if !node.children.is_empty() {
|
||||
Some(build_segment_agg_collectors(agg_data, &node.children)?)
|
||||
} else {
|
||||
None
|
||||
@@ -423,13 +444,13 @@ impl SegmentHistogramCollector {
|
||||
max: f64::MAX,
|
||||
});
|
||||
req_data.offset = req_data.req.offset.unwrap_or(0.0);
|
||||
|
||||
req_data.sub_aggregation_blueprint = blueprint;
|
||||
let sub_agg = sub_agg.map(CachedSubAggs::new);
|
||||
|
||||
Ok(Self {
|
||||
buckets: Default::default(),
|
||||
sub_aggregations: Default::default(),
|
||||
parent_buckets: Default::default(),
|
||||
sub_agg,
|
||||
accessor_idx: node.idx_in_req_data,
|
||||
bucket_id_provider: BucketIdProvider::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
use std::fmt::Debug;
|
||||
use std::ops::Range;
|
||||
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType};
|
||||
use columnar::{Column, ColumnType};
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::agg_limits::AggregationLimitsGuard;
|
||||
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
|
||||
@@ -23,12 +25,12 @@ pub struct RangeAggReqData {
|
||||
pub accessor: Column<u64>,
|
||||
/// The type of the fast field.
|
||||
pub field_type: ColumnType,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The range aggregation request.
|
||||
pub req: RangeAggregation,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// Whether this is a top-level aggregation.
|
||||
pub is_top_level: bool,
|
||||
}
|
||||
|
||||
impl RangeAggReqData {
|
||||
@@ -151,19 +153,47 @@ pub(crate) struct SegmentRangeAndBucketEntry {
|
||||
|
||||
/// The collector puts values from the fast field into the correct buckets and does a conversion to
|
||||
/// the correct datatype.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SegmentRangeCollector {
|
||||
pub struct SegmentRangeCollector<const LOWCARD: bool = false> {
|
||||
/// The buckets containing the aggregation data.
|
||||
buckets: Vec<SegmentRangeAndBucketEntry>,
|
||||
/// One for each ParentBucketId
|
||||
parent_buckets: Vec<Vec<SegmentRangeAndBucketEntry>>,
|
||||
column_type: ColumnType,
|
||||
pub(crate) accessor_idx: usize,
|
||||
sub_agg: Option<CachedSubAggs<LOWCARD>>,
|
||||
/// Here things get a bit weird. We need to assign unique bucket ids across all
|
||||
/// parent buckets. So we keep track of the next available bucket id here.
|
||||
/// This allows a kind of flattening of the bucket ids across all parent buckets.
|
||||
/// E.g. in nested aggregations:
|
||||
/// Term Agg -> Range aggregation -> Stats aggregation
|
||||
/// E.g. the Term Agg creates 3 buckets ["INFO", "ERROR", "WARN"], each of these has a Range
|
||||
/// aggregation with 4 buckets. The Range aggregation will create buckets with ids:
|
||||
/// - INFO: 0,1,2,3
|
||||
/// - ERROR: 4,5,6,7
|
||||
/// - WARN: 8,9,10,11
|
||||
///
|
||||
/// This allows the Stats aggregation to have unique bucket ids to refer to.
|
||||
bucket_id_provider: BucketIdProvider,
|
||||
limits: AggregationLimitsGuard,
|
||||
}
|
||||
|
||||
impl<const LOWCARD: bool> Debug for SegmentRangeCollector<LOWCARD> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SegmentRangeCollector")
|
||||
.field("parent_buckets_len", &self.parent_buckets.len())
|
||||
.field("column_type", &self.column_type)
|
||||
.field("accessor_idx", &self.accessor_idx)
|
||||
.field("has_sub_agg", &self.sub_agg.is_some())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO: Bad naming, there's also SegmentRangeAndBucketEntry
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SegmentRangeBucketEntry {
|
||||
pub key: Key,
|
||||
pub doc_count: u64,
|
||||
pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
// pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
pub bucket_id: BucketId,
|
||||
/// The from range of the bucket. Equals `f64::MIN` when `None`.
|
||||
pub from: Option<f64>,
|
||||
/// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not
|
||||
@@ -184,48 +214,50 @@ impl Debug for SegmentRangeBucketEntry {
|
||||
impl SegmentRangeBucketEntry {
|
||||
pub(crate) fn into_intermediate_bucket_entry(
|
||||
self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<IntermediateRangeBucketEntry> {
|
||||
let mut sub_aggregation_res = IntermediateAggregationResults::default();
|
||||
if let Some(sub_aggregation) = self.sub_aggregation {
|
||||
sub_aggregation
|
||||
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?
|
||||
} else {
|
||||
Default::default()
|
||||
};
|
||||
let sub_aggregation = IntermediateAggregationResults::default();
|
||||
|
||||
Ok(IntermediateRangeBucketEntry {
|
||||
key: self.key.into(),
|
||||
doc_count: self.doc_count,
|
||||
sub_aggregation: sub_aggregation_res,
|
||||
sub_aggregation_res: sub_aggregation,
|
||||
from: self.from,
|
||||
to: self.to,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentRangeCollector {
|
||||
impl<const LOWCARD: bool> SegmentAggregationCollector for SegmentRangeCollector<LOWCARD> {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let field_type = self.column_type;
|
||||
let name = agg_data
|
||||
.get_range_req_data(self.accessor_idx)
|
||||
.name
|
||||
.to_string();
|
||||
|
||||
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
|
||||
.buckets
|
||||
let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
|
||||
|
||||
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = buckets
|
||||
.into_iter()
|
||||
.map(move |range_bucket| {
|
||||
Ok((
|
||||
range_to_string(&range_bucket.range, &field_type)?,
|
||||
range_bucket
|
||||
.bucket
|
||||
.into_intermediate_bucket_entry(agg_data)?,
|
||||
))
|
||||
.map(|range_bucket| {
|
||||
let bucket_id = range_bucket.bucket.bucket_id;
|
||||
let mut agg = range_bucket.bucket.into_intermediate_bucket_entry()?;
|
||||
if let Some(sub_aggregation) = &mut self.sub_agg {
|
||||
sub_aggregation
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(
|
||||
agg_data,
|
||||
&mut agg.sub_aggregation_res,
|
||||
bucket_id,
|
||||
)?;
|
||||
}
|
||||
Ok((range_to_string(&range_bucket.range, &field_type)?, agg))
|
||||
})
|
||||
.collect::<crate::Result<_>>()?;
|
||||
|
||||
@@ -242,73 +274,114 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
// Take request data to avoid borrow conflicts during sub-aggregation
|
||||
let mut req = agg_data.take_range_req_data(self.accessor_idx);
|
||||
let req = agg_data.take_range_req_data(self.accessor_idx);
|
||||
|
||||
req.column_block_accessor.fetch_block(docs, &req.accessor);
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req.accessor);
|
||||
|
||||
for (doc, val) in req
|
||||
let buckets = &mut self.parent_buckets[parent_bucket_id as usize];
|
||||
|
||||
for (doc, val) in agg_data
|
||||
.column_block_accessor
|
||||
.iter_docid_vals(docs, &req.accessor)
|
||||
{
|
||||
let bucket_pos = self.get_bucket_pos(val);
|
||||
let bucket = &mut self.buckets[bucket_pos];
|
||||
let bucket_pos = get_bucket_pos(val, buckets);
|
||||
let bucket = &mut buckets[bucket_pos];
|
||||
bucket.bucket.doc_count += 1;
|
||||
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
|
||||
sub_agg.collect(doc, agg_data)?;
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.push(bucket.bucket.bucket_id, doc);
|
||||
}
|
||||
}
|
||||
|
||||
agg_data.put_back_range_req_data(self.accessor_idx, req);
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.check_flush_local(agg_data)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
for bucket in self.buckets.iter_mut() {
|
||||
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
|
||||
sub_agg.flush(agg_data)?;
|
||||
}
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.flush(agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
while self.parent_buckets.len() <= max_bucket as usize {
|
||||
let new_buckets = self.create_new_buckets(agg_data)?;
|
||||
self.parent_buckets.push(new_buckets);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
|
||||
/// bucket storage, depending on the column type and aggregation level.
|
||||
pub(crate) fn build_segment_range_collector(
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
let accessor_idx = node.idx_in_req_data;
|
||||
let req_data = agg_data.get_range_req_data(node.idx_in_req_data);
|
||||
let field_type = req_data.field_type;
|
||||
|
||||
// TODO: A better metric instead of is_top_level would be the number of buckets expected.
|
||||
// E.g. If range agg is not top level, but the parent is a bucket agg with less than 10 buckets,
|
||||
// we can are still in low cardinality territory.
|
||||
let is_low_card = req_data.is_top_level && req_data.req.ranges.len() <= 64;
|
||||
|
||||
let sub_agg = if !node.children.is_empty() {
|
||||
Some(build_segment_agg_collectors(agg_data, &node.children)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if is_low_card {
|
||||
Ok(Box::new(SegmentRangeCollector {
|
||||
sub_agg: sub_agg.map(CachedSubAggs::<true>::new),
|
||||
column_type: field_type,
|
||||
accessor_idx,
|
||||
parent_buckets: Vec::new(),
|
||||
bucket_id_provider: BucketIdProvider::default(),
|
||||
limits: agg_data.context.limits.clone(),
|
||||
}))
|
||||
} else {
|
||||
Ok(Box::new(SegmentRangeCollector {
|
||||
sub_agg: sub_agg.map(CachedSubAggs::<false>::new),
|
||||
column_type: field_type,
|
||||
accessor_idx,
|
||||
parent_buckets: Vec::new(),
|
||||
bucket_id_provider: BucketIdProvider::default(),
|
||||
limits: agg_data.context.limits.clone(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentRangeCollector {
|
||||
pub(crate) fn from_req_and_validate(
|
||||
req_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
let accessor_idx = node.idx_in_req_data;
|
||||
let (field_type, ranges) = {
|
||||
let req_view = req_data.get_range_req_data(node.idx_in_req_data);
|
||||
(req_view.field_type, req_view.req.ranges.clone())
|
||||
};
|
||||
|
||||
impl<const LOWCARD: bool> SegmentRangeCollector<LOWCARD> {
|
||||
pub(crate) fn create_new_buckets(
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<Vec<SegmentRangeAndBucketEntry>> {
|
||||
let field_type = self.column_type;
|
||||
let req_data = agg_data.get_range_req_data(self.accessor_idx);
|
||||
// The range input on the request is f64.
|
||||
// We need to convert to u64 ranges, because we read the values as u64.
|
||||
// The mapping from the conversion is monotonic so ordering is preserved.
|
||||
let sub_agg_prototype = if !node.children.is_empty() {
|
||||
Some(build_segment_agg_collectors(req_data, &node.children)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)?
|
||||
let buckets: Vec<_> = extend_validate_ranges(&req_data.req.ranges, &field_type)?
|
||||
.iter()
|
||||
.map(|range| {
|
||||
let bucket_id = self.bucket_id_provider.next_bucket_id();
|
||||
let key = range
|
||||
.key
|
||||
.clone()
|
||||
@@ -317,20 +390,20 @@ impl SegmentRangeCollector {
|
||||
let to = if range.range.end == u64::MAX {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.end, &field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.end, field_type))
|
||||
};
|
||||
let from = if range.range.start == u64::MIN {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.start, &field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.start, field_type))
|
||||
};
|
||||
let sub_aggregation = sub_agg_prototype.clone();
|
||||
// let sub_aggregation = sub_agg_prototype.clone();
|
||||
|
||||
Ok(SegmentRangeAndBucketEntry {
|
||||
range: range.range.clone(),
|
||||
bucket: SegmentRangeBucketEntry {
|
||||
doc_count: 0,
|
||||
sub_aggregation,
|
||||
bucket_id,
|
||||
key,
|
||||
from,
|
||||
to,
|
||||
@@ -339,27 +412,20 @@ impl SegmentRangeCollector {
|
||||
})
|
||||
.collect::<crate::Result<_>>()?;
|
||||
|
||||
req_data.context.limits.add_memory_consumed(
|
||||
self.limits.add_memory_consumed(
|
||||
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
|
||||
)?;
|
||||
|
||||
Ok(SegmentRangeCollector {
|
||||
buckets,
|
||||
column_type: field_type,
|
||||
accessor_idx,
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_bucket_pos(&self, val: u64) -> usize {
|
||||
let pos = self
|
||||
.buckets
|
||||
.binary_search_by_key(&val, |probe| probe.range.start)
|
||||
.unwrap_or_else(|pos| pos - 1);
|
||||
debug_assert!(self.buckets[pos].range.contains(&val));
|
||||
pos
|
||||
Ok(buckets)
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
fn get_bucket_pos(val: u64, buckets: &[SegmentRangeAndBucketEntry]) -> usize {
|
||||
let pos = buckets
|
||||
.binary_search_by_key(&val, |probe| probe.range.start)
|
||||
.unwrap_or_else(|pos| pos - 1);
|
||||
debug_assert!(buckets[pos].range.contains(&val));
|
||||
pos
|
||||
}
|
||||
|
||||
/// Converts the user provided f64 range value to fast field value space.
|
||||
///
|
||||
@@ -456,7 +522,7 @@ pub(crate) fn range_to_string(
|
||||
let val = i64::from_u64(val);
|
||||
format_date(val)
|
||||
} else {
|
||||
Ok(f64_from_fastfield_u64(val, field_type).to_string())
|
||||
Ok(f64_from_fastfield_u64(val, *field_type).to_string())
|
||||
}
|
||||
};
|
||||
|
||||
@@ -506,30 +572,33 @@ mod tests {
|
||||
let to = if range.range.end == u64::MAX {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.end, &field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.end, field_type))
|
||||
};
|
||||
let from = if range.range.start == u64::MIN {
|
||||
None
|
||||
} else {
|
||||
Some(f64_from_fastfield_u64(range.range.start, &field_type))
|
||||
Some(f64_from_fastfield_u64(range.range.start, field_type))
|
||||
};
|
||||
SegmentRangeAndBucketEntry {
|
||||
range: range.range.clone(),
|
||||
bucket: SegmentRangeBucketEntry {
|
||||
doc_count: 0,
|
||||
sub_aggregation: None,
|
||||
key,
|
||||
from,
|
||||
to,
|
||||
bucket_id: 0,
|
||||
},
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
SegmentRangeCollector {
|
||||
buckets,
|
||||
parent_buckets: vec![buckets],
|
||||
column_type: field_type,
|
||||
accessor_idx: 0,
|
||||
sub_agg: None,
|
||||
bucket_id_provider: Default::default(),
|
||||
limits: AggregationLimitsGuard::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -776,7 +845,7 @@ mod tests {
|
||||
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
|
||||
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
|
||||
|
||||
let buckets = collector.buckets;
|
||||
let buckets = collector.parent_buckets[0].clone();
|
||||
assert_eq!(buckets[0].range.start, u64::MIN);
|
||||
assert_eq!(buckets[0].range.end, 10f64.to_u64());
|
||||
assert_eq!(buckets[1].range.start, 10f64.to_u64());
|
||||
@@ -799,7 +868,7 @@ mod tests {
|
||||
];
|
||||
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
|
||||
|
||||
let buckets = collector.buckets;
|
||||
let buckets = collector.parent_buckets[0].clone();
|
||||
assert_eq!(buckets[0].range.start, u64::MIN);
|
||||
assert_eq!(buckets[0].range.end, 10f64.to_u64());
|
||||
assert_eq!(buckets[1].range.start, 10f64.to_u64());
|
||||
@@ -814,7 +883,7 @@ mod tests {
|
||||
let buckets = vec![(-10f64..-1f64).into()];
|
||||
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
|
||||
|
||||
let buckets = collector.buckets;
|
||||
let buckets = collector.parent_buckets[0].clone();
|
||||
assert_eq!(&buckets[0].bucket.key.to_string(), "*--10");
|
||||
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*");
|
||||
}
|
||||
@@ -823,7 +892,7 @@ mod tests {
|
||||
let buckets = vec![(0f64..10f64).into()];
|
||||
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
|
||||
|
||||
let buckets = collector.buckets;
|
||||
let buckets = collector.parent_buckets[0].clone();
|
||||
assert_eq!(&buckets[0].bucket.key.to_string(), "*-0");
|
||||
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*");
|
||||
}
|
||||
@@ -832,7 +901,7 @@ mod tests {
|
||||
fn range_binary_search_test_u64() {
|
||||
let check_ranges = |ranges: Vec<RangeAggregationRange>| {
|
||||
let collector = get_collector_from_ranges(ranges, ColumnType::U64);
|
||||
let search = |val: u64| collector.get_bucket_pos(val);
|
||||
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
|
||||
|
||||
assert_eq!(search(u64::MIN), 0);
|
||||
assert_eq!(search(9), 0);
|
||||
@@ -878,7 +947,7 @@ mod tests {
|
||||
let ranges = vec![(10.0..100.0).into()];
|
||||
|
||||
let collector = get_collector_from_ranges(ranges, ColumnType::F64);
|
||||
let search = |val: u64| collector.get_bucket_pos(val);
|
||||
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
|
||||
|
||||
assert_eq!(search(u64::MIN), 0);
|
||||
assert_eq!(search(9f64.to_u64()), 0);
|
||||
@@ -890,63 +959,3 @@ mod tests {
|
||||
// the max value
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
|
||||
use itertools::Itertools;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
use super::*;
|
||||
use crate::aggregation::bucket::range::tests::get_collector_from_ranges;
|
||||
|
||||
const TOTAL_DOCS: u64 = 1_000_000u64;
|
||||
const NUM_DOCS: u64 = 50_000u64;
|
||||
|
||||
fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector {
|
||||
let bucket_size = num_docs / num_buckets;
|
||||
let mut buckets: Vec<RangeAggregationRange> = vec![];
|
||||
for i in 0..num_buckets {
|
||||
let bucket_start = (i * bucket_size) as f64;
|
||||
buckets.push((bucket_start..bucket_start + bucket_size as f64).into())
|
||||
}
|
||||
|
||||
get_collector_from_ranges(buckets, ColumnType::U64)
|
||||
}
|
||||
|
||||
fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec<u64> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let all_docs = (0..total_docs - 1).collect_vec();
|
||||
let mut vals = all_docs
|
||||
.as_slice()
|
||||
.choose_multiple(&mut rng, num_docs_returned as usize)
|
||||
.cloned()
|
||||
.collect_vec();
|
||||
vals.sort();
|
||||
vals
|
||||
}
|
||||
|
||||
fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) {
|
||||
let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS);
|
||||
let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS);
|
||||
b.iter(|| {
|
||||
let mut bucket_pos = 0;
|
||||
for val in &vals {
|
||||
bucket_pos = collector.get_bucket_pos(*val);
|
||||
}
|
||||
bucket_pos
|
||||
})
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_range_100_buckets(b: &mut test::Bencher) {
|
||||
bench_range_binary_search(b, 100)
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_range_10_buckets(b: &mut test::Bencher) {
|
||||
bench_range_binary_search(b, 10)
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,11 +5,13 @@ use crate::aggregation::agg_data::{
|
||||
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
|
||||
};
|
||||
use crate::aggregation::bucket::term_agg::TermsAggregation;
|
||||
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
|
||||
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
|
||||
use crate::aggregation::BucketId;
|
||||
|
||||
/// Special aggregation to handle missing values for term aggregations.
|
||||
/// This missing aggregation will check multiple columns for existence.
|
||||
@@ -35,41 +37,55 @@ impl MissingTermAggReqData {
|
||||
}
|
||||
}
|
||||
|
||||
/// The specialized missing term aggregation.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct TermMissingAgg {
|
||||
struct MissingCount {
|
||||
missing_count: u32,
|
||||
bucket_id: BucketId,
|
||||
}
|
||||
|
||||
/// The specialized missing term aggregation.
|
||||
#[derive(Default, Debug)]
|
||||
pub struct TermMissingAgg {
|
||||
accessor_idx: usize,
|
||||
sub_agg: Option<Box<dyn SegmentAggregationCollector>>,
|
||||
sub_agg: Option<CachedSubAggs>,
|
||||
/// Idx = parent bucket id, Value = missing count for that bucket
|
||||
missing_count_per_bucket: Vec<MissingCount>,
|
||||
bucket_id_provider: BucketIdProvider,
|
||||
}
|
||||
impl TermMissingAgg {
|
||||
pub(crate) fn new(
|
||||
req_data: &mut AggregationsSegmentCtx,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
node: &AggRefNode,
|
||||
) -> crate::Result<Self> {
|
||||
let has_sub_aggregations = !node.children.is_empty();
|
||||
let accessor_idx = node.idx_in_req_data;
|
||||
let sub_agg = if has_sub_aggregations {
|
||||
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
|
||||
let sub_aggregation = build_segment_agg_collectors(agg_data, &node.children)?;
|
||||
Some(sub_aggregation)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let sub_agg = sub_agg.map(CachedSubAggs::new);
|
||||
let bucket_id_provider = BucketIdProvider::default();
|
||||
|
||||
Ok(Self {
|
||||
accessor_idx,
|
||||
sub_agg,
|
||||
..Default::default()
|
||||
missing_count_per_bucket: Vec::new(),
|
||||
bucket_id_provider,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for TermMissingAgg {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
|
||||
let term_agg = &req_data.req;
|
||||
let missing = term_agg
|
||||
@@ -80,13 +96,16 @@ impl SegmentAggregationCollector for TermMissingAgg {
|
||||
let mut entries: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> =
|
||||
Default::default();
|
||||
|
||||
let missing_count = &self.missing_count_per_bucket[parent_bucket_id as usize];
|
||||
let mut missing_entry = IntermediateTermBucketEntry {
|
||||
doc_count: self.missing_count,
|
||||
doc_count: missing_count.missing_count,
|
||||
sub_aggregation: Default::default(),
|
||||
};
|
||||
if let Some(sub_agg) = self.sub_agg {
|
||||
if let Some(sub_agg) = &mut self.sub_agg {
|
||||
let mut res = IntermediateAggregationResults::default();
|
||||
sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?;
|
||||
sub_agg
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(agg_data, &mut res, missing_count.bucket_id)?;
|
||||
missing_entry.sub_aggregation = res;
|
||||
}
|
||||
entries.insert(missing.into(), missing_entry);
|
||||
@@ -109,30 +128,52 @@ impl SegmentAggregationCollector for TermMissingAgg {
|
||||
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let bucket = &mut self.missing_count_per_bucket[parent_bucket_id as usize];
|
||||
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
|
||||
let has_value = req_data
|
||||
.accessors
|
||||
.iter()
|
||||
.any(|(acc, _)| acc.index.has_value(doc));
|
||||
if !has_value {
|
||||
self.missing_count += 1;
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.collect(doc, agg_data)?;
|
||||
|
||||
for doc in docs {
|
||||
let doc = *doc;
|
||||
let has_value = req_data
|
||||
.accessors
|
||||
.iter()
|
||||
.any(|(acc, _)| acc.index.has_value(doc));
|
||||
if !has_value {
|
||||
bucket.missing_count += 1;
|
||||
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.push(bucket.bucket_id, doc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.check_flush_local(agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
for doc in docs {
|
||||
self.collect(*doc, agg_data)?;
|
||||
while self.missing_count_per_bucket.len() <= max_bucket as usize {
|
||||
let bucket_id = self.bucket_id_provider.next_bucket_id();
|
||||
self.missing_count_per_bucket.push(MissingCount {
|
||||
missing_count: 0,
|
||||
bucket_id,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
if let Some(sub_agg) = self.sub_agg.as_mut() {
|
||||
sub_agg.flush(agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
use super::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use super::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::DocId;
|
||||
|
||||
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
|
||||
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
|
||||
|
||||
/// BufAggregationCollector buffers documents before calling collect_block().
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct BufAggregationCollector {
|
||||
pub(crate) collector: Box<dyn SegmentAggregationCollector>,
|
||||
staged_docs: DocBlock,
|
||||
num_staged_docs: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for BufAggregationCollector {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SegmentAggregationResultsCollector")
|
||||
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
|
||||
.field("num_staged_docs", &self.num_staged_docs)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl BufAggregationCollector {
|
||||
pub fn new(collector: Box<dyn SegmentAggregationCollector>) -> Self {
|
||||
Self {
|
||||
collector,
|
||||
num_staged_docs: 0,
|
||||
staged_docs: [0; DOC_BLOCK_SIZE],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for BufAggregationCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.staged_docs[self.num_staged_docs] = doc;
|
||||
self.num_staged_docs += 1;
|
||||
if self.num_staged_docs == self.staged_docs.len() {
|
||||
self.collector
|
||||
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
|
||||
self.num_staged_docs = 0;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collector.collect_block(docs, agg_data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
self.collector
|
||||
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
|
||||
self.num_staged_docs = 0;
|
||||
|
||||
self.collector.flush(agg_data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
185
src/aggregation/cached_sub_aggs.rs
Normal file
185
src/aggregation/cached_sub_aggs.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
use super::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC;
|
||||
use crate::aggregation::BucketId;
|
||||
use crate::DocId;
|
||||
|
||||
#[derive(Debug)]
|
||||
/// A cache for sub-aggregations, storing doc ids per bucket id.
|
||||
/// Depending on the cardinality of the parent aggregation, we use different
|
||||
/// storage strategies.
|
||||
///
|
||||
/// ## Low Cardinality
|
||||
/// Cardinality here refers to the number of unique flattened buckets that can be created
|
||||
/// by the parent aggregation.
|
||||
/// Flattened buckets are the result of combining all buckets per collector
|
||||
/// into a single list of buckets, where each bucket is identified by its BucketId.
|
||||
///
|
||||
/// ## Usage
|
||||
/// Since this is caching for sub-aggregations, it is only used by bucket
|
||||
/// aggregations.
|
||||
///
|
||||
/// TODO: consider using a more advanced data structure for high cardinality
|
||||
/// aggregations.
|
||||
/// What this datastructure does in general is to group docs by bucket id.
|
||||
pub(crate) struct CachedSubAggs<const LOWCARD: bool = false> {
|
||||
/// Only used when LOWCARD is true.
|
||||
/// Cache doc ids per bucket for sub-aggregations.
|
||||
///
|
||||
/// The outer Vec is indexed by BucketId.
|
||||
per_bucket_docs: Vec<Vec<DocId>>,
|
||||
/// Only used when LOWCARD is false.
|
||||
///
|
||||
/// This weird partitioning is used to do some cheap grouping on the bucket ids.
|
||||
/// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality,
|
||||
/// but there are just 16 bucket ids, each bucket id will go to its own partition.
|
||||
///
|
||||
/// We want to keep this cheap, because high cardinality aggregations can have a lot of
|
||||
/// buckets, and they may be nothing to group.
|
||||
partitions: [PartitionEntry; NUM_PARTITIONS],
|
||||
pub(crate) sub_agg_collector: Box<dyn SegmentAggregationCollector>,
|
||||
num_docs: usize,
|
||||
}
|
||||
|
||||
const FLUSH_THRESHOLD: usize = 2048;
|
||||
const NUM_PARTITIONS: usize = 16;
|
||||
|
||||
impl<const LOWCARD: bool> CachedSubAggs<LOWCARD> {
|
||||
pub fn get_sub_agg_collector(&mut self) -> &mut Box<dyn SegmentAggregationCollector> {
|
||||
&mut self.sub_agg_collector
|
||||
}
|
||||
|
||||
pub fn new(sub_agg: Box<dyn SegmentAggregationCollector>) -> Self {
|
||||
Self {
|
||||
per_bucket_docs: Vec::new(),
|
||||
num_docs: 0,
|
||||
sub_agg_collector: sub_agg,
|
||||
partitions: core::array::from_fn(|_| PartitionEntry::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn clear(&mut self) {
|
||||
for v in &mut self.per_bucket_docs {
|
||||
v.clear();
|
||||
}
|
||||
for partition in &mut self.partitions {
|
||||
partition.clear();
|
||||
}
|
||||
self.num_docs = 0;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
|
||||
if LOWCARD {
|
||||
// TODO: We could flush single buckets here
|
||||
let idx = bucket_id as usize;
|
||||
if self.per_bucket_docs.len() <= idx {
|
||||
self.per_bucket_docs.resize_with(idx + 1, Vec::new);
|
||||
}
|
||||
self.per_bucket_docs[idx].push(doc_id);
|
||||
} else {
|
||||
let idx = bucket_id % NUM_PARTITIONS as u32;
|
||||
let slot = &mut self.partitions[idx as usize];
|
||||
slot.bucket_ids.push(bucket_id);
|
||||
slot.docs.push(doc_id);
|
||||
}
|
||||
self.num_docs += 1;
|
||||
}
|
||||
|
||||
/// Check if we need to flush based on the number of documents cached.
|
||||
/// If so, flushes the cache to the provided aggregation collector.
|
||||
pub fn check_flush_local(
|
||||
&mut self,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
if self.num_docs >= FLUSH_THRESHOLD {
|
||||
self.flush_local(agg_data, false)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Note: this does _not_ flush the sub aggregations
|
||||
fn flush_local(
|
||||
&mut self,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
force: bool,
|
||||
) -> crate::Result<()> {
|
||||
if LOWCARD {
|
||||
// Pre-aggregated: call collect per bucket.
|
||||
let max_bucket = (self.per_bucket_docs.len() as BucketId).saturating_sub(1);
|
||||
self.sub_agg_collector
|
||||
.prepare_max_bucket(max_bucket, agg_data)?;
|
||||
// The threshold above which we flush buckets individually.
|
||||
// Note: We need to make sure that we don't lock ourselves into a situation where we hit
|
||||
// the FLUSH_THRESHOLD, but never flush any buckets. (except the final flush)
|
||||
let mut bucket_treshold = FLUSH_THRESHOLD / (self.per_bucket_docs.len().max(1) * 2);
|
||||
const _: () = {
|
||||
// MAX_NUM_TERMS_FOR_VEC == LOWCARD threshold
|
||||
let bucket_treshold = FLUSH_THRESHOLD / (MAX_NUM_TERMS_FOR_VEC as usize * 2);
|
||||
assert!(
|
||||
bucket_treshold > 0,
|
||||
"Bucket threshold must be greater than 0"
|
||||
);
|
||||
};
|
||||
if force {
|
||||
bucket_treshold = 0;
|
||||
}
|
||||
for (bucket_id, docs) in self
|
||||
.per_bucket_docs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, docs)| docs.len() > bucket_treshold)
|
||||
{
|
||||
self.sub_agg_collector
|
||||
.collect(bucket_id as BucketId, docs, agg_data)?;
|
||||
}
|
||||
} else {
|
||||
let mut max_bucket = 0u32;
|
||||
for partition in &self.partitions {
|
||||
if let Some(&local_max) = partition.bucket_ids.iter().max() {
|
||||
max_bucket = max_bucket.max(local_max);
|
||||
}
|
||||
}
|
||||
|
||||
self.sub_agg_collector
|
||||
.prepare_max_bucket(max_bucket, agg_data)?;
|
||||
|
||||
for slot in &self.partitions {
|
||||
if !slot.bucket_ids.is_empty() {
|
||||
// Reduce dynamic dispatch overhead by collecting a full partition in one call.
|
||||
self.sub_agg_collector.collect_multiple(
|
||||
&slot.bucket_ids,
|
||||
&slot.docs,
|
||||
agg_data,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
self.clear();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Note: this _does_ flush the sub aggregations
|
||||
pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
|
||||
if self.num_docs != 0 {
|
||||
self.flush_local(agg_data, true)?;
|
||||
}
|
||||
self.sub_agg_collector.flush(agg_data)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct PartitionEntry {
|
||||
bucket_ids: Vec<BucketId>,
|
||||
docs: Vec<DocId>,
|
||||
}
|
||||
|
||||
impl PartitionEntry {
|
||||
#[inline]
|
||||
fn clear(&mut self) {
|
||||
self.bucket_ids.clear();
|
||||
self.docs.clear();
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
use super::agg_req::Aggregations;
|
||||
use super::agg_result::AggregationResults;
|
||||
use super::buf_collector::BufAggregationCollector;
|
||||
use super::cached_sub_aggs::CachedSubAggs;
|
||||
use super::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use super::segment_agg_result::SegmentAggregationCollector;
|
||||
use super::AggContextParams;
|
||||
// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly.
|
||||
use crate::aggregation::agg_data::{
|
||||
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
|
||||
};
|
||||
@@ -136,7 +136,7 @@ fn merge_fruits(
|
||||
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
|
||||
pub struct AggregationSegmentCollector {
|
||||
aggs_with_accessor: AggregationsSegmentCtx,
|
||||
agg_collector: BufAggregationCollector,
|
||||
agg_collector: CachedSubAggs<true>,
|
||||
error: Option<TantivyError>,
|
||||
}
|
||||
|
||||
@@ -151,8 +151,10 @@ impl AggregationSegmentCollector {
|
||||
) -> crate::Result<Self> {
|
||||
let mut agg_data =
|
||||
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
|
||||
let result =
|
||||
BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?);
|
||||
let mut result = CachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
|
||||
result
|
||||
.get_sub_agg_collector()
|
||||
.prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero
|
||||
|
||||
Ok(AggregationSegmentCollector {
|
||||
aggs_with_accessor: agg_data,
|
||||
@@ -170,26 +172,31 @@ impl SegmentCollector for AggregationSegmentCollector {
|
||||
if self.error.is_some() {
|
||||
return;
|
||||
}
|
||||
if let Err(err) = self
|
||||
self.agg_collector.push(0, doc);
|
||||
match self
|
||||
.agg_collector
|
||||
.collect(doc, &mut self.aggs_with_accessor)
|
||||
.check_flush_local(&mut self.aggs_with_accessor)
|
||||
{
|
||||
self.error = Some(err);
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
self.error = Some(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The query pushes the documents to the collector via this method.
|
||||
///
|
||||
/// Only valid for Collectors that ignore docs
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
if self.error.is_some() {
|
||||
return;
|
||||
}
|
||||
if let Err(err) = self
|
||||
.agg_collector
|
||||
.collect_block(docs, &mut self.aggs_with_accessor)
|
||||
{
|
||||
self.error = Some(err);
|
||||
|
||||
match self.agg_collector.get_sub_agg_collector().collect(
|
||||
0,
|
||||
docs,
|
||||
&mut self.aggs_with_accessor,
|
||||
) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
self.error = Some(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -200,10 +207,13 @@ impl SegmentCollector for AggregationSegmentCollector {
|
||||
self.agg_collector.flush(&mut self.aggs_with_accessor)?;
|
||||
|
||||
let mut sub_aggregation_res = IntermediateAggregationResults::default();
|
||||
Box::new(self.agg_collector).add_intermediate_aggregation_result(
|
||||
&self.aggs_with_accessor,
|
||||
&mut sub_aggregation_res,
|
||||
)?;
|
||||
self.agg_collector
|
||||
.get_sub_agg_collector()
|
||||
.add_intermediate_aggregation_result(
|
||||
&self.aggs_with_accessor,
|
||||
&mut sub_aggregation_res,
|
||||
0,
|
||||
)?;
|
||||
|
||||
Ok(sub_aggregation_res)
|
||||
}
|
||||
|
||||
@@ -792,7 +792,7 @@ pub struct IntermediateRangeBucketEntry {
|
||||
/// The number of documents in the bucket.
|
||||
pub doc_count: u64,
|
||||
/// The sub_aggregation in this bucket.
|
||||
pub sub_aggregation: IntermediateAggregationResults,
|
||||
pub sub_aggregation_res: IntermediateAggregationResults,
|
||||
/// The from range of the bucket. Equals `f64::MIN` when `None`.
|
||||
pub from: Option<f64>,
|
||||
/// The to range of the bucket. Equals `f64::MAX` when `None`.
|
||||
@@ -811,7 +811,7 @@ impl IntermediateRangeBucketEntry {
|
||||
key: self.key.into(),
|
||||
doc_count: self.doc_count,
|
||||
sub_aggregation: self
|
||||
.sub_aggregation
|
||||
.sub_aggregation_res
|
||||
.into_final_result_internal(req, limits)?,
|
||||
to: self.to,
|
||||
from: self.from,
|
||||
@@ -857,7 +857,8 @@ impl MergeFruits for IntermediateTermBucketEntry {
|
||||
impl MergeFruits for IntermediateRangeBucketEntry {
|
||||
fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> {
|
||||
self.doc_count += other.doc_count;
|
||||
self.sub_aggregation.merge_fruits(other.sub_aggregation)?;
|
||||
self.sub_aggregation_res
|
||||
.merge_fruits(other.sub_aggregation_res)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -887,7 +888,7 @@ mod tests {
|
||||
IntermediateRangeBucketEntry {
|
||||
key: IntermediateKey::Str(key.to_string()),
|
||||
doc_count: *doc_count,
|
||||
sub_aggregation: Default::default(),
|
||||
sub_aggregation_res: Default::default(),
|
||||
from: None,
|
||||
to: None,
|
||||
},
|
||||
@@ -920,7 +921,7 @@ mod tests {
|
||||
doc_count: *doc_count,
|
||||
from: None,
|
||||
to: None,
|
||||
sub_aggregation: get_sub_test_tree(&[(
|
||||
sub_aggregation_res: get_sub_test_tree(&[(
|
||||
sub_aggregation_key.to_string(),
|
||||
*sub_aggregation_count,
|
||||
)]),
|
||||
|
||||
@@ -52,10 +52,8 @@ pub struct IntermediateAverage {
|
||||
|
||||
impl IntermediateAverage {
|
||||
/// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateAverage) {
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{BuildHasher, Hasher};
|
||||
|
||||
use columnar::column_values::CompactSpaceU64Accessor;
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType, Dictionary, StrColumn};
|
||||
use columnar::{Column, ColumnType, Dictionary, StrColumn};
|
||||
use common::f64_to_u64;
|
||||
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
|
||||
use rustc_hash::FxHashSet;
|
||||
@@ -106,8 +106,6 @@ pub struct CardinalityAggReqData {
|
||||
pub str_dict_column: Option<StrColumn>,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
pub missing_value_for_accessor: Option<u64>,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The name of the aggregation.
|
||||
pub name: String,
|
||||
/// The aggregation request.
|
||||
@@ -135,45 +133,34 @@ impl CardinalityAggregationReq {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct SegmentCardinalityCollector {
|
||||
cardinality: CardinalityCollector,
|
||||
entries: FxHashSet<u64>,
|
||||
buckets: Vec<SegmentCardinalityCollectorBucket>,
|
||||
accessor_idx: usize,
|
||||
/// The column accessor to access the fast field values.
|
||||
accessor: Column<u64>,
|
||||
/// The column_type of the field.
|
||||
column_type: ColumnType,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
missing_value_for_accessor: Option<u64>,
|
||||
}
|
||||
|
||||
impl SegmentCardinalityCollector {
|
||||
pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self {
|
||||
#[derive(Clone, Debug, PartialEq, Default)]
|
||||
pub(crate) struct SegmentCardinalityCollectorBucket {
|
||||
cardinality: CardinalityCollector,
|
||||
entries: FxHashSet<u64>,
|
||||
}
|
||||
impl SegmentCardinalityCollectorBucket {
|
||||
pub fn new(column_type: ColumnType) -> Self {
|
||||
Self {
|
||||
cardinality: CardinalityCollector::new(column_type as u8),
|
||||
entries: Default::default(),
|
||||
accessor_idx,
|
||||
entries: FxHashSet::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_block_with_field(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut CardinalityAggReqData,
|
||||
) {
|
||||
if let Some(missing) = agg_data.missing_value_for_accessor {
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&agg_data.accessor,
|
||||
missing,
|
||||
);
|
||||
} else {
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &agg_data.accessor);
|
||||
}
|
||||
}
|
||||
|
||||
fn into_intermediate_metric_result(
|
||||
mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
req_data: &CardinalityAggReqData,
|
||||
) -> crate::Result<IntermediateMetricResult> {
|
||||
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
|
||||
if req_data.column_type == ColumnType::Str {
|
||||
let fallback_dict = Dictionary::empty();
|
||||
let dict = req_data
|
||||
@@ -194,6 +181,7 @@ impl SegmentCardinalityCollector {
|
||||
term_ids.push(term_ord as u32);
|
||||
}
|
||||
}
|
||||
|
||||
term_ids.sort_unstable();
|
||||
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
|
||||
self.cardinality.sketch.insert_any(&term);
|
||||
@@ -227,16 +215,49 @@ impl SegmentCardinalityCollector {
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentCardinalityCollector {
|
||||
pub fn from_req(
|
||||
column_type: ColumnType,
|
||||
accessor_idx: usize,
|
||||
accessor: Column<u64>,
|
||||
missing_value_for_accessor: Option<u64>,
|
||||
) -> Self {
|
||||
Self {
|
||||
buckets: vec![SegmentCardinalityCollectorBucket::new(column_type); 1],
|
||||
column_type,
|
||||
accessor_idx,
|
||||
accessor,
|
||||
missing_value_for_accessor,
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_block_with_field(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) {
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&self.accessor,
|
||||
self.missing_value_for_accessor,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
|
||||
let name = req_data.name.to_string();
|
||||
// take the bucket in buckets and replace it with a new empty one
|
||||
let bucket = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
|
||||
|
||||
let intermediate_result = self.into_intermediate_metric_result(agg_data)?;
|
||||
let intermediate_result = bucket.into_intermediate_metric_result(req_data)?;
|
||||
results.push(
|
||||
name,
|
||||
IntermediateAggregationResult::Metric(intermediate_result),
|
||||
@@ -247,27 +268,20 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx);
|
||||
self.fetch_block_with_field(docs, req_data);
|
||||
self.fetch_block_with_field(docs, agg_data);
|
||||
let bucket = &mut self.buckets[parent_bucket_id as usize];
|
||||
|
||||
let col_block_accessor = &req_data.column_block_accessor;
|
||||
if req_data.column_type == ColumnType::Str {
|
||||
let col_block_accessor = &agg_data.column_block_accessor;
|
||||
if self.column_type == ColumnType::Str {
|
||||
for term_ord in col_block_accessor.iter_vals() {
|
||||
self.entries.insert(term_ord);
|
||||
bucket.entries.insert(term_ord);
|
||||
}
|
||||
} else if req_data.column_type == ColumnType::IpAddr {
|
||||
let compact_space_accessor = req_data
|
||||
} else if self.column_type == ColumnType::IpAddr {
|
||||
let compact_space_accessor = self
|
||||
.accessor
|
||||
.values
|
||||
.clone()
|
||||
@@ -282,16 +296,29 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
})?;
|
||||
for val in col_block_accessor.iter_vals() {
|
||||
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
bucket.cardinality.sketch.insert_any(&val);
|
||||
}
|
||||
} else {
|
||||
for val in col_block_accessor.iter_vals() {
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
bucket.cardinality.sketch.insert_any(&val);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
if max_bucket as usize >= self.buckets.len() {
|
||||
self.buckets.resize_with(max_bucket as usize + 1, || {
|
||||
SegmentCardinalityCollectorBucket::new(self.column_type)
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
|
||||
@@ -52,10 +52,8 @@ pub struct IntermediateCount {
|
||||
|
||||
impl IntermediateCount {
|
||||
/// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateCount) {
|
||||
|
||||
@@ -8,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::metric::MetricAggReqData;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::{DocId, TantivyError};
|
||||
use crate::TantivyError;
|
||||
|
||||
/// A multi-value metric aggregation that computes a collection of extended statistics
|
||||
/// on numeric values that are extracted
|
||||
@@ -62,7 +61,7 @@ impl ExtendedStatsAggregation {
|
||||
|
||||
/// Extended stats contains a collection of statistics
|
||||
/// they extends stats adding variance, standard deviation
|
||||
/// and bound informations
|
||||
/// and bound information
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ExtendedStats {
|
||||
/// The number of documents.
|
||||
@@ -318,51 +317,28 @@ impl IntermediateExtendedStats {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct SegmentExtendedStatsCollector {
|
||||
name: String,
|
||||
missing: Option<u64>,
|
||||
field_type: ColumnType,
|
||||
pub(crate) extended_stats: IntermediateExtendedStats,
|
||||
pub(crate) accessor_idx: usize,
|
||||
val_cache: Vec<u64>,
|
||||
accessor: columnar::Column<u64>,
|
||||
buckets: Vec<IntermediateExtendedStats>,
|
||||
sigma: Option<f64>,
|
||||
}
|
||||
|
||||
impl SegmentExtendedStatsCollector {
|
||||
pub fn from_req(
|
||||
field_type: ColumnType,
|
||||
sigma: Option<f64>,
|
||||
accessor_idx: usize,
|
||||
missing: Option<f64>,
|
||||
) -> Self {
|
||||
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
|
||||
pub fn from_req(req: &MetricAggReqData, sigma: Option<f64>) -> Self {
|
||||
let missing = req
|
||||
.missing
|
||||
.and_then(|val| f64_to_fastfield_u64(val, &req.field_type));
|
||||
Self {
|
||||
field_type,
|
||||
extended_stats: IntermediateExtendedStats::with_sigma(sigma),
|
||||
accessor_idx,
|
||||
name: req.name.clone(),
|
||||
field_type: req.field_type,
|
||||
accessor: req.accessor.clone(),
|
||||
missing,
|
||||
val_cache: Default::default(),
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn collect_block_with_field(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
req_data: &mut MetricAggReqData,
|
||||
) {
|
||||
if let Some(missing) = self.missing.as_ref() {
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&req_data.accessor,
|
||||
*missing,
|
||||
);
|
||||
} else {
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
for val in req_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
self.extended_stats.collect(val1);
|
||||
buckets: vec![IntermediateExtendedStats::with_sigma(sigma); 16],
|
||||
sigma,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -370,15 +346,18 @@ impl SegmentExtendedStatsCollector {
|
||||
impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
|
||||
let name = self.name.clone();
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let extended_stats = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
|
||||
results.push(
|
||||
name,
|
||||
IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats(
|
||||
self.extended_stats,
|
||||
extended_stats,
|
||||
)),
|
||||
)?;
|
||||
|
||||
@@ -388,39 +367,36 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
if let Some(missing) = self.missing {
|
||||
let mut has_val = false;
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
self.extended_stats.collect(val1);
|
||||
has_val = true;
|
||||
}
|
||||
if !has_val {
|
||||
self.extended_stats
|
||||
.collect(f64_from_fastfield_u64(missing, &self.field_type));
|
||||
}
|
||||
} else {
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &self.field_type);
|
||||
self.extended_stats.collect(val1);
|
||||
}
|
||||
let mut extended_stats = self.buckets[parent_bucket_id as usize].clone();
|
||||
|
||||
agg_data
|
||||
.column_block_accessor
|
||||
.fetch_block_with_missing(docs, &self.accessor, self.missing);
|
||||
for val in agg_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, self.field_type);
|
||||
extended_stats.collect(val1);
|
||||
}
|
||||
|
||||
// store back
|
||||
self.buckets[parent_bucket_id as usize] = extended_stats;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
|
||||
self.collect_block_with_field(docs, req_data);
|
||||
if self.buckets.len() <= max_bucket as usize {
|
||||
self.buckets.resize_with(max_bucket as usize + 1, || {
|
||||
IntermediateExtendedStats::with_sigma(self.sigma)
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,10 +52,8 @@ pub struct IntermediateMax {
|
||||
|
||||
impl IntermediateMax {
|
||||
/// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateMax) {
|
||||
|
||||
@@ -52,10 +52,8 @@ pub struct IntermediateMin {
|
||||
|
||||
impl IntermediateMin {
|
||||
/// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateMin) {
|
||||
|
||||
@@ -31,7 +31,7 @@ use std::collections::HashMap;
|
||||
|
||||
pub use average::*;
|
||||
pub use cardinality::*;
|
||||
use columnar::{Column, ColumnBlockAccessor, ColumnType};
|
||||
use columnar::{Column, ColumnType};
|
||||
pub use count::*;
|
||||
pub use extended_stats::*;
|
||||
pub use max::*;
|
||||
@@ -55,8 +55,6 @@ pub struct MetricAggReqData {
|
||||
pub field_type: ColumnType,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
pub missing_u64: Option<u64>,
|
||||
/// The column block accessor to access the fast field values.
|
||||
pub column_block_accessor: ColumnBlockAccessor<u64>,
|
||||
/// The column accessor to access the fast field values.
|
||||
pub accessor: Column<u64>,
|
||||
/// Used when converting to intermediate result
|
||||
|
||||
@@ -7,10 +7,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::metric::MetricAggReqData;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::{DocId, TantivyError};
|
||||
use crate::TantivyError;
|
||||
|
||||
/// # Percentiles
|
||||
///
|
||||
@@ -131,10 +130,16 @@ impl PercentilesAggregationReq {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct SegmentPercentilesCollector {
|
||||
pub(crate) percentiles: PercentilesCollector,
|
||||
pub(crate) buckets: Vec<PercentilesCollector>,
|
||||
pub(crate) accessor_idx: usize,
|
||||
/// The type of the field.
|
||||
pub field_type: ColumnType,
|
||||
/// The missing value normalized to the internal u64 representation of the field type.
|
||||
pub missing_u64: Option<u64>,
|
||||
/// The column accessor to access the fast field values.
|
||||
pub accessor: Column<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
@@ -229,33 +234,18 @@ impl PercentilesCollector {
|
||||
}
|
||||
|
||||
impl SegmentPercentilesCollector {
|
||||
pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result<Self> {
|
||||
Ok(Self {
|
||||
percentiles: PercentilesCollector::new(),
|
||||
pub fn from_req_and_validate(
|
||||
field_type: ColumnType,
|
||||
missing_u64: Option<u64>,
|
||||
accessor: Column<u64>,
|
||||
accessor_idx: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
buckets: Vec::with_capacity(64),
|
||||
field_type,
|
||||
missing_u64,
|
||||
accessor,
|
||||
accessor_idx,
|
||||
})
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn collect_block_with_field(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
req_data: &mut MetricAggReqData,
|
||||
) {
|
||||
if let Some(missing) = req_data.missing_u64.as_ref() {
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&req_data.accessor,
|
||||
*missing,
|
||||
);
|
||||
} else {
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
|
||||
for val in req_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.percentiles.collect(val1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -263,12 +253,18 @@ impl SegmentPercentilesCollector {
|
||||
impl SegmentAggregationCollector for SegmentPercentilesCollector {
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
|
||||
let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles);
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
// Swap collector with an empty one to avoid cloning
|
||||
let percentiles_collector = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
|
||||
|
||||
let intermediate_metric_result =
|
||||
IntermediateMetricResult::Percentiles(percentiles_collector);
|
||||
|
||||
results.push(
|
||||
name,
|
||||
@@ -281,40 +277,33 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
let percentiles = &mut self.buckets[parent_bucket_id as usize];
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&self.accessor,
|
||||
self.missing_u64,
|
||||
);
|
||||
|
||||
if let Some(missing) = req_data.missing_u64 {
|
||||
let mut has_val = false;
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.percentiles.collect(val1);
|
||||
has_val = true;
|
||||
}
|
||||
if !has_val {
|
||||
self.percentiles
|
||||
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
|
||||
}
|
||||
} else {
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.percentiles.collect(val1);
|
||||
}
|
||||
for val in agg_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, self.field_type);
|
||||
percentiles.collect(val1);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
|
||||
self.collect_block_with_field(docs, req_data);
|
||||
while self.buckets.len() <= max_bucket as usize {
|
||||
self.buckets.push(PercentilesCollector::new());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use columnar::{Column, ColumnType};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::*;
|
||||
@@ -7,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::metric::MetricAggReqData;
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::{DocId, TantivyError};
|
||||
use crate::TantivyError;
|
||||
|
||||
/// A multi-value metric aggregation that computes a collection of statistics on numeric values that
|
||||
/// are extracted from the aggregated documents.
|
||||
@@ -83,7 +83,7 @@ impl Stats {
|
||||
|
||||
/// Intermediate result of the stats aggregation that can be combined with other intermediate
|
||||
/// results.
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct IntermediateStats {
|
||||
/// The number of extracted values.
|
||||
pub(crate) count: u64,
|
||||
@@ -187,75 +187,75 @@ pub enum StatsType {
|
||||
Percentiles,
|
||||
}
|
||||
|
||||
fn create_collector<const TYPE_ID: u8>(
|
||||
req: &MetricAggReqData,
|
||||
) -> Box<dyn SegmentAggregationCollector> {
|
||||
Box::new(SegmentStatsCollector::<TYPE_ID> {
|
||||
name: req.name.clone(),
|
||||
collecting_for: req.collecting_for,
|
||||
is_number_or_date_type: req.is_number_or_date_type,
|
||||
missing_u64: req.missing_u64,
|
||||
accessor: req.accessor.clone(),
|
||||
buckets: vec![IntermediateStats::default()],
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a concrete `SegmentStatsCollector` depending on the column type.
|
||||
pub(crate) fn build_segment_stats_collector(
|
||||
req: &MetricAggReqData,
|
||||
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
|
||||
match req.field_type {
|
||||
ColumnType::I64 => Ok(create_collector::<{ ColumnType::I64 as u8 }>(req)),
|
||||
ColumnType::U64 => Ok(create_collector::<{ ColumnType::U64 as u8 }>(req)),
|
||||
ColumnType::F64 => Ok(create_collector::<{ ColumnType::F64 as u8 }>(req)),
|
||||
ColumnType::Bool => Ok(create_collector::<{ ColumnType::Bool as u8 }>(req)),
|
||||
ColumnType::DateTime => Ok(create_collector::<{ ColumnType::DateTime as u8 }>(req)),
|
||||
ColumnType::Bytes => Ok(create_collector::<{ ColumnType::Bytes as u8 }>(req)),
|
||||
ColumnType::Str => Ok(create_collector::<{ ColumnType::Str as u8 }>(req)),
|
||||
ColumnType::IpAddr => Ok(create_collector::<{ ColumnType::IpAddr as u8 }>(req)),
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct SegmentStatsCollector {
|
||||
pub(crate) stats: IntermediateStats,
|
||||
pub(crate) accessor_idx: usize,
|
||||
pub(crate) struct SegmentStatsCollector<const COLUMN_TYPE_ID: u8> {
|
||||
pub(crate) missing_u64: Option<u64>,
|
||||
pub(crate) accessor: Column<u64>,
|
||||
pub(crate) is_number_or_date_type: bool,
|
||||
pub(crate) buckets: Vec<IntermediateStats>,
|
||||
pub(crate) name: String,
|
||||
pub(crate) collecting_for: StatsType,
|
||||
}
|
||||
|
||||
impl SegmentStatsCollector {
|
||||
pub fn from_req(accessor_idx: usize) -> Self {
|
||||
Self {
|
||||
stats: IntermediateStats::default(),
|
||||
accessor_idx,
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
pub(crate) fn collect_block_with_field(
|
||||
&mut self,
|
||||
docs: &[DocId],
|
||||
req_data: &mut MetricAggReqData,
|
||||
) {
|
||||
if let Some(missing) = req_data.missing_u64.as_ref() {
|
||||
req_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&req_data.accessor,
|
||||
*missing,
|
||||
);
|
||||
} else {
|
||||
req_data
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &req_data.accessor);
|
||||
}
|
||||
if req_data.is_number_or_date_type {
|
||||
for val in req_data.column_block_accessor.iter_vals() {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.stats.collect(val1);
|
||||
}
|
||||
} else {
|
||||
for _val in req_data.column_block_accessor.iter_vals() {
|
||||
// we ignore the value and simply record that we got something
|
||||
self.stats.collect(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentStatsCollector {
|
||||
impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
|
||||
for SegmentStatsCollector<COLUMN_TYPE_ID>
|
||||
{
|
||||
#[inline]
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let req = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
let name = req.name.clone();
|
||||
let name = self.name.clone();
|
||||
|
||||
let intermediate_metric_result = match req.collecting_for {
|
||||
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
|
||||
let stats = self.buckets[parent_bucket_id as usize];
|
||||
let intermediate_metric_result = match self.collecting_for {
|
||||
StatsType::Average => {
|
||||
IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self))
|
||||
IntermediateMetricResult::Average(IntermediateAverage::from_stats(stats))
|
||||
}
|
||||
StatsType::Count => {
|
||||
IntermediateMetricResult::Count(IntermediateCount::from_collector(*self))
|
||||
IntermediateMetricResult::Count(IntermediateCount::from_stats(stats))
|
||||
}
|
||||
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)),
|
||||
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)),
|
||||
StatsType::Stats => IntermediateMetricResult::Stats(self.stats),
|
||||
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)),
|
||||
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_stats(stats)),
|
||||
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_stats(stats)),
|
||||
StatsType::Stats => IntermediateMetricResult::Stats(stats),
|
||||
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_stats(stats)),
|
||||
_ => {
|
||||
return Err(TantivyError::InvalidArgument(format!(
|
||||
"Unsupported stats type for stats aggregation: {:?}",
|
||||
req.collecting_for
|
||||
self.collecting_for
|
||||
)))
|
||||
}
|
||||
};
|
||||
@@ -271,41 +271,67 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
|
||||
#[inline]
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
|
||||
if let Some(missing) = req_data.missing_u64 {
|
||||
let mut has_val = false;
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.stats.collect(val1);
|
||||
has_val = true;
|
||||
}
|
||||
if !has_val {
|
||||
self.stats
|
||||
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
|
||||
}
|
||||
} else {
|
||||
for val in req_data.accessor.values_for_doc(doc) {
|
||||
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
|
||||
self.stats.collect(val1);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
|
||||
self.collect_block_with_field(docs, req_data);
|
||||
// TODO: remove once we fetch all values for all bucket ids in one go
|
||||
if docs.len() == 1 && self.missing_u64.is_none() {
|
||||
collect_stats::<COLUMN_TYPE_ID>(
|
||||
&mut self.buckets[parent_bucket_id as usize],
|
||||
self.accessor.values_for_doc(docs[0]),
|
||||
self.is_number_or_date_type,
|
||||
)?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
agg_data.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&self.accessor,
|
||||
self.missing_u64,
|
||||
);
|
||||
collect_stats::<COLUMN_TYPE_ID>(
|
||||
&mut self.buckets[parent_bucket_id as usize],
|
||||
agg_data.column_block_accessor.iter_vals(),
|
||||
self.is_number_or_date_type,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let required_buckets = (max_bucket as usize) + 1;
|
||||
if self.buckets.len() < required_buckets {
|
||||
self.buckets
|
||||
.resize_with(required_buckets, IntermediateStats::default);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn collect_stats<const COLUMN_TYPE_ID: u8>(
|
||||
stats: &mut IntermediateStats,
|
||||
vals: impl Iterator<Item = u64>,
|
||||
is_number_or_date_type: bool,
|
||||
) -> crate::Result<()> {
|
||||
if is_number_or_date_type {
|
||||
for val in vals {
|
||||
let val1 = convert_to_f64::<COLUMN_TYPE_ID>(val);
|
||||
stats.collect(val1);
|
||||
}
|
||||
} else {
|
||||
for _val in vals {
|
||||
// we ignore the value and simply record that we got something
|
||||
stats.collect(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -52,10 +52,8 @@ pub struct IntermediateSum {
|
||||
|
||||
impl IntermediateSum {
|
||||
/// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`].
|
||||
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
|
||||
Self {
|
||||
stats: collector.stats,
|
||||
}
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
}
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateSum) {
|
||||
|
||||
@@ -15,11 +15,11 @@ use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::AggregationError;
|
||||
use crate::aggregation::{AggregationError, BucketId};
|
||||
use crate::collector::sort_key::ReverseComparator;
|
||||
use crate::collector::TopNComputer;
|
||||
use crate::schema::OwnedValue;
|
||||
use crate::{DocAddress, DocId, SegmentOrdinal};
|
||||
// duplicate import removed; already imported above
|
||||
|
||||
/// Contains all information required by the TopHitsSegmentCollector to perform the
|
||||
/// top_hits aggregation on a segment.
|
||||
@@ -458,7 +458,7 @@ impl Eq for DocSortValuesAndFields {}
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct TopHitsTopNComputer {
|
||||
req: TopHitsAggregationReq,
|
||||
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, false>,
|
||||
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, ReverseComparator>,
|
||||
}
|
||||
|
||||
impl std::cmp::PartialEq for TopHitsTopNComputer {
|
||||
@@ -471,7 +471,10 @@ impl TopHitsTopNComputer {
|
||||
/// Create a new TopHitsCollector
|
||||
pub fn new(req: &TopHitsAggregationReq) -> Self {
|
||||
Self {
|
||||
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
|
||||
top_n: TopNComputer::new_with_comparator(
|
||||
req.size + req.from.unwrap_or(0),
|
||||
ReverseComparator,
|
||||
),
|
||||
req: req.clone(),
|
||||
}
|
||||
}
|
||||
@@ -482,7 +485,7 @@ impl TopHitsTopNComputer {
|
||||
|
||||
pub(crate) fn merge_fruits(&mut self, other_fruit: Self) -> crate::Result<()> {
|
||||
for doc in other_fruit.top_n.into_vec() {
|
||||
self.collect(doc.feature, doc.doc);
|
||||
self.collect(doc.sort_key, doc.doc);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -494,9 +497,9 @@ impl TopHitsTopNComputer {
|
||||
.into_sorted_vec()
|
||||
.into_iter()
|
||||
.map(|doc| TopHitsVecEntry {
|
||||
sort: doc.feature.sorts.iter().map(|f| f.value).collect(),
|
||||
sort: doc.sort_key.sorts.iter().map(|f| f.value).collect(),
|
||||
doc_value_fields: doc
|
||||
.feature
|
||||
.sort_key
|
||||
.doc_value_fields
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.into()))
|
||||
@@ -517,7 +520,8 @@ impl TopHitsTopNComputer {
|
||||
pub(crate) struct TopHitsSegmentCollector {
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
accessor_idx: usize,
|
||||
top_n: TopNComputer<Vec<DocValueAndOrder>, DocAddress, false>,
|
||||
buckets: Vec<TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>>,
|
||||
num_hits: usize,
|
||||
}
|
||||
|
||||
impl TopHitsSegmentCollector {
|
||||
@@ -526,25 +530,35 @@ impl TopHitsSegmentCollector {
|
||||
accessor_idx: usize,
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
) -> Self {
|
||||
let num_hits = req.size + req.from.unwrap_or(0);
|
||||
Self {
|
||||
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
|
||||
num_hits,
|
||||
segment_ordinal,
|
||||
accessor_idx,
|
||||
buckets: vec![TopNComputer::new_with_comparator(num_hits, ReverseComparator); 1],
|
||||
}
|
||||
}
|
||||
fn into_top_hits_collector(
|
||||
self,
|
||||
fn get_top_hits_computer(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
value_accessors: &HashMap<String, Vec<DynamicColumn>>,
|
||||
req: &TopHitsAggregationReq,
|
||||
) -> TopHitsTopNComputer {
|
||||
if parent_bucket_id as usize >= self.buckets.len() {
|
||||
return TopHitsTopNComputer::new(req);
|
||||
}
|
||||
let top_n = std::mem::replace(
|
||||
&mut self.buckets[parent_bucket_id as usize],
|
||||
TopNComputer::new(0),
|
||||
);
|
||||
let mut top_hits_computer = TopHitsTopNComputer::new(req);
|
||||
let top_results = self.top_n.into_vec();
|
||||
let top_results = top_n.into_vec();
|
||||
|
||||
for res in top_results {
|
||||
let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id);
|
||||
top_hits_computer.collect(
|
||||
DocSortValuesAndFields {
|
||||
sorts: res.feature,
|
||||
sorts: res.sort_key,
|
||||
doc_value_fields,
|
||||
},
|
||||
res.doc,
|
||||
@@ -553,54 +567,24 @@ impl TopHitsSegmentCollector {
|
||||
|
||||
top_hits_computer
|
||||
}
|
||||
|
||||
/// TODO add a specialized variant for a single sort field
|
||||
fn collect_with(
|
||||
&mut self,
|
||||
doc_id: crate::DocId,
|
||||
req: &TopHitsAggregationReq,
|
||||
accessors: &[(Column<u64>, ColumnType)],
|
||||
) -> crate::Result<()> {
|
||||
let sorts: Vec<DocValueAndOrder> = req
|
||||
.sort
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, KeyOrder { order, .. })| {
|
||||
let order = *order;
|
||||
let value = accessors
|
||||
.get(idx)
|
||||
.expect("could not find field in accessors")
|
||||
.0
|
||||
.values_for_doc(doc_id)
|
||||
.next();
|
||||
DocValueAndOrder { value, order }
|
||||
})
|
||||
.collect();
|
||||
|
||||
self.top_n.push(
|
||||
sorts,
|
||||
DocAddress {
|
||||
segment_ord: self.segment_ordinal,
|
||||
doc_id,
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for TopHitsSegmentCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
|
||||
|
||||
let value_accessors = &req_data.value_accessors;
|
||||
|
||||
let intermediate_result = IntermediateMetricResult::TopHits(
|
||||
self.into_top_hits_collector(value_accessors, &req_data.req),
|
||||
);
|
||||
let intermediate_result = IntermediateMetricResult::TopHits(self.get_top_hits_computer(
|
||||
parent_bucket_id,
|
||||
value_accessors,
|
||||
&req_data.req,
|
||||
));
|
||||
results.push(
|
||||
req_data.name.to_string(),
|
||||
IntermediateAggregationResult::Metric(intermediate_result),
|
||||
@@ -610,26 +594,57 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
|
||||
/// TODO: Consider a caching layer to reduce the call overhead
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc_id: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
|
||||
self.collect_with(doc_id, &req_data.req, &req_data.accessors)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
let top_n = &mut self.buckets[parent_bucket_id as usize];
|
||||
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
|
||||
// TODO: Consider getting fields with the column block accessor.
|
||||
for doc in docs {
|
||||
self.collect_with(*doc, &req_data.req, &req_data.accessors)?;
|
||||
let req = &req_data.req;
|
||||
let accessors = &req_data.accessors;
|
||||
for doc_id in docs {
|
||||
let doc_id = *doc_id;
|
||||
// TODO: this is terrible, a new vec is allocated for every doc
|
||||
// We can fetch blocks instead
|
||||
// We don't need to store the order for every value
|
||||
let sorts: Vec<DocValueAndOrder> = req
|
||||
.sort
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, KeyOrder { order, .. })| {
|
||||
let order = *order;
|
||||
let value = accessors
|
||||
.get(idx)
|
||||
.expect("could not find field in accessors")
|
||||
.0
|
||||
.values_for_doc(doc_id)
|
||||
.next();
|
||||
DocValueAndOrder { value, order }
|
||||
})
|
||||
.collect();
|
||||
|
||||
top_n.push(
|
||||
sorts,
|
||||
DocAddress {
|
||||
segment_ord: self.segment_ordinal,
|
||||
doc_id,
|
||||
},
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
_agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.buckets.resize(
|
||||
(max_bucket as usize) + 1,
|
||||
TopNComputer::new_with_comparator(self.num_hits, ReverseComparator),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -645,6 +660,7 @@ mod tests {
|
||||
use crate::aggregation::bucket::tests::get_test_index_from_docs;
|
||||
use crate::aggregation::tests::get_test_index_from_values;
|
||||
use crate::aggregation::AggregationCollector;
|
||||
use crate::collector::sort_key::ReverseComparator;
|
||||
use crate::collector::ComparableDoc;
|
||||
use crate::query::AllQuery;
|
||||
use crate::schema::OwnedValue;
|
||||
@@ -660,7 +676,7 @@ mod tests {
|
||||
|
||||
fn collector_with_capacity(capacity: usize) -> super::TopHitsTopNComputer {
|
||||
super::TopHitsTopNComputer {
|
||||
top_n: super::TopNComputer::new(capacity),
|
||||
top_n: super::TopNComputer::new_with_comparator(capacity, ReverseComparator),
|
||||
req: Default::default(),
|
||||
}
|
||||
}
|
||||
@@ -744,7 +760,7 @@ mod tests {
|
||||
],
|
||||
"from": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
@@ -774,12 +790,12 @@ mod tests {
|
||||
#[test]
|
||||
fn test_top_hits_collector_single_feature() -> crate::Result<()> {
|
||||
let docs = vec![
|
||||
ComparableDoc::<_, _, false> {
|
||||
ComparableDoc::<_, _> {
|
||||
doc: crate::DocAddress {
|
||||
segment_ord: 0,
|
||||
doc_id: 0,
|
||||
},
|
||||
feature: DocSortValuesAndFields {
|
||||
sort_key: DocSortValuesAndFields {
|
||||
sorts: vec![DocValueAndOrder {
|
||||
value: Some(1),
|
||||
order: Order::Asc,
|
||||
@@ -792,7 +808,7 @@ mod tests {
|
||||
segment_ord: 0,
|
||||
doc_id: 2,
|
||||
},
|
||||
feature: DocSortValuesAndFields {
|
||||
sort_key: DocSortValuesAndFields {
|
||||
sorts: vec![DocValueAndOrder {
|
||||
value: Some(3),
|
||||
order: Order::Asc,
|
||||
@@ -805,7 +821,7 @@ mod tests {
|
||||
segment_ord: 0,
|
||||
doc_id: 1,
|
||||
},
|
||||
feature: DocSortValuesAndFields {
|
||||
sort_key: DocSortValuesAndFields {
|
||||
sorts: vec![DocValueAndOrder {
|
||||
value: Some(5),
|
||||
order: Order::Asc,
|
||||
@@ -817,7 +833,7 @@ mod tests {
|
||||
|
||||
let mut collector = collector_with_capacity(3);
|
||||
for doc in docs.clone() {
|
||||
collector.collect(doc.feature, doc.doc);
|
||||
collector.collect(doc.sort_key, doc.doc);
|
||||
}
|
||||
|
||||
let res = collector.into_final_result();
|
||||
@@ -827,15 +843,15 @@ mod tests {
|
||||
super::TopHitsMetricResult {
|
||||
hits: vec![
|
||||
super::TopHitsVecEntry {
|
||||
sort: vec![docs[0].feature.sorts[0].value],
|
||||
sort: vec![docs[0].sort_key.sorts[0].value],
|
||||
doc_value_fields: Default::default(),
|
||||
},
|
||||
super::TopHitsVecEntry {
|
||||
sort: vec![docs[1].feature.sorts[0].value],
|
||||
sort: vec![docs[1].sort_key.sorts[0].value],
|
||||
doc_value_fields: Default::default(),
|
||||
},
|
||||
super::TopHitsVecEntry {
|
||||
sort: vec![docs[2].feature.sorts[0].value],
|
||||
sort: vec![docs[2].sort_key.sorts[0].value],
|
||||
doc_value_fields: Default::default(),
|
||||
},
|
||||
]
|
||||
@@ -873,7 +889,7 @@ mod tests {
|
||||
"mixed.*",
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
}))?;
|
||||
|
||||
let collector = AggregationCollector::from_aggs(d, Default::default());
|
||||
|
||||
@@ -133,7 +133,7 @@ mod agg_limits;
|
||||
pub mod agg_req;
|
||||
pub mod agg_result;
|
||||
pub mod bucket;
|
||||
mod buf_collector;
|
||||
pub(crate) mod cached_sub_aggs;
|
||||
mod collector;
|
||||
mod date;
|
||||
mod error;
|
||||
@@ -162,6 +162,19 @@ use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::tokenizer::TokenizerManager;
|
||||
|
||||
/// A bucket id is a dense identifier for a bucket within an aggregation.
|
||||
/// It is used to index into a Vec that hold per-bucket data.
|
||||
///
|
||||
/// For example, in a terms aggregation, each unique term will be assigned a incremental BucketId.
|
||||
/// This BucketId will be forwarded to sub-aggregations to identify the parent bucket.
|
||||
///
|
||||
/// This allows to have a single AggregationCollector instance per aggregation,
|
||||
/// that can handle multiple buckets efficiently.
|
||||
///
|
||||
/// The API to call sub-aggregations is therefore a &[(BucketId, &[DocId])].
|
||||
/// For that we'll need a buffer. One Vec per bucket aggregation is needed.
|
||||
pub type BucketId = u32;
|
||||
|
||||
/// Context parameters for aggregation execution
|
||||
///
|
||||
/// This struct holds shared resources needed during aggregation execution:
|
||||
@@ -335,19 +348,37 @@ impl Display for Key {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn convert_to_f64<const COLUMN_TYPE_ID: u8>(val: u64) -> f64 {
|
||||
if COLUMN_TYPE_ID == ColumnType::U64 as u8 {
|
||||
val as f64
|
||||
} else if COLUMN_TYPE_ID == ColumnType::I64 as u8
|
||||
|| COLUMN_TYPE_ID == ColumnType::DateTime as u8
|
||||
{
|
||||
i64::from_u64(val) as f64
|
||||
} else if COLUMN_TYPE_ID == ColumnType::F64 as u8 {
|
||||
f64::from_u64(val)
|
||||
} else if COLUMN_TYPE_ID == ColumnType::Bool as u8 {
|
||||
val as f64
|
||||
} else {
|
||||
panic!(
|
||||
"ColumnType ID {} cannot be converted to f64 metric",
|
||||
COLUMN_TYPE_ID
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Inverse of `to_fastfield_u64`. Used to convert to `f64` for metrics.
|
||||
///
|
||||
/// # Panics
|
||||
/// Only `u64`, `f64`, `date`, and `i64` are supported.
|
||||
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 {
|
||||
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: ColumnType) -> f64 {
|
||||
match field_type {
|
||||
ColumnType::U64 => val as f64,
|
||||
ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64,
|
||||
ColumnType::F64 => f64::from_u64(val),
|
||||
ColumnType::Bool => val as f64,
|
||||
_ => {
|
||||
panic!("unexpected type {field_type:?}. This should not happen")
|
||||
}
|
||||
ColumnType::U64 => convert_to_f64::<{ ColumnType::U64 as u8 }>(val),
|
||||
ColumnType::I64 => convert_to_f64::<{ ColumnType::I64 as u8 }>(val),
|
||||
ColumnType::F64 => convert_to_f64::<{ ColumnType::F64 as u8 }>(val),
|
||||
ColumnType::Bool => convert_to_f64::<{ ColumnType::Bool as u8 }>(val),
|
||||
ColumnType::DateTime => convert_to_f64::<{ ColumnType::DateTime as u8 }>(val),
|
||||
_ => panic!("unexpected type {field_type:?}. This should not happen"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,25 +8,67 @@ use std::fmt::Debug;
|
||||
pub(crate) use super::agg_limits::AggregationLimitsGuard;
|
||||
use super::intermediate_agg_result::IntermediateAggregationResults;
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::BucketId;
|
||||
|
||||
/// Monotonically increasing provider of BucketIds.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct BucketIdProvider(u32);
|
||||
impl BucketIdProvider {
|
||||
/// Get the next BucketId.
|
||||
pub fn next_bucket_id(&mut self) -> BucketId {
|
||||
let bucket_id = self.0;
|
||||
self.0 += 1;
|
||||
bucket_id
|
||||
}
|
||||
}
|
||||
|
||||
/// A SegmentAggregationCollector is used to collect aggregation results.
|
||||
pub trait SegmentAggregationCollector: CollectorClone + Debug {
|
||||
pub trait SegmentAggregationCollector: Debug {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()>;
|
||||
|
||||
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()>;
|
||||
|
||||
fn collect_block(
|
||||
/// Collect docs for multiple buckets in one call.
|
||||
/// Minimizes dynamic dispatch overhead when collecting many buckets.
|
||||
///
|
||||
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
|
||||
fn collect_multiple(
|
||||
&mut self,
|
||||
bucket_ids: &[BucketId],
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
debug_assert_eq!(bucket_ids.len(), docs.len());
|
||||
let mut start = 0;
|
||||
while start < bucket_ids.len() {
|
||||
let bucket_id = bucket_ids[start];
|
||||
let mut end = start + 1;
|
||||
while end < bucket_ids.len() && bucket_ids[end] == bucket_id {
|
||||
end += 1;
|
||||
}
|
||||
self.collect(bucket_id, &docs[start..end], agg_data)?;
|
||||
start = end;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Prepare the collector for collecting up to BucketId `max_bucket`.
|
||||
/// This is useful so we can split allocation ahead of time of collecting.
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()>;
|
||||
|
||||
/// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`.
|
||||
@@ -36,26 +78,7 @@ pub trait SegmentAggregationCollector: CollectorClone + Debug {
|
||||
}
|
||||
}
|
||||
|
||||
/// A helper trait to enable cloning of Box<dyn SegmentAggregationCollector>
|
||||
pub trait CollectorClone {
|
||||
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector>;
|
||||
}
|
||||
|
||||
impl<T> CollectorClone for T
|
||||
where T: 'static + SegmentAggregationCollector + Clone
|
||||
{
|
||||
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Box<dyn SegmentAggregationCollector> {
|
||||
fn clone(&self) -> Box<dyn SegmentAggregationCollector> {
|
||||
self.clone_box()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
#[derive(Default)]
|
||||
/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which
|
||||
/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one
|
||||
/// and can provide specialized versions instead, that remove some of its overhead.
|
||||
@@ -73,12 +96,13 @@ impl Debug for GenericSegmentAggregationResultsCollector {
|
||||
|
||||
impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
&mut self,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
parent_bucket_id: BucketId,
|
||||
) -> crate::Result<()> {
|
||||
for agg in self.aggs {
|
||||
agg.add_intermediate_aggregation_result(agg_data, results)?;
|
||||
for agg in &mut self.aggs {
|
||||
agg.add_intermediate_aggregation_result(agg_data, results, parent_bucket_id)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -86,23 +110,13 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
|
||||
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
parent_bucket_id: BucketId,
|
||||
docs: &[crate::DocId],
|
||||
agg_data: &mut AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
for collector in &mut self.aggs {
|
||||
collector.collect_block(docs, agg_data)?;
|
||||
collector.collect(parent_bucket_id, docs, agg_data)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -112,4 +126,15 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_max_bucket(
|
||||
&mut self,
|
||||
max_bucket: BucketId,
|
||||
agg_data: &AggregationsSegmentCtx,
|
||||
) -> crate::Result<()> {
|
||||
for collector in &mut self.aggs {
|
||||
collector.prepare_max_bucket(max_bucket, agg_data)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
use crate::collector::top_collector::{TopCollector, TopSegmentCollector};
|
||||
use crate::collector::{Collector, SegmentCollector};
|
||||
use crate::{DocAddress, DocId, Score, SegmentReader};
|
||||
|
||||
pub(crate) struct CustomScoreTopCollector<TCustomScorer, TScore = Score> {
|
||||
custom_scorer: TCustomScorer,
|
||||
collector: TopCollector<TScore>,
|
||||
}
|
||||
|
||||
impl<TCustomScorer, TScore> CustomScoreTopCollector<TCustomScorer, TScore>
|
||||
where TScore: Clone + PartialOrd
|
||||
{
|
||||
pub(crate) fn new(
|
||||
custom_scorer: TCustomScorer,
|
||||
collector: TopCollector<TScore>,
|
||||
) -> CustomScoreTopCollector<TCustomScorer, TScore> {
|
||||
CustomScoreTopCollector {
|
||||
custom_scorer,
|
||||
collector,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A custom segment scorer makes it possible to define any kind of score
|
||||
/// for a given document belonging to a specific segment.
|
||||
///
|
||||
/// It is the segment local version of the [`CustomScorer`].
|
||||
pub trait CustomSegmentScorer<TScore>: 'static {
|
||||
/// Computes the score of a specific `doc`.
|
||||
fn score(&mut self, doc: DocId) -> TScore;
|
||||
}
|
||||
|
||||
/// `CustomScorer` makes it possible to define any kind of score.
|
||||
///
|
||||
/// The `CustomerScorer` itself does not make much of the computation itself.
|
||||
/// Instead, it helps constructing `Self::Child` instances that will compute
|
||||
/// the score at a segment scale.
|
||||
pub trait CustomScorer<TScore>: Sync {
|
||||
/// Type of the associated [`CustomSegmentScorer`].
|
||||
type Child: CustomSegmentScorer<TScore>;
|
||||
/// Builds a child scorer for a specific segment. The child scorer is associated with
|
||||
/// a specific segment.
|
||||
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child>;
|
||||
}
|
||||
|
||||
impl<TCustomScorer, TScore> Collector for CustomScoreTopCollector<TCustomScorer, TScore>
|
||||
where
|
||||
TCustomScorer: CustomScorer<TScore> + Send + Sync,
|
||||
TScore: 'static + PartialOrd + Clone + Send + Sync,
|
||||
{
|
||||
type Fruit = Vec<(TScore, DocAddress)>;
|
||||
|
||||
type Child = CustomScoreTopSegmentCollector<TCustomScorer::Child, TScore>;
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
segment_reader: &SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let segment_collector = self.collector.for_segment(segment_local_id, segment_reader);
|
||||
let segment_scorer = self.custom_scorer.segment_scorer(segment_reader)?;
|
||||
Ok(CustomScoreTopSegmentCollector {
|
||||
segment_collector,
|
||||
segment_scorer,
|
||||
})
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> crate::Result<Self::Fruit> {
|
||||
self.collector.merge_fruits(segment_fruits)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CustomScoreTopSegmentCollector<T, TScore>
|
||||
where
|
||||
TScore: 'static + PartialOrd + Clone + Send + Sync + Sized,
|
||||
T: CustomSegmentScorer<TScore>,
|
||||
{
|
||||
segment_collector: TopSegmentCollector<TScore>,
|
||||
segment_scorer: T,
|
||||
}
|
||||
|
||||
impl<T, TScore> SegmentCollector for CustomScoreTopSegmentCollector<T, TScore>
|
||||
where
|
||||
TScore: 'static + PartialOrd + Clone + Send + Sync,
|
||||
T: 'static + CustomSegmentScorer<TScore>,
|
||||
{
|
||||
type Fruit = Vec<(TScore, DocAddress)>;
|
||||
|
||||
fn collect(&mut self, doc: DocId, _score: Score) {
|
||||
let score = self.segment_scorer.score(doc);
|
||||
self.segment_collector.collect(doc, score);
|
||||
}
|
||||
|
||||
fn harvest(self) -> Vec<(TScore, DocAddress)> {
|
||||
self.segment_collector.harvest()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, TScore, T> CustomScorer<TScore> for F
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(&SegmentReader) -> T,
|
||||
T: CustomSegmentScorer<TScore>,
|
||||
{
|
||||
type Child = T;
|
||||
|
||||
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child> {
|
||||
Ok((self)(segment_reader))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, TScore> CustomSegmentScorer<TScore> for F
|
||||
where F: 'static + FnMut(DocId) -> TScore
|
||||
{
|
||||
fn score(&mut self, doc: DocId) -> TScore {
|
||||
(self)(doc)
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ use std::marker::PhantomData;
|
||||
use columnar::{BytesColumn, Column, DynamicColumn, HasAssociatedColumnType};
|
||||
|
||||
use crate::collector::{Collector, SegmentCollector};
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocId, Score, SegmentReader};
|
||||
|
||||
/// The `FilterCollector` filters docs using a fast field value and a predicate.
|
||||
@@ -49,13 +50,13 @@ use crate::{DocId, Score, SegmentReader};
|
||||
///
|
||||
/// let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
/// let query = query_parser.parse_query("diary")?;
|
||||
/// let no_filter_collector = FilterCollector::new("price".to_string(), |value: u64| value > 20_120u64, TopDocs::with_limit(2));
|
||||
/// let no_filter_collector = FilterCollector::new("price".to_string(), |value: u64| value > 20_120u64, TopDocs::with_limit(2).order_by_score());
|
||||
/// let top_docs = searcher.search(&query, &no_filter_collector)?;
|
||||
///
|
||||
/// assert_eq!(top_docs.len(), 1);
|
||||
/// assert_eq!(top_docs[0].1, DocAddress::new(0, 1));
|
||||
///
|
||||
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new("price".to_string(), |value| value < 5u64, TopDocs::with_limit(2));
|
||||
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new("price".to_string(), |value| value < 5u64, TopDocs::with_limit(2).order_by_score());
|
||||
/// let filtered_top_docs = searcher.search(&query, &filter_all_collector)?;
|
||||
///
|
||||
/// assert_eq!(filtered_top_docs.len(), 0);
|
||||
@@ -104,6 +105,11 @@ where
|
||||
|
||||
type Child = FilterSegmentCollector<TCollector::Child, TPredicate, TPredicateValue>;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.collector.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -120,6 +126,7 @@ where
|
||||
segment_collector,
|
||||
predicate: self.predicate.clone(),
|
||||
t_predicate_value: PhantomData,
|
||||
filtered_docs: Vec::with_capacity(crate::COLLECT_BLOCK_BUFFER_LEN),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -140,6 +147,7 @@ pub struct FilterSegmentCollector<TSegmentCollector, TPredicate, TPredicateValue
|
||||
segment_collector: TSegmentCollector,
|
||||
predicate: TPredicate,
|
||||
t_predicate_value: PhantomData<TPredicateValue>,
|
||||
filtered_docs: Vec<DocId>,
|
||||
}
|
||||
|
||||
impl<TSegmentCollector, TPredicate, TPredicateValue>
|
||||
@@ -176,6 +184,20 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.filtered_docs.clear();
|
||||
for &doc in docs {
|
||||
// TODO: `accept_document` could be further optimized to do batch lookups of column
|
||||
// values for single-valued columns.
|
||||
if self.accept_document(doc) {
|
||||
self.filtered_docs.push(doc);
|
||||
}
|
||||
}
|
||||
if !self.filtered_docs.is_empty() {
|
||||
self.segment_collector.collect_block(&self.filtered_docs);
|
||||
}
|
||||
}
|
||||
|
||||
fn harvest(self) -> TSegmentCollector::Fruit {
|
||||
self.segment_collector.harvest()
|
||||
}
|
||||
@@ -218,7 +240,7 @@ where
|
||||
///
|
||||
/// let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
/// let query = query_parser.parse_query("diary")?;
|
||||
/// let filter_collector = BytesFilterCollector::new("barcode".to_string(), |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2));
|
||||
/// let filter_collector = BytesFilterCollector::new("barcode".to_string(), |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2).order_by_score());
|
||||
/// let top_docs = searcher.search(&query, &filter_collector)?;
|
||||
///
|
||||
/// assert_eq!(top_docs.len(), 1);
|
||||
@@ -258,6 +280,10 @@ where
|
||||
|
||||
type Child = BytesFilterSegmentCollector<TCollector::Child, TPredicate>;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.collector.check_schema(schema)
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -274,6 +300,7 @@ where
|
||||
segment_collector,
|
||||
predicate: self.predicate.clone(),
|
||||
buffer: Vec::new(),
|
||||
filtered_docs: Vec::with_capacity(crate::COLLECT_BLOCK_BUFFER_LEN),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -296,6 +323,7 @@ where TPredicate: 'static
|
||||
segment_collector: TSegmentCollector,
|
||||
predicate: TPredicate,
|
||||
buffer: Vec<u8>,
|
||||
filtered_docs: Vec<DocId>,
|
||||
}
|
||||
|
||||
impl<TSegmentCollector, TPredicate> BytesFilterSegmentCollector<TSegmentCollector, TPredicate>
|
||||
@@ -334,6 +362,20 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.filtered_docs.clear();
|
||||
for &doc in docs {
|
||||
// TODO: `accept_document` could be further optimized to do batch lookups of column
|
||||
// values for single-valued columns.
|
||||
if self.accept_document(doc) {
|
||||
self.filtered_docs.push(doc);
|
||||
}
|
||||
}
|
||||
if !self.filtered_docs.is_empty() {
|
||||
self.segment_collector.collect_block(&self.filtered_docs);
|
||||
}
|
||||
}
|
||||
|
||||
fn harvest(self) -> TSegmentCollector::Fruit {
|
||||
self.segment_collector.harvest()
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@
|
||||
//! # let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
//! # let query = query_parser.parse_query("diary")?;
|
||||
//! let (doc_count, top_docs): (usize, Vec<(Score, DocAddress)>) =
|
||||
//! searcher.search(&query, &(Count, TopDocs::with_limit(2)))?;
|
||||
//! searcher.search(&query, &(Count, TopDocs::with_limit(2).order_by_score()))?;
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
@@ -83,11 +83,15 @@
|
||||
|
||||
use downcast_rs::impl_downcast;
|
||||
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocId, Score, SegmentOrdinal, SegmentReader};
|
||||
|
||||
mod count_collector;
|
||||
pub use self::count_collector::Count;
|
||||
|
||||
/// Sort keys
|
||||
pub mod sort_key;
|
||||
|
||||
mod histogram_collector;
|
||||
pub use histogram_collector::HistogramCollector;
|
||||
|
||||
@@ -95,16 +99,13 @@ mod multi_collector;
|
||||
pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
|
||||
|
||||
mod top_collector;
|
||||
pub use self::top_collector::ComparableDoc;
|
||||
|
||||
mod top_score_collector;
|
||||
pub use self::top_collector::ComparableDoc;
|
||||
pub use self::top_score_collector::{TopDocs, TopNComputer};
|
||||
|
||||
mod custom_score_top_collector;
|
||||
pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer};
|
||||
|
||||
mod tweak_score_top_collector;
|
||||
pub use self::tweak_score_top_collector::{ScoreSegmentTweaker, ScoreTweaker};
|
||||
mod sort_key_top_collector;
|
||||
pub use self::sort_key::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
mod facet_collector;
|
||||
pub use self::facet_collector::{FacetCollector, FacetCounts};
|
||||
use crate::query::Weight;
|
||||
@@ -145,6 +146,11 @@ pub trait Collector: Sync + Send {
|
||||
/// Type of the `SegmentCollector` associated with this collector.
|
||||
type Child: SegmentCollector;
|
||||
|
||||
/// Returns an error if the schema is not compatible with the collector.
|
||||
fn check_schema(&self, _schema: &Schema) -> crate::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// `set_segment` is called before beginning to enumerate
|
||||
/// on this segment.
|
||||
fn for_segment(
|
||||
@@ -170,41 +176,50 @@ pub trait Collector: Sync + Send {
|
||||
segment_ord: u32,
|
||||
reader: &SegmentReader,
|
||||
) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> {
|
||||
let with_scoring = self.requires_scoring();
|
||||
let mut segment_collector = self.for_segment(segment_ord, reader)?;
|
||||
|
||||
match (reader.alive_bitset(), self.requires_scoring()) {
|
||||
(Some(alive_bitset), true) => {
|
||||
weight.for_each(reader, &mut |doc, score| {
|
||||
if alive_bitset.is_alive(doc) {
|
||||
segment_collector.collect(doc, score);
|
||||
}
|
||||
})?;
|
||||
}
|
||||
(Some(alive_bitset), false) => {
|
||||
weight.for_each_no_score(reader, &mut |docs| {
|
||||
for doc in docs.iter().cloned() {
|
||||
if alive_bitset.is_alive(doc) {
|
||||
segment_collector.collect(doc, 0.0);
|
||||
}
|
||||
}
|
||||
})?;
|
||||
}
|
||||
(None, true) => {
|
||||
weight.for_each(reader, &mut |doc, score| {
|
||||
segment_collector.collect(doc, score);
|
||||
})?;
|
||||
}
|
||||
(None, false) => {
|
||||
weight.for_each_no_score(reader, &mut |docs| {
|
||||
segment_collector.collect_block(docs);
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
default_collect_segment_impl(&mut segment_collector, weight, reader, with_scoring)?;
|
||||
Ok(segment_collector.harvest())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn default_collect_segment_impl<TSegmentCollector: SegmentCollector>(
|
||||
segment_collector: &mut TSegmentCollector,
|
||||
weight: &dyn Weight,
|
||||
reader: &SegmentReader,
|
||||
with_scoring: bool,
|
||||
) -> crate::Result<()> {
|
||||
match (reader.alive_bitset(), with_scoring) {
|
||||
(Some(alive_bitset), true) => {
|
||||
weight.for_each(reader, &mut |doc, score| {
|
||||
if alive_bitset.is_alive(doc) {
|
||||
segment_collector.collect(doc, score);
|
||||
}
|
||||
})?;
|
||||
}
|
||||
(Some(alive_bitset), false) => {
|
||||
weight.for_each_no_score(reader, &mut |docs| {
|
||||
for doc in docs.iter().cloned() {
|
||||
if alive_bitset.is_alive(doc) {
|
||||
segment_collector.collect(doc, 0.0);
|
||||
}
|
||||
}
|
||||
})?;
|
||||
}
|
||||
(None, true) => {
|
||||
weight.for_each(reader, &mut |doc, score| {
|
||||
segment_collector.collect(doc, score);
|
||||
})?;
|
||||
}
|
||||
(None, false) => {
|
||||
weight.for_each_no_score(reader, &mut |docs| {
|
||||
segment_collector.collect_block(docs);
|
||||
})?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl<TSegmentCollector: SegmentCollector> SegmentCollector for Option<TSegmentCollector> {
|
||||
type Fruit = Option<TSegmentCollector::Fruit>;
|
||||
|
||||
@@ -214,6 +229,12 @@ impl<TSegmentCollector: SegmentCollector> SegmentCollector for Option<TSegmentCo
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
if let Some(segment_collector) = self {
|
||||
segment_collector.collect_block(docs);
|
||||
}
|
||||
}
|
||||
|
||||
fn harvest(self) -> Self::Fruit {
|
||||
self.map(|segment_collector| segment_collector.harvest())
|
||||
}
|
||||
@@ -224,6 +245,13 @@ impl<TCollector: Collector> Collector for Option<TCollector> {
|
||||
|
||||
type Child = Option<<TCollector as Collector>::Child>;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
if let Some(underlying_collector) = self {
|
||||
underlying_collector.check_schema(schema)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: SegmentOrdinal,
|
||||
@@ -299,6 +327,12 @@ where
|
||||
type Fruit = (Left::Fruit, Right::Fruit);
|
||||
type Child = (Left::Child, Right::Child);
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)?;
|
||||
self.1.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -342,6 +376,11 @@ where
|
||||
self.1.collect(doc, score);
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.0.collect_block(docs);
|
||||
self.1.collect_block(docs);
|
||||
}
|
||||
|
||||
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
|
||||
(self.0.harvest(), self.1.harvest())
|
||||
}
|
||||
@@ -358,6 +397,13 @@ where
|
||||
type Fruit = (One::Fruit, Two::Fruit, Three::Fruit);
|
||||
type Child = (One::Child, Two::Child, Three::Child);
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)?;
|
||||
self.1.check_schema(schema)?;
|
||||
self.2.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -407,6 +453,12 @@ where
|
||||
self.2.collect(doc, score);
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.0.collect_block(docs);
|
||||
self.1.collect_block(docs);
|
||||
self.2.collect_block(docs);
|
||||
}
|
||||
|
||||
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
|
||||
(self.0.harvest(), self.1.harvest(), self.2.harvest())
|
||||
}
|
||||
@@ -424,6 +476,14 @@ where
|
||||
type Fruit = (One::Fruit, Two::Fruit, Three::Fruit, Four::Fruit);
|
||||
type Child = (One::Child, Two::Child, Three::Child, Four::Child);
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)?;
|
||||
self.1.check_schema(schema)?;
|
||||
self.2.check_schema(schema)?;
|
||||
self.3.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -482,6 +542,13 @@ where
|
||||
self.3.collect(doc, score);
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
self.0.collect_block(docs);
|
||||
self.1.collect_block(docs);
|
||||
self.2.collect_block(docs);
|
||||
self.3.collect_block(docs);
|
||||
}
|
||||
|
||||
fn harvest(self) -> <Self as SegmentCollector>::Fruit {
|
||||
(
|
||||
self.0.harvest(),
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::ops::Deref;
|
||||
|
||||
use super::{Collector, SegmentCollector};
|
||||
use crate::collector::Fruit;
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocId, Score, SegmentOrdinal, SegmentReader, TantivyError};
|
||||
|
||||
/// MultiFruit keeps Fruits from every nested Collector
|
||||
@@ -16,6 +17,10 @@ impl<TCollector: Collector> Collector for CollectorWrapper<TCollector> {
|
||||
type Fruit = Box<dyn Fruit>;
|
||||
type Child = Box<dyn BoxableSegmentCollector>;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
@@ -147,7 +152,7 @@ impl<TFruit: Fruit> FruitHandle<TFruit> {
|
||||
/// let searcher = reader.searcher();
|
||||
///
|
||||
/// let mut collectors = MultiCollector::new();
|
||||
/// let top_docs_handle = collectors.add_collector(TopDocs::with_limit(2));
|
||||
/// let top_docs_handle = collectors.add_collector(TopDocs::with_limit(2).order_by_score());
|
||||
/// let count_handle = collectors.add_collector(Count);
|
||||
/// let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
/// let query = query_parser.parse_query("diary").unwrap();
|
||||
@@ -194,6 +199,13 @@ impl Collector for MultiCollector<'_> {
|
||||
type Fruit = MultiFruit;
|
||||
type Child = MultiCollectorChild;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
for collector in &self.collector_wrappers {
|
||||
collector.check_schema(schema)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: SegmentOrdinal,
|
||||
@@ -250,6 +262,12 @@ impl SegmentCollector for MultiCollectorChild {
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_block(&mut self, docs: &[DocId]) {
|
||||
for child in &mut self.children {
|
||||
child.collect_block(docs);
|
||||
}
|
||||
}
|
||||
|
||||
fn harvest(self) -> MultiFruit {
|
||||
MultiFruit {
|
||||
sub_fruits: self
|
||||
@@ -293,7 +311,7 @@ mod tests {
|
||||
let query = TermQuery::new(term, IndexRecordOption::Basic);
|
||||
|
||||
let mut collectors = MultiCollector::new();
|
||||
let topdocs_handler = collectors.add_collector(TopDocs::with_limit(2));
|
||||
let topdocs_handler = collectors.add_collector(TopDocs::with_limit(2).order_by_score());
|
||||
let count_handler = collectors.add_collector(Count);
|
||||
let mut multifruits = searcher.search(&query, &collectors).unwrap();
|
||||
|
||||
|
||||
393
src/collector/sort_key/mod.rs
Normal file
393
src/collector/sort_key/mod.rs
Normal file
@@ -0,0 +1,393 @@
|
||||
mod order;
|
||||
mod sort_by_score;
|
||||
mod sort_by_static_fast_value;
|
||||
mod sort_by_string;
|
||||
mod sort_key_computer;
|
||||
|
||||
pub use order::*;
|
||||
pub use sort_by_score::SortBySimilarityScore;
|
||||
pub use sort_by_static_fast_value::SortByStaticFastValue;
|
||||
pub use sort_by_string::SortByString;
|
||||
pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::collector::sort_key::{SortBySimilarityScore, SortByStaticFastValue, SortByString};
|
||||
use crate::collector::{ComparableDoc, DocSetCollector, TopDocs};
|
||||
use crate::indexer::NoMergePolicy;
|
||||
use crate::query::{AllQuery, QueryParser};
|
||||
use crate::schema::{Schema, FAST, TEXT};
|
||||
use crate::{DocAddress, Document, Index, Order, Score, Searcher};
|
||||
|
||||
fn make_index() -> crate::Result<Index> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let id = schema_builder.add_u64_field("id", FAST);
|
||||
let city = schema_builder.add_text_field("city", TEXT | FAST);
|
||||
let catchphrase = schema_builder.add_text_field("catchphrase", TEXT);
|
||||
let altitude = schema_builder.add_f64_field("altitude", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
fn create_segment(index: &Index, docs: Vec<impl Document>) -> crate::Result<()> {
|
||||
let mut index_writer = index.writer_for_tests()?;
|
||||
index_writer.set_merge_policy(Box::new(NoMergePolicy));
|
||||
for doc in docs {
|
||||
index_writer.add_document(doc)?;
|
||||
}
|
||||
index_writer.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
create_segment(
|
||||
&index,
|
||||
vec![
|
||||
doc!(
|
||||
id => 0_u64,
|
||||
city => "austin",
|
||||
catchphrase => "Hills, Barbeque, Glow",
|
||||
altitude => 149.0,
|
||||
),
|
||||
doc!(
|
||||
id => 1_u64,
|
||||
city => "greenville",
|
||||
catchphrase => "Grow, Glow, Glow",
|
||||
altitude => 27.0,
|
||||
),
|
||||
],
|
||||
)?;
|
||||
create_segment(
|
||||
&index,
|
||||
vec![doc!(
|
||||
id => 2_u64,
|
||||
city => "tokyo",
|
||||
catchphrase => "Glow, Glow, Glow",
|
||||
altitude => 40.0,
|
||||
)],
|
||||
)?;
|
||||
create_segment(
|
||||
&index,
|
||||
vec![doc!(
|
||||
id => 3_u64,
|
||||
catchphrase => "No, No, No",
|
||||
altitude => 0.0,
|
||||
)],
|
||||
)?;
|
||||
Ok(index)
|
||||
}
|
||||
|
||||
// NOTE: You cannot determine the SegmentIds that will be generated for Segments
|
||||
// ahead of time, so DocAddresses must be mapped back to a unique id for each Searcher.
|
||||
fn id_mapping(searcher: &Searcher) -> HashMap<DocAddress, u64> {
|
||||
searcher
|
||||
.search(&AllQuery, &DocSetCollector)
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|doc_address| {
|
||||
let column = searcher.segment_readers()[doc_address.segment_ord as usize]
|
||||
.fast_fields()
|
||||
.u64("id")
|
||||
.unwrap();
|
||||
(doc_address, column.first(doc_address.doc_id).unwrap())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_string() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
#[track_caller]
|
||||
fn assert_query(
|
||||
index: &Index,
|
||||
order: Order,
|
||||
doc_range: Range<usize>,
|
||||
expected: Vec<(Option<String>, u64)>,
|
||||
) -> crate::Result<()> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
// Try as primitive.
|
||||
let top_collector = TopDocs::for_doc_range(doc_range)
|
||||
.order_by((SortByString::for_field("city"), order));
|
||||
let actual = searcher
|
||||
.search(&AllQuery, &top_collector)?
|
||||
.into_iter()
|
||||
.map(|(sort_key_opt, doc)| (sort_key_opt, ids[&doc]))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(actual, expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
0..4,
|
||||
vec![
|
||||
(Some("austin".to_owned()), 0),
|
||||
(Some("greenville".to_owned()), 1),
|
||||
(Some("tokyo".to_owned()), 2),
|
||||
(None, 3),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
0..3,
|
||||
vec![
|
||||
(Some("austin".to_owned()), 0),
|
||||
(Some("greenville".to_owned()), 1),
|
||||
(Some("tokyo".to_owned()), 2),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
0..2,
|
||||
vec![
|
||||
(Some("austin".to_owned()), 0),
|
||||
(Some("greenville".to_owned()), 1),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
0..1,
|
||||
vec![(Some("austin".to_string()), 0)],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
1..3,
|
||||
vec![
|
||||
(Some("greenville".to_owned()), 1),
|
||||
(Some("tokyo".to_owned()), 2),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Desc,
|
||||
0..4,
|
||||
vec![
|
||||
(Some("tokyo".to_owned()), 2),
|
||||
(Some("greenville".to_owned()), 1),
|
||||
(Some("austin".to_owned()), 0),
|
||||
(None, 3),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Desc,
|
||||
1..3,
|
||||
vec![
|
||||
(Some("greenville".to_owned()), 1),
|
||||
(Some("austin".to_owned()), 0),
|
||||
],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Desc,
|
||||
0..1,
|
||||
vec![(Some("tokyo".to_owned()), 2)],
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_f64() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
fn assert_query(
|
||||
index: &Index,
|
||||
order: Order,
|
||||
expected: Vec<(Option<f64>, u64)>,
|
||||
) -> crate::Result<()> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
// Try as primitive.
|
||||
let top_collector = TopDocs::with_limit(3)
|
||||
.order_by((SortByStaticFastValue::<f64>::for_field("altitude"), order));
|
||||
let actual = searcher
|
||||
.search(&AllQuery, &top_collector)?
|
||||
.into_iter()
|
||||
.map(|(altitude_opt, doc)| (altitude_opt, ids[&doc]))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(actual, expected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Asc,
|
||||
vec![(Some(0.0), 3), (Some(27.0), 1), (Some(40.0), 2)],
|
||||
)?;
|
||||
|
||||
assert_query(
|
||||
&index,
|
||||
Order::Desc,
|
||||
vec![(Some(149.0), 0), (Some(40.0), 2), (Some(27.0), 1)],
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_score() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
fn query(index: &Index, order: Order) -> crate::Result<Vec<(Score, u64)>> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
let top_collector = TopDocs::with_limit(4).order_by((SortBySimilarityScore, order));
|
||||
let field = index.schema().get_field("catchphrase").unwrap();
|
||||
let query_parser = QueryParser::for_index(index, vec![field]);
|
||||
let text_query = query_parser.parse_query("glow")?;
|
||||
|
||||
Ok(searcher
|
||||
.search(&text_query, &top_collector)?
|
||||
.into_iter()
|
||||
.map(|(score, doc)| (score, ids[&doc]))
|
||||
.collect())
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Desc)?,
|
||||
&[(0.5604893, 2), (0.4904281, 1), (0.35667497, 0),]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Asc)?,
|
||||
&[(0.35667497, 0), (0.4904281, 1), (0.5604893, 2),]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_by_score_then_string() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
|
||||
type SortKey = (Score, Option<String>);
|
||||
|
||||
fn query(
|
||||
index: &Index,
|
||||
score_order: Order,
|
||||
city_order: Order,
|
||||
) -> crate::Result<Vec<(SortKey, u64)>> {
|
||||
let searcher = index.reader()?.searcher();
|
||||
let ids = id_mapping(&searcher);
|
||||
|
||||
let top_collector = TopDocs::with_limit(4).order_by((
|
||||
(SortBySimilarityScore, score_order),
|
||||
(SortByString::for_field("city"), city_order),
|
||||
));
|
||||
Ok(searcher
|
||||
.search(&AllQuery, &top_collector)?
|
||||
.into_iter()
|
||||
.map(|(f, doc)| (f, ids[&doc]))
|
||||
.collect())
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Asc, Order::Asc)?,
|
||||
&[
|
||||
((1.0, Some("austin".to_owned())), 0),
|
||||
((1.0, Some("greenville".to_owned())), 1),
|
||||
((1.0, Some("tokyo".to_owned())), 2),
|
||||
((1.0, None), 3),
|
||||
]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
&query(&index, Order::Asc, Order::Desc)?,
|
||||
&[
|
||||
((1.0, Some("tokyo".to_owned())), 2),
|
||||
((1.0, Some("greenville".to_owned())), 1),
|
||||
((1.0, Some("austin".to_owned())), 0),
|
||||
((1.0, None), 3),
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_order_by_string_prop(
|
||||
order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)),
|
||||
limit in 1..64_usize,
|
||||
offset in 0..64_usize,
|
||||
segments_terms in
|
||||
proptest::collection::vec(
|
||||
proptest::collection::vec(0..32_u8, 1..32_usize),
|
||||
0..8_usize,
|
||||
)
|
||||
) {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let city = schema_builder.add_text_field("city", TEXT | FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests()?;
|
||||
|
||||
// A Vec<Vec<u8>>, where the outer Vec represents segments, and the inner Vec
|
||||
// represents terms.
|
||||
for segment_terms in segments_terms.into_iter() {
|
||||
for term in segment_terms.into_iter() {
|
||||
let term = format!("{term:0>3}");
|
||||
index_writer.add_document(doc!(
|
||||
city => term,
|
||||
))?;
|
||||
}
|
||||
index_writer.commit()?;
|
||||
}
|
||||
|
||||
let searcher = index.reader()?.searcher();
|
||||
let top_n_results = searcher.search(&AllQuery, &TopDocs::with_limit(limit)
|
||||
.and_offset(offset)
|
||||
.order_by_string_fast_field("city", order))?;
|
||||
let all_results = searcher.search(&AllQuery, &DocSetCollector)?.into_iter().map(|doc_address| {
|
||||
// Get the term for this address.
|
||||
let column = searcher.segment_readers()[doc_address.segment_ord as usize].fast_fields().str("city").unwrap().unwrap();
|
||||
let value = column.term_ords(doc_address.doc_id).next().map(|term_ord| {
|
||||
let mut city = Vec::new();
|
||||
column.dictionary().ord_to_term(term_ord, &mut city).unwrap();
|
||||
String::try_from(city).unwrap()
|
||||
});
|
||||
(value, doc_address)
|
||||
});
|
||||
|
||||
// Using the TopDocs collector should always be equivalent to sorting, skipping the
|
||||
// offset, and then taking the limit.
|
||||
let sorted_docs: Vec<_> = if order.is_desc() {
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
|
||||
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
|
||||
comparable_docs.sort();
|
||||
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
|
||||
} else {
|
||||
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
|
||||
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
|
||||
comparable_docs.sort();
|
||||
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
|
||||
};
|
||||
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();
|
||||
prop_assert_eq!(
|
||||
expected_docs,
|
||||
top_n_results
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
348
src/collector/sort_key/order.rs
Normal file
348
src/collector/sort_key/order.rs
Normal file
@@ -0,0 +1,348 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocId, Order, Score};
|
||||
|
||||
/// Comparator trait defining the order in which documents should be ordered.
|
||||
pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
|
||||
/// Return the order between two values.
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering;
|
||||
}
|
||||
|
||||
/// With the natural comparator, the top k collector will return
|
||||
/// the top documents in decreasing order.
|
||||
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct NaturalComparator;
|
||||
|
||||
impl<T: PartialOrd> Comparator<T> for NaturalComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
|
||||
lhs.partial_cmp(rhs).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Sorts document in reverse order.
|
||||
///
|
||||
/// If the sort key is None, it will considered as the lowest value, and will therefore appear
|
||||
/// first.
|
||||
///
|
||||
/// The ReverseComparator does not necessarily imply that the sort order is reversed compared
|
||||
/// to the NaturalComparator. In presence of a tie, both version will retain the higher doc ids.
|
||||
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ReverseComparator;
|
||||
|
||||
impl<T> Comparator<T> for ReverseComparator
|
||||
where NaturalComparator: Comparator<T>
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
|
||||
NaturalComparator.compare(rhs, lhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sorts document in reverse order, but considers None as having the lowest value.
|
||||
///
|
||||
/// This is usually what is wanted when sorting by a field in an ascending order.
|
||||
/// For instance, in a e-commerce website, if I sort by price ascending, I most likely want the
|
||||
/// cheapest items first, and the items without a price at last.
|
||||
#[derive(Debug, Copy, Clone, Default)]
|
||||
pub struct ReverseNoneIsLowerComparator;
|
||||
|
||||
impl<T> Comparator<Option<T>> for ReverseNoneIsLowerComparator
|
||||
where ReverseComparator: Comparator<T>
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs_opt: &Option<T>, rhs_opt: &Option<T>) -> Ordering {
|
||||
match (lhs_opt, rhs_opt) {
|
||||
(None, None) => Ordering::Equal,
|
||||
(None, Some(_)) => Ordering::Less,
|
||||
(Some(_), None) => Ordering::Greater,
|
||||
(Some(lhs), Some(rhs)) => ReverseComparator.compare(lhs, rhs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<u32> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<u64> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<f64> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<f32> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<i64> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparator<String> for ReverseNoneIsLowerComparator {
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
|
||||
ReverseComparator.compare(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// An enum representing the different sort orders.
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
|
||||
pub enum ComparatorEnum {
|
||||
/// Natural order (See [NaturalComparator])
|
||||
#[default]
|
||||
Natural,
|
||||
/// Reverse order (See [ReverseComparator])
|
||||
Reverse,
|
||||
/// Reverse order by treating None as the lowest value.(See [ReverseNoneLowerComparator])
|
||||
ReverseNoneLower,
|
||||
}
|
||||
|
||||
impl From<Order> for ComparatorEnum {
|
||||
fn from(order: Order) -> Self {
|
||||
match order {
|
||||
Order::Asc => ComparatorEnum::ReverseNoneLower,
|
||||
Order::Desc => ComparatorEnum::Natural,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Comparator<T> for ComparatorEnum
|
||||
where
|
||||
ReverseNoneIsLowerComparator: Comparator<T>,
|
||||
NaturalComparator: Comparator<T>,
|
||||
ReverseComparator: Comparator<T>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
|
||||
match self {
|
||||
ComparatorEnum::Natural => NaturalComparator.compare(lhs, rhs),
|
||||
ComparatorEnum::Reverse => ReverseComparator.compare(lhs, rhs),
|
||||
ComparatorEnum::ReverseNoneLower => ReverseNoneIsLowerComparator.compare(lhs, rhs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Head, Tail, LeftComparator, RightComparator> Comparator<(Head, Tail)>
|
||||
for (LeftComparator, RightComparator)
|
||||
where
|
||||
LeftComparator: Comparator<Head>,
|
||||
RightComparator: Comparator<Tail>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &(Head, Tail), rhs: &(Head, Tail)) -> Ordering {
|
||||
self.0
|
||||
.compare(&lhs.0, &rhs.0)
|
||||
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, (Type2, Type3))>
|
||||
for (Comparator1, Comparator2, Comparator3)
|
||||
where
|
||||
Comparator1: Comparator<Type1>,
|
||||
Comparator2: Comparator<Type2>,
|
||||
Comparator3: Comparator<Type3>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &(Type1, (Type2, Type3)), rhs: &(Type1, (Type2, Type3))) -> Ordering {
|
||||
self.0
|
||||
.compare(&lhs.0, &rhs.0)
|
||||
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
|
||||
.then_with(|| self.2.compare(&lhs.1 .1, &rhs.1 .1))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type1, Type2, Type3, Comparator1, Comparator2, Comparator3> Comparator<(Type1, Type2, Type3)>
|
||||
for (Comparator1, Comparator2, Comparator3)
|
||||
where
|
||||
Comparator1: Comparator<Type1>,
|
||||
Comparator2: Comparator<Type2>,
|
||||
Comparator3: Comparator<Type3>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(&self, lhs: &(Type1, Type2, Type3), rhs: &(Type1, Type2, Type3)) -> Ordering {
|
||||
self.0
|
||||
.compare(&lhs.0, &rhs.0)
|
||||
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
|
||||
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
|
||||
Comparator<(Type1, (Type2, (Type3, Type4)))>
|
||||
for (Comparator1, Comparator2, Comparator3, Comparator4)
|
||||
where
|
||||
Comparator1: Comparator<Type1>,
|
||||
Comparator2: Comparator<Type2>,
|
||||
Comparator3: Comparator<Type3>,
|
||||
Comparator4: Comparator<Type4>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(
|
||||
&self,
|
||||
lhs: &(Type1, (Type2, (Type3, Type4))),
|
||||
rhs: &(Type1, (Type2, (Type3, Type4))),
|
||||
) -> Ordering {
|
||||
self.0
|
||||
.compare(&lhs.0, &rhs.0)
|
||||
.then_with(|| self.1.compare(&lhs.1 .0, &rhs.1 .0))
|
||||
.then_with(|| self.2.compare(&lhs.1 .1 .0, &rhs.1 .1 .0))
|
||||
.then_with(|| self.3.compare(&lhs.1 .1 .1, &rhs.1 .1 .1))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Type1, Type2, Type3, Type4, Comparator1, Comparator2, Comparator3, Comparator4>
|
||||
Comparator<(Type1, Type2, Type3, Type4)>
|
||||
for (Comparator1, Comparator2, Comparator3, Comparator4)
|
||||
where
|
||||
Comparator1: Comparator<Type1>,
|
||||
Comparator2: Comparator<Type2>,
|
||||
Comparator3: Comparator<Type3>,
|
||||
Comparator4: Comparator<Type4>,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn compare(
|
||||
&self,
|
||||
lhs: &(Type1, Type2, Type3, Type4),
|
||||
rhs: &(Type1, Type2, Type3, Type4),
|
||||
) -> Ordering {
|
||||
self.0
|
||||
.compare(&lhs.0, &rhs.0)
|
||||
.then_with(|| self.1.compare(&lhs.1, &rhs.1))
|
||||
.then_with(|| self.2.compare(&lhs.2, &rhs.2))
|
||||
.then_with(|| self.3.compare(&lhs.3, &rhs.3))
|
||||
}
|
||||
}
|
||||
|
||||
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, ComparatorEnum)
|
||||
where
|
||||
TSortKeyComputer: SortKeyComputer,
|
||||
ComparatorEnum: Comparator<TSortKeyComputer::SortKey>,
|
||||
ComparatorEnum: Comparator<
|
||||
<<TSortKeyComputer as SortKeyComputer>::Child as SegmentSortKeyComputer>::SegmentSortKey,
|
||||
>,
|
||||
{
|
||||
type SortKey = TSortKeyComputer::SortKey;
|
||||
|
||||
type Child = SegmentSortKeyComputerWithComparator<TSortKeyComputer::Child, Self::Comparator>;
|
||||
|
||||
type Comparator = ComparatorEnum;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
self.0.requires_scoring()
|
||||
}
|
||||
|
||||
fn comparator(&self) -> Self::Comparator {
|
||||
self.1
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let child = self.0.segment_sort_key_computer(segment_reader)?;
|
||||
Ok(SegmentSortKeyComputerWithComparator {
|
||||
segment_sort_key_computer: child,
|
||||
comparator: self.comparator(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<TSortKeyComputer> SortKeyComputer for (TSortKeyComputer, Order)
|
||||
where
|
||||
TSortKeyComputer: SortKeyComputer,
|
||||
ComparatorEnum: Comparator<TSortKeyComputer::SortKey>,
|
||||
ComparatorEnum: Comparator<
|
||||
<<TSortKeyComputer as SortKeyComputer>::Child as SegmentSortKeyComputer>::SegmentSortKey,
|
||||
>,
|
||||
{
|
||||
type SortKey = TSortKeyComputer::SortKey;
|
||||
|
||||
type Child = SegmentSortKeyComputerWithComparator<TSortKeyComputer::Child, Self::Comparator>;
|
||||
|
||||
type Comparator = ComparatorEnum;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
self.0.requires_scoring()
|
||||
}
|
||||
|
||||
fn comparator(&self) -> Self::Comparator {
|
||||
self.1.into()
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let child = self.0.segment_sort_key_computer(segment_reader)?;
|
||||
Ok(SegmentSortKeyComputerWithComparator {
|
||||
segment_sort_key_computer: child,
|
||||
comparator: self.comparator(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A segment sort key computer with a custom ordering.
|
||||
pub struct SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator> {
|
||||
segment_sort_key_computer: TSegmentSortKeyComputer,
|
||||
comparator: TComparator,
|
||||
}
|
||||
|
||||
impl<TSegmentSortKeyComputer, TSegmentSortKey, TComparator> SegmentSortKeyComputer
|
||||
for SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator>
|
||||
where
|
||||
TSegmentSortKeyComputer: SegmentSortKeyComputer<SegmentSortKey = TSegmentSortKey>,
|
||||
TSegmentSortKey: PartialOrd + Clone + 'static + Sync + Send,
|
||||
TComparator: Comparator<TSegmentSortKey> + 'static + Sync + Send,
|
||||
{
|
||||
type SortKey = TSegmentSortKeyComputer::SortKey;
|
||||
type SegmentSortKey = TSegmentSortKey;
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
|
||||
self.segment_sort_key_computer.segment_sort_key(doc, score)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn compare_segment_sort_key(
|
||||
&self,
|
||||
left: &Self::SegmentSortKey,
|
||||
right: &Self::SegmentSortKey,
|
||||
) -> Ordering {
|
||||
self.comparator.compare(left, right)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
|
||||
self.segment_sort_key_computer
|
||||
.convert_segment_sort_key(sort_key)
|
||||
}
|
||||
}
|
||||
77
src/collector/sort_key/sort_by_score.rs
Normal file
77
src/collector/sort_key/sort_by_score.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use crate::collector::sort_key::NaturalComparator;
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
|
||||
use crate::{DocAddress, DocId, Score};
|
||||
|
||||
/// Sort by similarity score.
|
||||
#[derive(Clone, Debug, Copy)]
|
||||
pub struct SortBySimilarityScore;
|
||||
|
||||
impl SortKeyComputer for SortBySimilarityScore {
|
||||
type SortKey = Score;
|
||||
|
||||
type Child = SortBySimilarityScore;
|
||||
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
_segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
Ok(SortBySimilarityScore)
|
||||
}
|
||||
|
||||
// Sorting by score is special in that it allows for the Block-Wand optimization.
|
||||
fn collect_segment_top_k(
|
||||
&self,
|
||||
k: usize,
|
||||
weight: &dyn crate::query::Weight,
|
||||
reader: &crate::SegmentReader,
|
||||
segment_ord: u32,
|
||||
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
|
||||
let mut top_n: TopNComputer<Score, DocId, Self::Comparator> =
|
||||
TopNComputer::new_with_comparator(k, self.comparator());
|
||||
|
||||
if let Some(alive_bitset) = reader.alive_bitset() {
|
||||
let mut threshold = Score::MIN;
|
||||
top_n.threshold = Some(threshold);
|
||||
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
|
||||
if alive_bitset.is_deleted(doc) {
|
||||
return threshold;
|
||||
}
|
||||
top_n.push(score, doc);
|
||||
threshold = top_n.threshold.unwrap_or(Score::MIN);
|
||||
threshold
|
||||
})?;
|
||||
} else {
|
||||
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
|
||||
top_n.push(score, doc);
|
||||
top_n.threshold.unwrap_or(Score::MIN)
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(top_n
|
||||
.into_vec()
|
||||
.into_iter()
|
||||
.map(|cid| (cid.sort_key, DocAddress::new(segment_ord, cid.doc)))
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentSortKeyComputer for SortBySimilarityScore {
|
||||
type SortKey = Score;
|
||||
|
||||
type SegmentSortKey = Score;
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score {
|
||||
score
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, score: Score) -> Score {
|
||||
score
|
||||
}
|
||||
}
|
||||
98
src/collector/sort_key/sort_by_static_fast_value.rs
Normal file
98
src/collector/sort_key/sort_by_static_fast_value.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use columnar::Column;
|
||||
|
||||
use crate::collector::sort_key::NaturalComparator;
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::fastfield::{FastFieldNotAvailableError, FastValue};
|
||||
use crate::{DocId, Score, SegmentReader};
|
||||
|
||||
/// Sorts by a fast value (u64, i64, f64, bool).
|
||||
///
|
||||
/// The field must appear explicitly in the schema, with the right type, and declared as
|
||||
/// a fast field..
|
||||
///
|
||||
/// If the field is multivalued, only the first value is considered.
|
||||
///
|
||||
/// Documents that do not have this value are still considered.
|
||||
/// Their sort key will simply be `None`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SortByStaticFastValue<T: FastValue> {
|
||||
field: String,
|
||||
typ: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: FastValue> SortByStaticFastValue<T> {
|
||||
/// Creates a new `SortByStaticFastValue` instance for the given field.
|
||||
pub fn for_field(column_name: impl ToString) -> SortByStaticFastValue<T> {
|
||||
Self {
|
||||
field: column_name.to_string(),
|
||||
typ: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
|
||||
type Child = SortByFastValueSegmentSortKeyComputer<T>;
|
||||
|
||||
type SortKey = Option<T>;
|
||||
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> {
|
||||
// At the segment sort key computer level, we rely on the u64 representation.
|
||||
// The mapping is monotonic, so it is sufficient to compute our top-K docs.
|
||||
let field = schema.get_field(&self.field)?;
|
||||
let field_entry = schema.get_field_entry(field);
|
||||
if !field_entry.is_fast() {
|
||||
return Err(crate::TantivyError::SchemaError(format!(
|
||||
"Field `{}` is not a fast field.",
|
||||
self.field,
|
||||
)));
|
||||
}
|
||||
let schema_type = field_entry.field_type().value_type();
|
||||
if schema_type != T::to_type() {
|
||||
return Err(crate::TantivyError::SchemaError(format!(
|
||||
"Field `{}` is of type {schema_type:?}, not of the type {:?}.",
|
||||
&self.field,
|
||||
T::to_type()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let sort_column_opt = segment_reader.fast_fields().u64_lenient(&self.field)?;
|
||||
let (sort_column, _sort_column_type) =
|
||||
sort_column_opt.ok_or_else(|| FastFieldNotAvailableError {
|
||||
field_name: self.field.clone(),
|
||||
})?;
|
||||
Ok(SortByFastValueSegmentSortKeyComputer {
|
||||
sort_column,
|
||||
typ: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SortByFastValueSegmentSortKeyComputer<T> {
|
||||
sort_column: Column<u64>,
|
||||
typ: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer<T> {
|
||||
type SortKey = Option<T>;
|
||||
|
||||
type SegmentSortKey = Option<u64>;
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey {
|
||||
self.sort_column.first(doc)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
|
||||
sort_key.map(T::from_u64)
|
||||
}
|
||||
}
|
||||
72
src/collector/sort_key/sort_by_string.rs
Normal file
72
src/collector/sort_key/sort_by_string.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use columnar::StrColumn;
|
||||
|
||||
use crate::collector::sort_key::NaturalComparator;
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::termdict::TermOrdinal;
|
||||
use crate::{DocId, Score};
|
||||
|
||||
/// Sort by the first value of a string column.
|
||||
///
|
||||
/// The string can be dynamic (coming from a json field)
|
||||
/// or static (being specificaly defined in the configuration).
|
||||
///
|
||||
/// If the field is multivalued, only the first value is considered.
|
||||
///
|
||||
/// Documents that do not have this value are still considered.
|
||||
/// Their sort key will simply be `None`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SortByString {
|
||||
column_name: String,
|
||||
}
|
||||
|
||||
impl SortByString {
|
||||
/// Creates a new sort by string sort key computer.
|
||||
pub fn for_field(column_name: impl ToString) -> Self {
|
||||
SortByString {
|
||||
column_name: column_name.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SortKeyComputer for SortByString {
|
||||
type SortKey = Option<String>;
|
||||
|
||||
type Child = ByStringColumnSegmentSortKeyComputer;
|
||||
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let str_column_opt = segment_reader.fast_fields().str(&self.column_name)?;
|
||||
Ok(ByStringColumnSegmentSortKeyComputer { str_column_opt })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ByStringColumnSegmentSortKeyComputer {
|
||||
str_column_opt: Option<StrColumn>,
|
||||
}
|
||||
|
||||
impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
|
||||
type SortKey = Option<String>;
|
||||
|
||||
type SegmentSortKey = Option<TermOrdinal>;
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
|
||||
let str_column = self.str_column_opt.as_ref()?;
|
||||
str_column.ords().first(doc)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> {
|
||||
let term_ord = term_ord_opt?;
|
||||
let str_column = self.str_column_opt.as_ref()?;
|
||||
let mut bytes = Vec::new();
|
||||
str_column
|
||||
.dictionary()
|
||||
.ord_to_term(term_ord, &mut bytes)
|
||||
.ok()?;
|
||||
String::try_from(bytes).ok()
|
||||
}
|
||||
}
|
||||
631
src/collector/sort_key/sort_key_computer.rs
Normal file
631
src/collector/sort_key/sort_key_computer.rs
Normal file
@@ -0,0 +1,631 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use crate::collector::sort_key::{Comparator, NaturalComparator};
|
||||
use crate::collector::sort_key_top_collector::TopBySortKeySegmentCollector;
|
||||
use crate::collector::{default_collect_segment_impl, SegmentCollector as _, TopNComputer};
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
|
||||
|
||||
/// A `SegmentSortKeyComputer` makes it possible to modify the default score
|
||||
/// for a given document belonging to a specific segment.
|
||||
///
|
||||
/// It is the segment local version of the [`SortKeyComputer`].
|
||||
pub trait SegmentSortKeyComputer: 'static {
|
||||
/// The final score being emitted.
|
||||
type SortKey: 'static + PartialOrd + Send + Sync + Clone;
|
||||
|
||||
/// Sort key used by at the segment level by the `SegmentSortKeyComputer`.
|
||||
///
|
||||
/// It is typically small like a `u64`, and is meant to be converted
|
||||
/// to the final score at the end of the collection of the segment.
|
||||
type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone;
|
||||
|
||||
/// Computes the sort key for the given document and score.
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey;
|
||||
|
||||
/// Computes the sort key and pushes the document in a TopN Computer.
|
||||
///
|
||||
/// When using a tuple as the sorting key, the sort key is evaluated in a lazy manner.
|
||||
#[inline(always)]
|
||||
fn compute_sort_key_and_collect<C: Comparator<Self::SegmentSortKey>>(
|
||||
&mut self,
|
||||
doc: DocId,
|
||||
score: Score,
|
||||
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
|
||||
) {
|
||||
let sort_key = self.segment_sort_key(doc, score);
|
||||
top_n_computer.push(sort_key, doc);
|
||||
}
|
||||
|
||||
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
|
||||
/// its ordering.
|
||||
///
|
||||
/// This method must be consistent with the `SortKey` ordering.
|
||||
#[inline(always)]
|
||||
fn compare_segment_sort_key(
|
||||
&self,
|
||||
left: &Self::SegmentSortKey,
|
||||
right: &Self::SegmentSortKey,
|
||||
) -> Ordering {
|
||||
NaturalComparator.compare(left, right)
|
||||
}
|
||||
|
||||
/// Implementing this method makes it possible to avoid computing
|
||||
/// a sort_key entirely if we can assess that it won't pass a threshold
|
||||
/// with a partial computation.
|
||||
///
|
||||
/// This is currently used for lexicographic sorting.
|
||||
fn accept_sort_key_lazy(
|
||||
&mut self,
|
||||
doc_id: DocId,
|
||||
score: Score,
|
||||
threshold: &Self::SegmentSortKey,
|
||||
) -> Option<(Ordering, Self::SegmentSortKey)> {
|
||||
let sort_key = self.segment_sort_key(doc_id, score);
|
||||
let cmp = self.compare_segment_sort_key(&sort_key, threshold);
|
||||
if cmp == Ordering::Less {
|
||||
None
|
||||
} else {
|
||||
Some((cmp, sort_key))
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a segment level sort key into the global sort key.
|
||||
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey;
|
||||
}
|
||||
|
||||
/// `SortKeyComputer` defines the sort key to be used by a TopK Collector.
|
||||
///
|
||||
/// The `SortKeyComputer` itself does not make much of the computation itself.
|
||||
/// Instead, it helps constructing `Self::Child` instances that will compute
|
||||
/// the sort key at a segment scale.
|
||||
pub trait SortKeyComputer: Sync {
|
||||
/// The sort key type.
|
||||
type SortKey: 'static + Send + Sync + PartialOrd + Clone + std::fmt::Debug;
|
||||
/// Type of the associated [`SegmentSortKeyComputer`].
|
||||
type Child: SegmentSortKeyComputer<SortKey = Self::SortKey>;
|
||||
/// Comparator type.
|
||||
type Comparator: Comparator<Self::SortKey>
|
||||
+ Comparator<<Self::Child as SegmentSortKeyComputer>::SegmentSortKey>
|
||||
+ 'static;
|
||||
|
||||
/// Checks whether the schema is compatible with the sort key computer.
|
||||
fn check_schema(&self, _schema: &Schema) -> crate::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the sort key comparator.
|
||||
fn comparator(&self) -> Self::Comparator {
|
||||
Self::Comparator::default()
|
||||
}
|
||||
|
||||
/// Indicates whether the sort key actually uses the similarity score (by default BM25).
|
||||
/// If set to false, the similary score might not be computed (as an optimization),
|
||||
/// and the score fed in the segment sort key computer could take any value.
|
||||
fn requires_scoring(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Sorting by score has a overriding implementation for BM25 scores, using Block-WAND.
|
||||
fn collect_segment_top_k(
|
||||
&self,
|
||||
k: usize,
|
||||
weight: &dyn crate::query::Weight,
|
||||
reader: &crate::SegmentReader,
|
||||
segment_ord: u32,
|
||||
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
|
||||
let with_scoring = self.requires_scoring();
|
||||
let segment_sort_key_computer = self.segment_sort_key_computer(reader)?;
|
||||
let topn_computer = TopNComputer::new_with_comparator(k, self.comparator());
|
||||
let mut segment_top_key_collector = TopBySortKeySegmentCollector {
|
||||
topn_computer,
|
||||
segment_ord,
|
||||
segment_sort_key_computer,
|
||||
};
|
||||
default_collect_segment_impl(&mut segment_top_key_collector, weight, reader, with_scoring)?;
|
||||
Ok(segment_top_key_collector.harvest())
|
||||
}
|
||||
|
||||
/// Builds a child sort key computer for a specific segment.
|
||||
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child>;
|
||||
}
|
||||
|
||||
impl<HeadSortKeyComputer, TailSortKeyComputer> SortKeyComputer
|
||||
for (HeadSortKeyComputer, TailSortKeyComputer)
|
||||
where
|
||||
HeadSortKeyComputer: SortKeyComputer,
|
||||
TailSortKeyComputer: SortKeyComputer,
|
||||
{
|
||||
type SortKey = (
|
||||
<HeadSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
|
||||
<TailSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
|
||||
);
|
||||
type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child);
|
||||
|
||||
type Comparator = (
|
||||
HeadSortKeyComputer::Comparator,
|
||||
TailSortKeyComputer::Comparator,
|
||||
);
|
||||
|
||||
fn comparator(&self) -> Self::Comparator {
|
||||
(self.0.comparator(), self.1.comparator())
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
|
||||
Ok((
|
||||
self.0.segment_sort_key_computer(segment_reader)?,
|
||||
self.1.segment_sort_key_computer(segment_reader)?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Checks whether the schema is compatible with the sort key computer.
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)?;
|
||||
self.1.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Indicates whether the sort key actually uses the similarity score (by default BM25).
|
||||
/// If set to false, the similary score might not be computed (as an optimization),
|
||||
/// and the score fed in the segment sort key computer could take any value.
|
||||
fn requires_scoring(&self) -> bool {
|
||||
self.0.requires_scoring() || self.1.requires_scoring()
|
||||
}
|
||||
}
|
||||
|
||||
impl<HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer> SegmentSortKeyComputer
|
||||
for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer)
|
||||
where
|
||||
HeadSegmentSortKeyComputer: SegmentSortKeyComputer,
|
||||
TailSegmentSortKeyComputer: SegmentSortKeyComputer,
|
||||
{
|
||||
type SortKey = (
|
||||
HeadSegmentSortKeyComputer::SortKey,
|
||||
TailSegmentSortKeyComputer::SortKey,
|
||||
);
|
||||
type SegmentSortKey = (
|
||||
HeadSegmentSortKeyComputer::SegmentSortKey,
|
||||
TailSegmentSortKeyComputer::SegmentSortKey,
|
||||
);
|
||||
|
||||
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
|
||||
/// its ordering.
|
||||
///
|
||||
/// By default, it uses the natural ordering.
|
||||
#[inline]
|
||||
fn compare_segment_sort_key(
|
||||
&self,
|
||||
left: &Self::SegmentSortKey,
|
||||
right: &Self::SegmentSortKey,
|
||||
) -> Ordering {
|
||||
self.0
|
||||
.compare_segment_sort_key(&left.0, &right.0)
|
||||
.then_with(|| self.1.compare_segment_sort_key(&left.1, &right.1))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn compute_sort_key_and_collect<C: Comparator<Self::SegmentSortKey>>(
|
||||
&mut self,
|
||||
doc: DocId,
|
||||
score: Score,
|
||||
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
|
||||
) {
|
||||
let sort_key: Self::SegmentSortKey;
|
||||
if let Some(threshold) = &top_n_computer.threshold {
|
||||
if let Some((_cmp, lazy_sort_key)) = self.accept_sort_key_lazy(doc, score, threshold) {
|
||||
sort_key = lazy_sort_key;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
sort_key = self.segment_sort_key(doc, score);
|
||||
};
|
||||
top_n_computer.append_doc(doc, sort_key);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
|
||||
let head_sort_key = self.0.segment_sort_key(doc, score);
|
||||
let tail_sort_key = self.1.segment_sort_key(doc, score);
|
||||
(head_sort_key, tail_sort_key)
|
||||
}
|
||||
|
||||
fn accept_sort_key_lazy(
|
||||
&mut self,
|
||||
doc_id: DocId,
|
||||
score: Score,
|
||||
threshold: &Self::SegmentSortKey,
|
||||
) -> Option<(Ordering, Self::SegmentSortKey)> {
|
||||
let (head_threshold, tail_threshold) = threshold;
|
||||
let (head_cmp, head_sort_key) =
|
||||
self.0.accept_sort_key_lazy(doc_id, score, head_threshold)?;
|
||||
if head_cmp == Ordering::Equal {
|
||||
let (tail_cmp, tail_sort_key) =
|
||||
self.1.accept_sort_key_lazy(doc_id, score, tail_threshold)?;
|
||||
Some((tail_cmp, (head_sort_key, tail_sort_key)))
|
||||
} else {
|
||||
let tail_sort_key = self.1.segment_sort_key(doc_id, score);
|
||||
Some((head_cmp, (head_sort_key, tail_sort_key)))
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
|
||||
let (head_sort_key, tail_sort_key) = sort_key;
|
||||
(
|
||||
self.0.convert_segment_sort_key(head_sort_key),
|
||||
self.1.convert_segment_sort_key(tail_sort_key),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// This struct is used as an adapter to take a sort key computer and map its score to another
|
||||
/// new sort key.
|
||||
pub struct MappedSegmentSortKeyComputer<T, PreviousSortKey, NewSortKey> {
|
||||
sort_key_computer: T,
|
||||
map: fn(PreviousSortKey) -> NewSortKey,
|
||||
}
|
||||
|
||||
impl<T, PreviousScore, NewScore> SegmentSortKeyComputer
|
||||
for MappedSegmentSortKeyComputer<T, PreviousScore, NewScore>
|
||||
where
|
||||
T: SegmentSortKeyComputer<SortKey = PreviousScore>,
|
||||
PreviousScore: 'static + Clone + Send + Sync + PartialOrd,
|
||||
NewScore: 'static + Clone + Send + Sync + PartialOrd,
|
||||
{
|
||||
type SortKey = NewScore;
|
||||
type SegmentSortKey = T::SegmentSortKey;
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
|
||||
self.sort_key_computer.segment_sort_key(doc, score)
|
||||
}
|
||||
|
||||
fn accept_sort_key_lazy(
|
||||
&mut self,
|
||||
doc_id: DocId,
|
||||
score: Score,
|
||||
threshold: &Self::SegmentSortKey,
|
||||
) -> Option<(Ordering, Self::SegmentSortKey)> {
|
||||
self.sort_key_computer
|
||||
.accept_sort_key_lazy(doc_id, score, threshold)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn compute_sort_key_and_collect<C: Comparator<Self::SegmentSortKey>>(
|
||||
&mut self,
|
||||
doc: DocId,
|
||||
score: Score,
|
||||
top_n_computer: &mut TopNComputer<Self::SegmentSortKey, DocId, C>,
|
||||
) {
|
||||
self.sort_key_computer
|
||||
.compute_sort_key_and_collect(doc, score, top_n_computer);
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> Self::SortKey {
|
||||
(self.map)(
|
||||
self.sort_key_computer
|
||||
.convert_segment_sort_key(segment_sort_key),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// We then re-use our (head, tail) implement and our mapper by seeing mapping any tuple (a, b, c,
|
||||
// ...) as the chain (a, (b, (c, ...)))
|
||||
|
||||
impl<SortKeyComputer1, SortKeyComputer2, SortKeyComputer3> SortKeyComputer
|
||||
for (SortKeyComputer1, SortKeyComputer2, SortKeyComputer3)
|
||||
where
|
||||
SortKeyComputer1: SortKeyComputer,
|
||||
SortKeyComputer2: SortKeyComputer,
|
||||
SortKeyComputer3: SortKeyComputer,
|
||||
{
|
||||
type SortKey = (
|
||||
SortKeyComputer1::SortKey,
|
||||
SortKeyComputer2::SortKey,
|
||||
SortKeyComputer3::SortKey,
|
||||
);
|
||||
type Child = MappedSegmentSortKeyComputer<
|
||||
<(SortKeyComputer1, (SortKeyComputer2, SortKeyComputer3)) as SortKeyComputer>::Child,
|
||||
(
|
||||
SortKeyComputer1::SortKey,
|
||||
(SortKeyComputer2::SortKey, SortKeyComputer3::SortKey),
|
||||
),
|
||||
Self::SortKey,
|
||||
>;
|
||||
|
||||
type Comparator = (
|
||||
SortKeyComputer1::Comparator,
|
||||
SortKeyComputer2::Comparator,
|
||||
SortKeyComputer3::Comparator,
|
||||
);
|
||||
|
||||
fn comparator(&self) -> Self::Comparator {
|
||||
(
|
||||
self.0.comparator(),
|
||||
self.1.comparator(),
|
||||
self.2.comparator(),
|
||||
)
|
||||
}
|
||||
|
||||
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
|
||||
let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?;
|
||||
let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?;
|
||||
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
|
||||
let map = |(sort_key1, (sort_key2, sort_key3))| (sort_key1, sort_key2, sort_key3);
|
||||
Ok(MappedSegmentSortKeyComputer {
|
||||
sort_key_computer: (sort_key_computer1, (sort_key_computer2, sort_key_computer3)),
|
||||
map,
|
||||
})
|
||||
}
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)?;
|
||||
self.1.check_schema(schema)?;
|
||||
self.2.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
self.0.requires_scoring() || self.1.requires_scoring() || self.2.requires_scoring()
|
||||
}
|
||||
}
|
||||
|
||||
impl<SortKeyComputer1, SortKeyComputer2, SortKeyComputer3, SortKeyComputer4> SortKeyComputer
|
||||
for (
|
||||
SortKeyComputer1,
|
||||
SortKeyComputer2,
|
||||
SortKeyComputer3,
|
||||
SortKeyComputer4,
|
||||
)
|
||||
where
|
||||
SortKeyComputer1: SortKeyComputer,
|
||||
SortKeyComputer2: SortKeyComputer,
|
||||
SortKeyComputer3: SortKeyComputer,
|
||||
SortKeyComputer4: SortKeyComputer,
|
||||
{
|
||||
type Child = MappedSegmentSortKeyComputer<
|
||||
<(
|
||||
SortKeyComputer1,
|
||||
(SortKeyComputer2, (SortKeyComputer3, SortKeyComputer4)),
|
||||
) as SortKeyComputer>::Child,
|
||||
(
|
||||
SortKeyComputer1::SortKey,
|
||||
(
|
||||
SortKeyComputer2::SortKey,
|
||||
(SortKeyComputer3::SortKey, SortKeyComputer4::SortKey),
|
||||
),
|
||||
),
|
||||
Self::SortKey,
|
||||
>;
|
||||
type SortKey = (
|
||||
SortKeyComputer1::SortKey,
|
||||
SortKeyComputer2::SortKey,
|
||||
SortKeyComputer3::SortKey,
|
||||
SortKeyComputer4::SortKey,
|
||||
);
|
||||
type Comparator = (
|
||||
SortKeyComputer1::Comparator,
|
||||
SortKeyComputer2::Comparator,
|
||||
SortKeyComputer3::Comparator,
|
||||
SortKeyComputer4::Comparator,
|
||||
);
|
||||
|
||||
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
|
||||
let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?;
|
||||
let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?;
|
||||
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
|
||||
let sort_key_computer4 = self.3.segment_sort_key_computer(segment_reader)?;
|
||||
Ok(MappedSegmentSortKeyComputer {
|
||||
sort_key_computer: (
|
||||
sort_key_computer1,
|
||||
(sort_key_computer2, (sort_key_computer3, sort_key_computer4)),
|
||||
),
|
||||
map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| {
|
||||
(sort_key1, sort_key2, sort_key3, sort_key4)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.0.check_schema(schema)?;
|
||||
self.1.check_schema(schema)?;
|
||||
self.2.check_schema(schema)?;
|
||||
self.3.check_schema(schema)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
self.0.requires_scoring()
|
||||
|| self.1.requires_scoring()
|
||||
|| self.2.requires_scoring()
|
||||
|| self.3.requires_scoring()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, SegmentF, TSortKey> SortKeyComputer for F
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(&SegmentReader) -> SegmentF,
|
||||
SegmentF: 'static + FnMut(DocId) -> TSortKey,
|
||||
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
|
||||
{
|
||||
type SortKey = TSortKey;
|
||||
type Child = SegmentF;
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
|
||||
Ok((self)(segment_reader))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, TSortKey> SegmentSortKeyComputer for F
|
||||
where
|
||||
F: 'static + FnMut(DocId) -> TSortKey,
|
||||
TSortKey: 'static + PartialOrd + Clone + Send + Sync,
|
||||
{
|
||||
type SortKey = TSortKey;
|
||||
type SegmentSortKey = TSortKey;
|
||||
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey {
|
||||
(self)(doc)
|
||||
}
|
||||
|
||||
/// Convert a segment level score into the global level score.
|
||||
fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey {
|
||||
sort_key
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::cmp::Ordering;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocId, Index, Order, SegmentReader};
|
||||
|
||||
fn build_test_index() -> Index {
|
||||
let schema = Schema::builder().build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
index_writer
|
||||
.add_document(crate::TantivyDocument::default())
|
||||
.unwrap();
|
||||
index_writer.commit().unwrap();
|
||||
index
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lazy_score_computer() {
|
||||
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
|
||||
let call_count = Arc::new(AtomicUsize::new(0));
|
||||
let call_count_clone = call_count.clone();
|
||||
let score_computer_secondary = move |_segment_reader: &SegmentReader| {
|
||||
let call_count_new_clone = call_count_clone.clone();
|
||||
move |_doc: DocId| {
|
||||
call_count_new_clone.fetch_add(1, AtomicOrdering::SeqCst);
|
||||
"b"
|
||||
}
|
||||
};
|
||||
let lazy_score_computer = (score_computer_primary, score_computer_secondary);
|
||||
let index = build_test_index();
|
||||
let searcher = index.reader().unwrap().searcher();
|
||||
let mut segment_sort_key_computer = lazy_score_computer
|
||||
.segment_sort_key_computer(searcher.segment_reader(0))
|
||||
.unwrap();
|
||||
let expected_sort_key = (200, "b");
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, "a"));
|
||||
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 1);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, "c"));
|
||||
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 2);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, "a"));
|
||||
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 3);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, "c"));
|
||||
assert!(sort_key_opt.is_none());
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, "a"));
|
||||
assert_eq!(sort_key_opt, None);
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, "c"));
|
||||
assert_eq!(sort_key_opt, None);
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &expected_sort_key);
|
||||
assert_eq!(sort_key_opt, Some((Ordering::Equal, expected_sort_key)));
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 5);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lazy_score_computer_dynamic_ordering() {
|
||||
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
|
||||
let call_count = Arc::new(AtomicUsize::new(0));
|
||||
let call_count_clone = call_count.clone();
|
||||
let score_computer_secondary = move |_segment_reader: &SegmentReader| {
|
||||
let call_count_new_clone = call_count_clone.clone();
|
||||
move |_doc: DocId| {
|
||||
call_count_new_clone.fetch_add(1, AtomicOrdering::SeqCst);
|
||||
2u32
|
||||
}
|
||||
};
|
||||
let lazy_score_computer = (
|
||||
(score_computer_primary, Order::Desc),
|
||||
(score_computer_secondary, Order::Asc),
|
||||
);
|
||||
let index = build_test_index();
|
||||
let searcher = index.reader().unwrap().searcher();
|
||||
let mut segment_sort_key_computer = lazy_score_computer
|
||||
.segment_sort_key_computer(searcher.segment_reader(0))
|
||||
.unwrap();
|
||||
let expected_sort_key = (200, 2u32);
|
||||
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, 1u32));
|
||||
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 1);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(100u32, 3u32));
|
||||
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 2);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, 1u32));
|
||||
assert!(sort_key_opt.is_none());
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 3);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(200u32, 3u32));
|
||||
assert_eq!(sort_key_opt, Some((Ordering::Greater, expected_sort_key)));
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, 1u32));
|
||||
assert_eq!(sort_key_opt, None);
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &(300u32, 3u32));
|
||||
assert_eq!(sort_key_opt, None);
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 4);
|
||||
}
|
||||
{
|
||||
let sort_key_opt =
|
||||
segment_sort_key_computer.accept_sort_key_lazy(0u32, 1f32, &expected_sort_key);
|
||||
assert_eq!(sort_key_opt, Some((Ordering::Equal, expected_sort_key)));
|
||||
assert_eq!(call_count.load(AtomicOrdering::SeqCst), 5);
|
||||
}
|
||||
assert_eq!(
|
||||
segment_sort_key_computer.convert_segment_sort_key(expected_sort_key),
|
||||
(200u32, 2u32)
|
||||
);
|
||||
}
|
||||
}
|
||||
193
src/collector/sort_key_top_collector.rs
Normal file
193
src/collector/sort_key_top_collector.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::collector::sort_key::{Comparator, SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::collector::{Collector, SegmentCollector, TopNComputer};
|
||||
use crate::query::Weight;
|
||||
use crate::schema::Schema;
|
||||
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
|
||||
|
||||
pub(crate) struct TopBySortKeyCollector<TSortKeyComputer> {
|
||||
sort_key_computer: TSortKeyComputer,
|
||||
doc_range: Range<usize>,
|
||||
}
|
||||
|
||||
impl<TSortKeyComputer> TopBySortKeyCollector<TSortKeyComputer> {
|
||||
pub fn new(sort_key_computer: TSortKeyComputer, doc_range: Range<usize>) -> Self {
|
||||
TopBySortKeyCollector {
|
||||
sort_key_computer,
|
||||
doc_range,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TSortKeyComputer> Collector for TopBySortKeyCollector<TSortKeyComputer>
|
||||
where TSortKeyComputer: SortKeyComputer + Send + Sync + 'static
|
||||
{
|
||||
type Fruit = Vec<(TSortKeyComputer::SortKey, DocAddress)>;
|
||||
|
||||
type Child =
|
||||
TopBySortKeySegmentCollector<TSortKeyComputer::Child, TSortKeyComputer::Comparator>;
|
||||
|
||||
fn check_schema(&self, schema: &Schema) -> crate::Result<()> {
|
||||
self.sort_key_computer.check_schema(schema)
|
||||
}
|
||||
|
||||
fn for_segment(&self, segment_ord: u32, segment_reader: &SegmentReader) -> Result<Self::Child> {
|
||||
let segment_sort_key_computer = self
|
||||
.sort_key_computer
|
||||
.segment_sort_key_computer(segment_reader)?;
|
||||
let topn_computer = TopNComputer::new_with_comparator(
|
||||
self.doc_range.end,
|
||||
self.sort_key_computer.comparator(),
|
||||
);
|
||||
Ok(TopBySortKeySegmentCollector {
|
||||
topn_computer,
|
||||
segment_ord,
|
||||
segment_sort_key_computer,
|
||||
})
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
self.sort_key_computer.requires_scoring()
|
||||
}
|
||||
|
||||
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> Result<Self::Fruit> {
|
||||
Ok(merge_top_k(
|
||||
segment_fruits.into_iter().flatten(),
|
||||
self.doc_range.clone(),
|
||||
self.sort_key_computer.comparator(),
|
||||
))
|
||||
}
|
||||
|
||||
fn collect_segment(
|
||||
&self,
|
||||
weight: &dyn Weight,
|
||||
segment_ord: u32,
|
||||
reader: &SegmentReader,
|
||||
) -> crate::Result<Vec<(TSortKeyComputer::SortKey, DocAddress)>> {
|
||||
let k = self.doc_range.end;
|
||||
let docs = self
|
||||
.sort_key_computer
|
||||
.collect_segment_top_k(k, weight, reader, segment_ord)?;
|
||||
Ok(docs)
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_top_k<D: Ord, TSortKey: Clone + std::fmt::Debug, C: Comparator<TSortKey>>(
|
||||
sort_key_docs: impl Iterator<Item = (TSortKey, D)>,
|
||||
doc_range: Range<usize>,
|
||||
comparator: C,
|
||||
) -> Vec<(TSortKey, D)> {
|
||||
if doc_range.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
let mut top_collector: TopNComputer<TSortKey, D, C> =
|
||||
TopNComputer::new_with_comparator(doc_range.end, comparator);
|
||||
for (sort_key, doc) in sort_key_docs {
|
||||
top_collector.push(sort_key, doc);
|
||||
}
|
||||
top_collector
|
||||
.into_sorted_vec()
|
||||
.into_iter()
|
||||
.skip(doc_range.start)
|
||||
.map(|cdoc| (cdoc.sort_key, cdoc.doc))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub struct TopBySortKeySegmentCollector<TSegmentSortKeyComputer, C>
|
||||
where
|
||||
TSegmentSortKeyComputer: SegmentSortKeyComputer,
|
||||
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey>,
|
||||
{
|
||||
pub(crate) topn_computer: TopNComputer<TSegmentSortKeyComputer::SegmentSortKey, DocId, C>,
|
||||
pub(crate) segment_ord: u32,
|
||||
pub(crate) segment_sort_key_computer: TSegmentSortKeyComputer,
|
||||
}
|
||||
|
||||
impl<TSegmentSortKeyComputer, C> SegmentCollector
|
||||
for TopBySortKeySegmentCollector<TSegmentSortKeyComputer, C>
|
||||
where
|
||||
TSegmentSortKeyComputer: 'static + SegmentSortKeyComputer,
|
||||
C: Comparator<TSegmentSortKeyComputer::SegmentSortKey> + 'static,
|
||||
{
|
||||
type Fruit = Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)>;
|
||||
|
||||
fn collect(&mut self, doc: DocId, score: Score) {
|
||||
self.segment_sort_key_computer.compute_sort_key_and_collect(
|
||||
doc,
|
||||
score,
|
||||
&mut self.topn_computer,
|
||||
);
|
||||
}
|
||||
|
||||
fn harvest(self) -> Self::Fruit {
|
||||
let segment_ord = self.segment_ord;
|
||||
let segment_hits: Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)> = self
|
||||
.topn_computer
|
||||
.into_vec()
|
||||
.into_iter()
|
||||
.map(|comparable_doc| {
|
||||
let sort_key = self
|
||||
.segment_sort_key_computer
|
||||
.convert_segment_sort_key(comparable_doc.sort_key);
|
||||
(
|
||||
sort_key,
|
||||
DocAddress {
|
||||
segment_ord,
|
||||
doc_id: comparable_doc.doc,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
segment_hits
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::ops::Range;
|
||||
|
||||
use rand;
|
||||
use rand::seq::SliceRandom as _;
|
||||
|
||||
use super::merge_top_k;
|
||||
use crate::collector::sort_key::ComparatorEnum;
|
||||
use crate::Order;
|
||||
|
||||
fn test_merge_top_k_aux(
|
||||
order: Order,
|
||||
doc_range: Range<usize>,
|
||||
expected: &[(crate::Score, usize)],
|
||||
) {
|
||||
let mut vals: Vec<(crate::Score, usize)> = (0..10).map(|val| (val as f32, val)).collect();
|
||||
vals.shuffle(&mut rand::thread_rng());
|
||||
let vals_merged = merge_top_k(vals.into_iter(), doc_range, ComparatorEnum::from(order));
|
||||
assert_eq!(&vals_merged, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_top_k() {
|
||||
test_merge_top_k_aux(Order::Asc, 0..0, &[]);
|
||||
test_merge_top_k_aux(Order::Asc, 3..3, &[]);
|
||||
test_merge_top_k_aux(Order::Asc, 0..3, &[(0.0f32, 0), (1.0f32, 1), (2.0f32, 2)]);
|
||||
test_merge_top_k_aux(
|
||||
Order::Asc,
|
||||
0..11,
|
||||
&[
|
||||
(0.0f32, 0),
|
||||
(1.0f32, 1),
|
||||
(2.0f32, 2),
|
||||
(3.0f32, 3),
|
||||
(4.0f32, 4),
|
||||
(5.0f32, 5),
|
||||
(6.0f32, 6),
|
||||
(7.0f32, 7),
|
||||
(8.0f32, 8),
|
||||
(9.0f32, 9),
|
||||
],
|
||||
);
|
||||
test_merge_top_k_aux(Order::Asc, 1..3, &[(1.0f32, 1), (2.0f32, 2)]);
|
||||
test_merge_top_k_aux(Order::Desc, 0..2, &[(9.0f32, 9), (8.0f32, 8)]);
|
||||
test_merge_top_k_aux(Order::Desc, 2..4, &[(7.0f32, 7), (6.0f32, 6)]);
|
||||
}
|
||||
}
|
||||
@@ -40,7 +40,7 @@ pub fn test_filter_collector() -> crate::Result<()> {
|
||||
let filter_some_collector = FilterCollector::new(
|
||||
"price".to_string(),
|
||||
&|value: u64| value > 20_120u64,
|
||||
TopDocs::with_limit(2),
|
||||
TopDocs::with_limit(2).order_by_score(),
|
||||
);
|
||||
let top_docs = searcher.search(&query, &filter_some_collector)?;
|
||||
|
||||
@@ -50,7 +50,7 @@ pub fn test_filter_collector() -> crate::Result<()> {
|
||||
let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new(
|
||||
"price".to_string(),
|
||||
&|value| value < 5u64,
|
||||
TopDocs::with_limit(2),
|
||||
TopDocs::with_limit(2).order_by_score(),
|
||||
);
|
||||
let filtered_top_docs = searcher.search(&query, &filter_all_collector).unwrap();
|
||||
|
||||
@@ -62,8 +62,11 @@ pub fn test_filter_collector() -> crate::Result<()> {
|
||||
> 0
|
||||
}
|
||||
|
||||
let filter_dates_collector =
|
||||
FilterCollector::new("date".to_string(), &date_filter, TopDocs::with_limit(5));
|
||||
let filter_dates_collector = FilterCollector::new(
|
||||
"date".to_string(),
|
||||
&date_filter,
|
||||
TopDocs::with_limit(5).order_by_score(),
|
||||
);
|
||||
let filtered_date_docs = searcher.search(&query, &filter_dates_collector)?;
|
||||
|
||||
assert_eq!(filtered_date_docs.len(), 2);
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::top_score_collector::TopNComputer;
|
||||
use crate::index::SegmentReader;
|
||||
use crate::{DocAddress, DocId, SegmentOrdinal};
|
||||
|
||||
/// Contains a feature (field, score, etc.) of a document along with the document address.
|
||||
///
|
||||
/// It guarantees stable sorting: in case of a tie on the feature, the document
|
||||
@@ -19,7 +14,7 @@ use crate::{DocAddress, DocId, SegmentOrdinal};
|
||||
pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
|
||||
/// The feature of the document. In practice, this is
|
||||
/// is any type that implements `PartialOrd`.
|
||||
pub feature: T,
|
||||
pub sort_key: T,
|
||||
/// The document address. In practice, this is any
|
||||
/// type that implements `PartialOrd`, and is guaranteed
|
||||
/// to be unique for each document.
|
||||
@@ -28,9 +23,9 @@ pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
|
||||
impl<T: std::fmt::Debug, D: std::fmt::Debug, const R: bool> std::fmt::Debug
|
||||
for ComparableDoc<T, D, R>
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str())
|
||||
.field("feature", &self.feature)
|
||||
.field("feature", &self.sort_key)
|
||||
.field("doc", &self.doc)
|
||||
.finish()
|
||||
}
|
||||
@@ -46,8 +41,8 @@ impl<T: PartialOrd, D: PartialOrd, const R: bool> Ord for ComparableDoc<T, D, R>
|
||||
#[inline]
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
let by_feature = self
|
||||
.feature
|
||||
.partial_cmp(&other.feature)
|
||||
.sort_key
|
||||
.partial_cmp(&other.sort_key)
|
||||
.map(|ord| if R { ord.reverse() } else { ord })
|
||||
.unwrap_or(Ordering::Equal);
|
||||
|
||||
@@ -67,308 +62,3 @@ impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialEq for ComparableDoc<T,
|
||||
}
|
||||
|
||||
impl<T: PartialOrd, D: PartialOrd, const R: bool> Eq for ComparableDoc<T, D, R> {}
|
||||
|
||||
pub(crate) struct TopCollector<T> {
|
||||
pub limit: usize,
|
||||
pub offset: usize,
|
||||
_marker: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> TopCollector<T>
|
||||
where T: PartialOrd + Clone
|
||||
{
|
||||
/// Creates a top collector, with a number of documents equal to "limit".
|
||||
///
|
||||
/// # Panics
|
||||
/// The method panics if limit is 0
|
||||
pub fn with_limit(limit: usize) -> TopCollector<T> {
|
||||
assert!(limit >= 1, "Limit must be strictly greater than 0.");
|
||||
Self {
|
||||
limit,
|
||||
offset: 0,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Skip the first "offset" documents when collecting.
|
||||
///
|
||||
/// This is equivalent to `OFFSET` in MySQL or PostgreSQL and `start` in
|
||||
/// Lucene's TopDocsCollector.
|
||||
pub fn and_offset(mut self, offset: usize) -> TopCollector<T> {
|
||||
self.offset = offset;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn merge_fruits(
|
||||
&self,
|
||||
children: Vec<Vec<(T, DocAddress)>>,
|
||||
) -> crate::Result<Vec<(T, DocAddress)>> {
|
||||
if self.limit == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut top_collector: TopNComputer<_, _> = TopNComputer::new(self.limit + self.offset);
|
||||
for child_fruit in children {
|
||||
for (feature, doc) in child_fruit {
|
||||
top_collector.push(feature, doc);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(top_collector
|
||||
.into_sorted_vec()
|
||||
.into_iter()
|
||||
.skip(self.offset)
|
||||
.map(|cdoc| (cdoc.feature, cdoc.doc))
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub(crate) fn for_segment<F: PartialOrd + Clone>(
|
||||
&self,
|
||||
segment_id: SegmentOrdinal,
|
||||
_: &SegmentReader,
|
||||
) -> TopSegmentCollector<F> {
|
||||
TopSegmentCollector::new(segment_id, self.limit + self.offset)
|
||||
}
|
||||
|
||||
/// Create a new TopCollector with the same limit and offset.
|
||||
///
|
||||
/// Ideally we would use Into but the blanket implementation seems to cause the Scorer traits
|
||||
/// to fail.
|
||||
#[doc(hidden)]
|
||||
pub(crate) fn into_tscore<TScore: PartialOrd + Clone>(self) -> TopCollector<TScore> {
|
||||
TopCollector {
|
||||
limit: self.limit,
|
||||
offset: self.offset,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The Top Collector keeps track of the K documents
|
||||
/// sorted by type `T`.
|
||||
///
|
||||
/// The implementation is based on a repeatedly truncating on the median after K * 2 documents
|
||||
/// The theoretical complexity for collecting the top `K` out of `n` documents
|
||||
/// is `O(n + K)`.
|
||||
pub(crate) struct TopSegmentCollector<T> {
|
||||
/// We reverse the order of the feature in order to
|
||||
/// have top-semantics instead of bottom semantics.
|
||||
topn_computer: TopNComputer<T, DocId>,
|
||||
segment_ord: u32,
|
||||
}
|
||||
|
||||
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
|
||||
fn new(segment_ord: SegmentOrdinal, limit: usize) -> TopSegmentCollector<T> {
|
||||
TopSegmentCollector {
|
||||
topn_computer: TopNComputer::new(limit),
|
||||
segment_ord,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
|
||||
pub fn harvest(self) -> Vec<(T, DocAddress)> {
|
||||
let segment_ord = self.segment_ord;
|
||||
self.topn_computer
|
||||
.into_sorted_vec()
|
||||
.into_iter()
|
||||
.map(|comparable_doc| {
|
||||
(
|
||||
comparable_doc.feature,
|
||||
DocAddress {
|
||||
segment_ord,
|
||||
doc_id: comparable_doc.doc,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Collects a document scored by the given feature
|
||||
///
|
||||
/// It collects documents until it has reached the max capacity. Once it reaches capacity, it
|
||||
/// will compare the lowest scoring item with the given one and keep whichever is greater.
|
||||
#[inline]
|
||||
pub fn collect(&mut self, doc: DocId, feature: T) {
|
||||
self.topn_computer.push(feature, doc);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{TopCollector, TopSegmentCollector};
|
||||
use crate::DocAddress;
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_not_at_capacity() {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 4);
|
||||
top_collector.collect(1, 0.8);
|
||||
top_collector.collect(3, 0.2);
|
||||
top_collector.collect(5, 0.3);
|
||||
assert_eq!(
|
||||
top_collector.harvest(),
|
||||
vec![
|
||||
(0.8, DocAddress::new(0, 1)),
|
||||
(0.3, DocAddress::new(0, 5)),
|
||||
(0.2, DocAddress::new(0, 3))
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_at_capacity() {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 4);
|
||||
top_collector.collect(1, 0.8);
|
||||
top_collector.collect(3, 0.2);
|
||||
top_collector.collect(5, 0.3);
|
||||
top_collector.collect(7, 0.9);
|
||||
top_collector.collect(9, -0.2);
|
||||
assert_eq!(
|
||||
top_collector.harvest(),
|
||||
vec![
|
||||
(0.9, DocAddress::new(0, 7)),
|
||||
(0.8, DocAddress::new(0, 1)),
|
||||
(0.3, DocAddress::new(0, 5)),
|
||||
(0.2, DocAddress::new(0, 3))
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_segment_collector_stable_ordering_for_equal_feature() {
|
||||
// given that the documents are collected in ascending doc id order,
|
||||
// when harvesting we have to guarantee stable sorting in case of a tie
|
||||
// on the score
|
||||
let doc_ids_collection = [4, 5, 6];
|
||||
let score = 3.3f32;
|
||||
|
||||
let mut top_collector_limit_2 = TopSegmentCollector::new(0, 2);
|
||||
for id in &doc_ids_collection {
|
||||
top_collector_limit_2.collect(*id, score);
|
||||
}
|
||||
|
||||
let mut top_collector_limit_3 = TopSegmentCollector::new(0, 3);
|
||||
for id in &doc_ids_collection {
|
||||
top_collector_limit_3.collect(*id, score);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
top_collector_limit_2.harvest(),
|
||||
top_collector_limit_3.harvest()[..2].to_vec(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_with_limit_and_offset() {
|
||||
let collector = TopCollector::with_limit(2).and_offset(1);
|
||||
|
||||
let results = collector
|
||||
.merge_fruits(vec![vec![
|
||||
(0.9, DocAddress::new(0, 1)),
|
||||
(0.8, DocAddress::new(0, 2)),
|
||||
(0.7, DocAddress::new(0, 3)),
|
||||
(0.6, DocAddress::new(0, 4)),
|
||||
(0.5, DocAddress::new(0, 5)),
|
||||
]])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![(0.8, DocAddress::new(0, 2)), (0.7, DocAddress::new(0, 3)),]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_with_limit_larger_than_set_and_offset() {
|
||||
let collector = TopCollector::with_limit(2).and_offset(1);
|
||||
|
||||
let results = collector
|
||||
.merge_fruits(vec![vec![
|
||||
(0.9, DocAddress::new(0, 1)),
|
||||
(0.8, DocAddress::new(0, 2)),
|
||||
]])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results, vec![(0.8, DocAddress::new(0, 2)),]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_with_limit_and_offset_larger_than_set() {
|
||||
let collector = TopCollector::with_limit(2).and_offset(20);
|
||||
|
||||
let results = collector
|
||||
.merge_fruits(vec![vec![
|
||||
(0.9, DocAddress::new(0, 1)),
|
||||
(0.8, DocAddress::new(0, 2)),
|
||||
]])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results, vec![]);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
use test::Bencher;
|
||||
|
||||
use super::TopSegmentCollector;
|
||||
|
||||
#[bench]
|
||||
fn bench_top_segment_collector_collect_not_at_capacity(b: &mut Bencher) {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 400);
|
||||
|
||||
b.iter(|| {
|
||||
for i in 0..100 {
|
||||
top_collector.collect(i, 0.8);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_top_segment_collector_collect_at_capacity(b: &mut Bencher) {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 100);
|
||||
|
||||
for i in 0..100 {
|
||||
top_collector.collect(i, 0.8);
|
||||
}
|
||||
|
||||
b.iter(|| {
|
||||
for i in 0..100 {
|
||||
top_collector.collect(i, 0.8);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_top_segment_collector_collect_and_harvest_many_ties(b: &mut Bencher) {
|
||||
b.iter(|| {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 100);
|
||||
|
||||
for i in 0..100 {
|
||||
top_collector.collect(i, 0.8);
|
||||
}
|
||||
|
||||
// it would be nice to be able to do the setup N times but still
|
||||
// measure only harvest(). We can't since harvest() consumes
|
||||
// the top_collector.
|
||||
top_collector.harvest()
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_top_segment_collector_collect_and_harvest_no_tie(b: &mut Bencher) {
|
||||
b.iter(|| {
|
||||
let mut top_collector = TopSegmentCollector::new(0, 100);
|
||||
let mut score = 1.0;
|
||||
|
||||
for i in 0..100 {
|
||||
score += 1.0;
|
||||
top_collector.collect(i, score);
|
||||
}
|
||||
|
||||
// it would be nice to be able to do the setup N times but still
|
||||
// measure only harvest(). We can't since harvest() consumes
|
||||
// the top_collector.
|
||||
top_collector.harvest()
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,124 +0,0 @@
|
||||
use crate::collector::top_collector::{TopCollector, TopSegmentCollector};
|
||||
use crate::collector::{Collector, SegmentCollector};
|
||||
use crate::{DocAddress, DocId, Result, Score, SegmentReader};
|
||||
|
||||
pub(crate) struct TweakedScoreTopCollector<TScoreTweaker, TScore = Score> {
|
||||
score_tweaker: TScoreTweaker,
|
||||
collector: TopCollector<TScore>,
|
||||
}
|
||||
|
||||
impl<TScoreTweaker, TScore> TweakedScoreTopCollector<TScoreTweaker, TScore>
|
||||
where TScore: Clone + PartialOrd
|
||||
{
|
||||
pub fn new(
|
||||
score_tweaker: TScoreTweaker,
|
||||
collector: TopCollector<TScore>,
|
||||
) -> TweakedScoreTopCollector<TScoreTweaker, TScore> {
|
||||
TweakedScoreTopCollector {
|
||||
score_tweaker,
|
||||
collector,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A `ScoreSegmentTweaker` makes it possible to modify the default score
|
||||
/// for a given document belonging to a specific segment.
|
||||
///
|
||||
/// It is the segment local version of the [`ScoreTweaker`].
|
||||
pub trait ScoreSegmentTweaker<TScore>: 'static {
|
||||
/// Tweak the given `score` for the document `doc`.
|
||||
fn score(&mut self, doc: DocId, score: Score) -> TScore;
|
||||
}
|
||||
|
||||
/// `ScoreTweaker` makes it possible to tweak the score
|
||||
/// emitted by the scorer into another one.
|
||||
///
|
||||
/// The `ScoreTweaker` itself does not make much of the computation itself.
|
||||
/// Instead, it helps constructing `Self::Child` instances that will compute
|
||||
/// the score at a segment scale.
|
||||
pub trait ScoreTweaker<TScore>: Sync {
|
||||
/// Type of the associated [`ScoreSegmentTweaker`].
|
||||
type Child: ScoreSegmentTweaker<TScore>;
|
||||
|
||||
/// Builds a child tweaker for a specific segment. The child scorer is associated with
|
||||
/// a specific segment.
|
||||
fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result<Self::Child>;
|
||||
}
|
||||
|
||||
impl<TScoreTweaker, TScore> Collector for TweakedScoreTopCollector<TScoreTweaker, TScore>
|
||||
where
|
||||
TScoreTweaker: ScoreTweaker<TScore> + Send + Sync,
|
||||
TScore: 'static + PartialOrd + Clone + Send + Sync,
|
||||
{
|
||||
type Fruit = Vec<(TScore, DocAddress)>;
|
||||
|
||||
type Child = TopTweakedScoreSegmentCollector<TScoreTweaker::Child, TScore>;
|
||||
|
||||
fn for_segment(
|
||||
&self,
|
||||
segment_local_id: u32,
|
||||
segment_reader: &SegmentReader,
|
||||
) -> Result<Self::Child> {
|
||||
let segment_scorer = self.score_tweaker.segment_tweaker(segment_reader)?;
|
||||
let segment_collector = self.collector.for_segment(segment_local_id, segment_reader);
|
||||
Ok(TopTweakedScoreSegmentCollector {
|
||||
segment_collector,
|
||||
segment_scorer,
|
||||
})
|
||||
}
|
||||
|
||||
fn requires_scoring(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> Result<Self::Fruit> {
|
||||
self.collector.merge_fruits(segment_fruits)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TopTweakedScoreSegmentCollector<TSegmentScoreTweaker, TScore>
|
||||
where
|
||||
TScore: 'static + PartialOrd + Clone + Send + Sync + Sized,
|
||||
TSegmentScoreTweaker: ScoreSegmentTweaker<TScore>,
|
||||
{
|
||||
segment_collector: TopSegmentCollector<TScore>,
|
||||
segment_scorer: TSegmentScoreTweaker,
|
||||
}
|
||||
|
||||
impl<TSegmentScoreTweaker, TScore> SegmentCollector
|
||||
for TopTweakedScoreSegmentCollector<TSegmentScoreTweaker, TScore>
|
||||
where
|
||||
TScore: 'static + PartialOrd + Clone + Send + Sync,
|
||||
TSegmentScoreTweaker: 'static + ScoreSegmentTweaker<TScore>,
|
||||
{
|
||||
type Fruit = Vec<(TScore, DocAddress)>;
|
||||
|
||||
fn collect(&mut self, doc: DocId, score: Score) {
|
||||
let score = self.segment_scorer.score(doc, score);
|
||||
self.segment_collector.collect(doc, score);
|
||||
}
|
||||
|
||||
fn harvest(self) -> Vec<(TScore, DocAddress)> {
|
||||
self.segment_collector.harvest()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, TScore, TSegmentScoreTweaker> ScoreTweaker<TScore> for F
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(&SegmentReader) -> TSegmentScoreTweaker,
|
||||
TSegmentScoreTweaker: ScoreSegmentTweaker<TScore>,
|
||||
{
|
||||
type Child = TSegmentScoreTweaker;
|
||||
|
||||
fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
|
||||
Ok((self)(segment_reader))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, TScore> ScoreSegmentTweaker<TScore> for F
|
||||
where F: 'static + FnMut(DocId, Score) -> TScore
|
||||
{
|
||||
fn score(&mut self, doc: DocId, score: Score) -> TScore {
|
||||
(self)(doc, score)
|
||||
}
|
||||
}
|
||||
@@ -69,7 +69,7 @@ fn assert_date_time_precision(index: &Index, doc_store_precision: DateTimePrecis
|
||||
.parse_query("dateformat")
|
||||
.expect("Failed to parse query");
|
||||
let top_docs = searcher
|
||||
.search(&query, &TopDocs::with_limit(1))
|
||||
.search(&query, &TopDocs::with_limit(1).order_by_score())
|
||||
.expect("Search failed");
|
||||
|
||||
assert_eq!(top_docs.len(), 1, "Expected 1 search result");
|
||||
|
||||
@@ -48,7 +48,15 @@ impl Executor {
|
||||
F: Sized + Sync + Fn(A) -> crate::Result<R>,
|
||||
{
|
||||
match self {
|
||||
Executor::SingleThread => args.map(f).collect::<crate::Result<_>>(),
|
||||
Executor::SingleThread => {
|
||||
// Avoid `collect`, since the stacktrace is blown up by it, which makes profiling
|
||||
// harder.
|
||||
let mut result = Vec::with_capacity(args.size_hint().0);
|
||||
for arg in args {
|
||||
result.push(f(arg)?);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
Executor::ThreadPool(pool) => {
|
||||
let args: Vec<A> = args.collect();
|
||||
let num_fruits = args.len();
|
||||
|
||||
@@ -3,6 +3,7 @@ use common::json_path_writer::{JSON_END_OF_PATH, JSON_PATH_SEGMENT_SEP};
|
||||
use common::{replace_in_place, JsonPathWriter};
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::indexer::indexing_term::IndexingTerm;
|
||||
use crate::postings::{IndexingContext, IndexingPosition, PostingsWriter};
|
||||
use crate::schema::document::{ReferenceValue, ReferenceValueLeaf, Value};
|
||||
use crate::schema::{Type, DATE_TIME_PRECISION_INDEXED};
|
||||
@@ -77,7 +78,7 @@ fn index_json_object<'a, V: Value<'a>>(
|
||||
doc: DocId,
|
||||
json_visitor: V::ObjectIter,
|
||||
text_analyzer: &mut TextAnalyzer,
|
||||
term_buffer: &mut Term,
|
||||
term_buffer: &mut IndexingTerm,
|
||||
json_path_writer: &mut JsonPathWriter,
|
||||
postings_writer: &mut dyn PostingsWriter,
|
||||
ctx: &mut IndexingContext,
|
||||
@@ -107,17 +108,17 @@ pub(crate) fn index_json_value<'a, V: Value<'a>>(
|
||||
doc: DocId,
|
||||
json_value: V,
|
||||
text_analyzer: &mut TextAnalyzer,
|
||||
term_buffer: &mut Term,
|
||||
term_buffer: &mut IndexingTerm,
|
||||
json_path_writer: &mut JsonPathWriter,
|
||||
postings_writer: &mut dyn PostingsWriter,
|
||||
ctx: &mut IndexingContext,
|
||||
positions_per_path: &mut IndexingPositionsPerPath,
|
||||
) {
|
||||
let set_path_id = |term_buffer: &mut Term, unordered_id: u32| {
|
||||
let set_path_id = |term_buffer: &mut IndexingTerm, unordered_id: u32| {
|
||||
term_buffer.truncate_value_bytes(0);
|
||||
term_buffer.append_bytes(&unordered_id.to_be_bytes());
|
||||
};
|
||||
let set_type = |term_buffer: &mut Term, typ: Type| {
|
||||
let set_type = |term_buffer: &mut IndexingTerm, typ: Type| {
|
||||
term_buffer.append_bytes(&[typ.to_code()]);
|
||||
};
|
||||
|
||||
|
||||
@@ -225,6 +225,7 @@ impl Searcher {
|
||||
enabled_scoring: EnableScoring,
|
||||
) -> crate::Result<C::Fruit> {
|
||||
let weight = query.weight(enabled_scoring)?;
|
||||
collector.check_schema(self.schema())?;
|
||||
let segment_readers = self.segment_readers();
|
||||
let fruits = executor.map(
|
||||
|(segment_ord, segment_reader)| {
|
||||
|
||||
@@ -108,7 +108,7 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
|
||||
/// Opens a file and returns a boxed `FileHandle`.
|
||||
///
|
||||
/// Users of `Directory` should typically call `Directory::open_read(...)`,
|
||||
/// while `Directory` implementor should implement `get_file_handle()`.
|
||||
/// while `Directory` implementer should implement `get_file_handle()`.
|
||||
fn get_file_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>, OpenReadError>;
|
||||
|
||||
/// Once a virtual file is open, its data may not
|
||||
|
||||
@@ -104,7 +104,7 @@ pub enum TantivyError {
|
||||
#[error("{0:?}")]
|
||||
IncompatibleIndex(Incompatibility),
|
||||
/// An internal error occurred. This is are internal states that should not be reached.
|
||||
/// e.g. a datastructure is incorrectly inititalized.
|
||||
/// e.g. a datastructure is incorrectly initialized.
|
||||
#[error("Internal error: '{0}'")]
|
||||
InternalError(String),
|
||||
#[error("Deserialize error: {0}")]
|
||||
|
||||
@@ -726,22 +726,22 @@ mod tests {
|
||||
.column_opt::<DateTime>("multi_date")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let mut dates = Vec::new();
|
||||
|
||||
{
|
||||
assert_eq!(date_fast_field.get_val(0).into_timestamp_nanos(), 1i64);
|
||||
dates_fast_field.fill_vals(0u32, &mut dates);
|
||||
let dates: Vec<DateTime> = dates_fast_field.values_for_doc(0u32).collect();
|
||||
assert_eq!(dates.len(), 2);
|
||||
assert_eq!(dates[0].into_timestamp_nanos(), 2i64);
|
||||
assert_eq!(dates[1].into_timestamp_nanos(), 3i64);
|
||||
}
|
||||
{
|
||||
assert_eq!(date_fast_field.get_val(1).into_timestamp_nanos(), 4i64);
|
||||
dates_fast_field.fill_vals(1u32, &mut dates);
|
||||
let dates: Vec<DateTime> = dates_fast_field.values_for_doc(1u32).collect();
|
||||
assert!(dates.is_empty());
|
||||
}
|
||||
{
|
||||
assert_eq!(date_fast_field.get_val(2).into_timestamp_nanos(), 0i64);
|
||||
dates_fast_field.fill_vals(2u32, &mut dates);
|
||||
let dates: Vec<DateTime> = dates_fast_field.values_for_doc(2u32).collect();
|
||||
assert_eq!(dates.len(), 2);
|
||||
assert_eq!(dates[0].into_timestamp_nanos(), 5i64);
|
||||
assert_eq!(dates[1].into_timestamp_nanos(), 6i64);
|
||||
|
||||
@@ -276,13 +276,14 @@ impl Default for IndexSettings {
|
||||
}
|
||||
|
||||
/// The order to sort by
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub enum Order {
|
||||
/// Ascending Order
|
||||
Asc,
|
||||
/// Descending Order
|
||||
Desc,
|
||||
}
|
||||
|
||||
impl Order {
|
||||
/// return if the Order is ascending
|
||||
pub fn is_asc(&self) -> bool {
|
||||
|
||||
@@ -608,7 +608,7 @@ mod test {
|
||||
term_dictionary_size: Some(ByteCount::from(100u64)),
|
||||
postings_size: Some(ByteCount::from(1_000u64)),
|
||||
positions_size: Some(ByteCount::from(2_000u64)),
|
||||
fast_size: Some(ByteCount::from(1_000u64).into()),
|
||||
fast_size: Some(ByteCount::from(1_000u64)),
|
||||
};
|
||||
let field_metadata2 = FieldMetadata {
|
||||
field_name: "a".to_string(),
|
||||
@@ -617,7 +617,7 @@ mod test {
|
||||
term_dictionary_size: Some(ByteCount::from(80u64)),
|
||||
postings_size: Some(ByteCount::from(1_500u64)),
|
||||
positions_size: Some(ByteCount::from(2_500u64)),
|
||||
fast_size: Some(ByteCount::from(3_000u64).into()),
|
||||
fast_size: Some(ByteCount::from(3_000u64)),
|
||||
};
|
||||
let expected = FieldMetadata {
|
||||
field_name: "a".to_string(),
|
||||
@@ -626,7 +626,7 @@ mod test {
|
||||
term_dictionary_size: Some(ByteCount::from(180u64)),
|
||||
postings_size: Some(ByteCount::from(2_500u64)),
|
||||
positions_size: Some(ByteCount::from(4_500u64)),
|
||||
fast_size: Some(ByteCount::from(4_000u64).into()),
|
||||
fast_size: Some(ByteCount::from(4_000u64)),
|
||||
};
|
||||
assert_merge(
|
||||
&[vec![field_metadata1.clone()], vec![field_metadata2]],
|
||||
|
||||
@@ -513,7 +513,7 @@ impl<D: Document> IndexWriter<D> {
|
||||
/// let searcher = index.reader()?.searcher();
|
||||
/// let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
/// let query_promo = query_parser.parse_query("Prometheus")?;
|
||||
/// let top_docs_promo = searcher.search(&query_promo, &TopDocs::with_limit(1))?;
|
||||
/// let top_docs_promo = searcher.search(&query_promo, &TopDocs::with_limit(1).order_by_score())?;
|
||||
///
|
||||
/// assert!(top_docs_promo.is_empty());
|
||||
/// Ok(())
|
||||
@@ -946,11 +946,11 @@ mod tests {
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let a_docs = searcher
|
||||
.search(&a_query, &TopDocs::with_limit(1))
|
||||
.search(&a_query, &TopDocs::with_limit(1).order_by_score())
|
||||
.expect("search for a failed");
|
||||
|
||||
let b_docs = searcher
|
||||
.search(&b_query, &TopDocs::with_limit(1))
|
||||
.search(&b_query, &TopDocs::with_limit(1).order_by_score())
|
||||
.expect("search for b failed");
|
||||
|
||||
assert_eq!(a_docs.len(), 1);
|
||||
@@ -2014,8 +2014,9 @@ mod tests {
|
||||
let query = QueryParser::for_index(&index, vec![field])
|
||||
.parse_query(term)
|
||||
.unwrap();
|
||||
let top_docs: Vec<(f32, DocAddress)> =
|
||||
searcher.search(&query, &TopDocs::with_limit(1000)).unwrap();
|
||||
let top_docs: Vec<(f32, DocAddress)> = searcher
|
||||
.search(&query, &TopDocs::with_limit(1000).order_by_score())
|
||||
.unwrap();
|
||||
|
||||
top_docs.iter().map(|el| el.1).collect::<Vec<_>>()
|
||||
};
|
||||
@@ -2449,8 +2450,9 @@ mod tests {
|
||||
Term::from_field_u64(id_field, existing_id),
|
||||
IndexRecordOption::Basic,
|
||||
);
|
||||
let top_docs: Vec<(f32, DocAddress)> =
|
||||
searcher.search(&query, &TopDocs::with_limit(10)).unwrap();
|
||||
let top_docs: Vec<(f32, DocAddress)> = searcher
|
||||
.search(&query, &TopDocs::with_limit(10).order_by_score())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(top_docs.len(), 1); // Was failing
|
||||
|
||||
@@ -2491,8 +2493,9 @@ mod tests {
|
||||
Term::from_field_i64(id_field, 10i64),
|
||||
IndexRecordOption::Basic,
|
||||
);
|
||||
let top_docs: Vec<(f32, DocAddress)> =
|
||||
searcher.search(&query, &TopDocs::with_limit(10)).unwrap();
|
||||
let top_docs: Vec<(f32, DocAddress)> = searcher
|
||||
.search(&query, &TopDocs::with_limit(10).order_by_score())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(top_docs.len(), 1); // Fails
|
||||
|
||||
@@ -2500,8 +2503,9 @@ mod tests {
|
||||
Term::from_field_i64(id_field, 30i64),
|
||||
IndexRecordOption::Basic,
|
||||
);
|
||||
let top_docs: Vec<(f32, DocAddress)> =
|
||||
searcher.search(&query, &TopDocs::with_limit(10)).unwrap();
|
||||
let top_docs: Vec<(f32, DocAddress)> = searcher
|
||||
.search(&query, &TopDocs::with_limit(10).order_by_score())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(top_docs.len(), 1); // Fails
|
||||
|
||||
|
||||
184
src/indexer/indexing_term.rs
Normal file
184
src/indexer/indexing_term.rs
Normal file
@@ -0,0 +1,184 @@
|
||||
use std::net::Ipv6Addr;
|
||||
|
||||
use columnar::MonotonicallyMappableToU128;
|
||||
|
||||
use crate::fastfield::FastValue;
|
||||
use crate::schema::{Field, Type};
|
||||
|
||||
/// Term represents the value that the token can take.
|
||||
/// It's a serialized representation over different types.
|
||||
///
|
||||
/// It actually wraps a `Vec<u8>`. The first 5 bytes are metadata.
|
||||
/// 4 bytes are the field id, and the last byte is the type.
|
||||
///
|
||||
/// The serialized value `ValueBytes` is considered everything after the 4 first bytes (term id).
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct IndexingTerm<B = Vec<u8>>(B)
|
||||
where B: AsRef<[u8]>;
|
||||
|
||||
/// The number of bytes used as metadata by `Term`.
|
||||
const TERM_METADATA_LENGTH: usize = 5;
|
||||
|
||||
impl IndexingTerm {
|
||||
/// Create a new Term with a buffer with a given capacity.
|
||||
pub fn with_capacity(capacity: usize) -> IndexingTerm {
|
||||
let mut data = Vec::with_capacity(TERM_METADATA_LENGTH + capacity);
|
||||
data.resize(TERM_METADATA_LENGTH, 0u8);
|
||||
IndexingTerm(data)
|
||||
}
|
||||
|
||||
/// Panics when the term is not empty... ie: some value is set.
|
||||
/// Use `clear_with_field_and_type` in that case.
|
||||
///
|
||||
/// Sets field and the type.
|
||||
pub(crate) fn set_field_and_type(&mut self, field: Field, typ: Type) {
|
||||
assert!(self.is_empty());
|
||||
self.0[0..4].clone_from_slice(field.field_id().to_be_bytes().as_ref());
|
||||
self.0[4] = typ.to_code();
|
||||
}
|
||||
|
||||
/// Is empty if there are no value bytes.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0.len() == TERM_METADATA_LENGTH
|
||||
}
|
||||
|
||||
/// Removes the value_bytes and set the field and type code.
|
||||
pub(crate) fn clear_with_field_and_type(&mut self, typ: Type, field: Field) {
|
||||
self.truncate_value_bytes(0);
|
||||
self.set_field_and_type(field, typ);
|
||||
}
|
||||
|
||||
/// Sets a u64 value in the term.
|
||||
///
|
||||
/// U64 are serialized using (8-byte) BigEndian
|
||||
/// representation.
|
||||
/// The use of BigEndian has the benefit of preserving
|
||||
/// the natural order of the values.
|
||||
pub fn set_u64(&mut self, val: u64) {
|
||||
self.set_fast_value(val);
|
||||
}
|
||||
|
||||
/// Sets a `i64` value in the term.
|
||||
pub fn set_i64(&mut self, val: i64) {
|
||||
self.set_fast_value(val);
|
||||
}
|
||||
|
||||
/// Sets a `f64` value in the term.
|
||||
pub fn set_f64(&mut self, val: f64) {
|
||||
self.set_fast_value(val);
|
||||
}
|
||||
|
||||
/// Sets a `bool` value in the term.
|
||||
pub fn set_bool(&mut self, val: bool) {
|
||||
self.set_fast_value(val);
|
||||
}
|
||||
|
||||
fn set_fast_value<T: FastValue>(&mut self, val: T) {
|
||||
self.set_bytes(val.to_u64().to_be_bytes().as_ref());
|
||||
}
|
||||
|
||||
/// Append a type marker + fast value to a term.
|
||||
/// This is used in JSON type to append a fast value after the path.
|
||||
///
|
||||
/// It will not clear existing bytes.
|
||||
pub fn append_type_and_fast_value<T: FastValue>(&mut self, val: T) {
|
||||
self.0.push(T::to_type().to_code());
|
||||
let value = val.to_u64();
|
||||
self.0.extend(value.to_be_bytes().as_ref());
|
||||
}
|
||||
|
||||
/// Sets a `Ipv6Addr` value in the term.
|
||||
pub fn set_ip_addr(&mut self, val: Ipv6Addr) {
|
||||
self.set_bytes(val.to_u128().to_be_bytes().as_ref());
|
||||
}
|
||||
|
||||
/// Sets the value of a `Bytes` field.
|
||||
pub fn set_bytes(&mut self, bytes: &[u8]) {
|
||||
self.truncate_value_bytes(0);
|
||||
self.0.extend(bytes);
|
||||
}
|
||||
|
||||
/// Truncates the value bytes of the term. Value and field type stays the same.
|
||||
pub fn truncate_value_bytes(&mut self, len: usize) {
|
||||
self.0.truncate(len + TERM_METADATA_LENGTH);
|
||||
}
|
||||
|
||||
/// The length of the bytes.
|
||||
pub fn len_bytes(&self) -> usize {
|
||||
self.0.len() - TERM_METADATA_LENGTH
|
||||
}
|
||||
|
||||
/// Appends value bytes to the Term.
|
||||
///
|
||||
/// This function returns the segment that has just been added.
|
||||
#[inline]
|
||||
pub fn append_bytes(&mut self, bytes: &[u8]) -> &mut [u8] {
|
||||
let len_before = self.0.len();
|
||||
self.0.extend_from_slice(bytes);
|
||||
&mut self.0[len_before..]
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> IndexingTerm<B>
|
||||
where B: AsRef<[u8]>
|
||||
{
|
||||
/// Returns the serialized representation of Term.
|
||||
/// This includes field_id, value type and value.
|
||||
///
|
||||
/// Do NOT rely on this byte representation in the index.
|
||||
/// This value is likely to change in the future.
|
||||
#[inline]
|
||||
pub fn serialized_term(&self) -> &[u8] {
|
||||
self.0.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use crate::schema::*;
|
||||
|
||||
#[test]
|
||||
pub fn test_term_str() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
schema_builder.add_text_field("text", STRING);
|
||||
let title_field = schema_builder.add_text_field("title", STRING);
|
||||
let term = Term::from_field_text(title_field, "test");
|
||||
assert_eq!(term.field(), title_field);
|
||||
assert_eq!(term.typ(), Type::Str);
|
||||
assert_eq!(term.value().as_str(), Some("test"))
|
||||
}
|
||||
|
||||
/// Size (in bytes) of the buffer of a fast value (u64, i64, f64, or date) term.
|
||||
/// <field> + <type byte> + <value len>
|
||||
///
|
||||
/// - <field> is a big endian encoded u32 field id
|
||||
/// - <type_byte>'s most significant bit expresses whether the term is a json term or not The
|
||||
/// remaining 7 bits are used to encode the type of the value. If this is a JSON term, the
|
||||
/// type is the type of the leaf of the json.
|
||||
/// - <value> is, if this is not the json term, a binary representation specific to the type.
|
||||
/// If it is a JSON Term, then it is prepended with the path that leads to this leaf value.
|
||||
const FAST_VALUE_TERM_LEN: usize = 4 + 1 + 8;
|
||||
|
||||
#[test]
|
||||
pub fn test_term_u64() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let count_field = schema_builder.add_u64_field("count", INDEXED);
|
||||
let term = Term::from_field_u64(count_field, 983u64);
|
||||
assert_eq!(term.field(), count_field);
|
||||
assert_eq!(term.typ(), Type::U64);
|
||||
assert_eq!(term.serialized_term().len(), FAST_VALUE_TERM_LEN);
|
||||
assert_eq!(term.value().as_u64(), Some(983u64))
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_term_bool() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let bool_field = schema_builder.add_bool_field("bool", INDEXED);
|
||||
let term = Term::from_field_bool(bool_field, true);
|
||||
assert_eq!(term.field(), bool_field);
|
||||
assert_eq!(term.typ(), Type::Bool);
|
||||
assert_eq!(term.serialized_term().len(), FAST_VALUE_TERM_LEN);
|
||||
assert_eq!(term.value().as_bool(), Some(true))
|
||||
}
|
||||
}
|
||||
@@ -104,8 +104,9 @@ mod tests {
|
||||
let query = QueryParser::for_index(&index, vec![my_text_field])
|
||||
.parse_query(term)
|
||||
.unwrap();
|
||||
let top_docs: Vec<(f32, DocAddress)> =
|
||||
searcher.search(&query, &TopDocs::with_limit(3)).unwrap();
|
||||
let top_docs: Vec<(f32, DocAddress)> = searcher
|
||||
.search(&query, &TopDocs::with_limit(3).order_by_score())
|
||||
.unwrap();
|
||||
|
||||
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
@@ -1518,7 +1518,8 @@ mod tests {
|
||||
let searcher = reader.searcher();
|
||||
let mut term_scorer = term_query
|
||||
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
|
||||
.specialized_scorer(searcher.segment_reader(0u32), 1.0)?;
|
||||
.term_scorer_for_test(searcher.segment_reader(0u32), 1.0)?
|
||||
.unwrap();
|
||||
assert_eq!(term_scorer.doc(), 0);
|
||||
assert_nearly_equals!(term_scorer.block_max_score(), 0.0079681855);
|
||||
assert_nearly_equals!(term_scorer.score(), 0.0079681855);
|
||||
@@ -1533,7 +1534,8 @@ mod tests {
|
||||
for segment_reader in searcher.segment_readers() {
|
||||
let mut term_scorer = term_query
|
||||
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
|
||||
.specialized_scorer(segment_reader, 1.0)?;
|
||||
.term_scorer_for_test(segment_reader, 1.0)?
|
||||
.unwrap();
|
||||
// the difference compared to before is intrinsic to the bm25 formula. no worries
|
||||
// there.
|
||||
for doc in segment_reader.doc_ids_alive() {
|
||||
@@ -1558,7 +1560,8 @@ mod tests {
|
||||
let segment_reader = searcher.segment_reader(0u32);
|
||||
let mut term_scorer = term_query
|
||||
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
|
||||
.specialized_scorer(segment_reader, 1.0)?;
|
||||
.term_scorer_for_test(segment_reader, 1.0)?
|
||||
.unwrap();
|
||||
// the difference compared to before is intrinsic to the bm25 formula. no worries there.
|
||||
for doc in segment_reader.doc_ids_alive() {
|
||||
assert_eq!(term_scorer.doc(), doc);
|
||||
|
||||
@@ -12,6 +12,7 @@ mod doc_opstamp_mapping;
|
||||
mod flat_map_with_buffer;
|
||||
pub(crate) mod index_writer;
|
||||
pub(crate) mod index_writer_status;
|
||||
pub(crate) mod indexing_term;
|
||||
mod log_merge_policy;
|
||||
mod merge_index_test;
|
||||
mod merge_operation;
|
||||
@@ -181,6 +182,7 @@ mod tests_mmap {
|
||||
let field_name_out = ".";
|
||||
test_json_field_name(field_name_in, field_name_out);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_field_dot() {
|
||||
// Test when field name contains a '.'
|
||||
@@ -587,7 +589,9 @@ mod tests_mmap {
|
||||
};
|
||||
let query_str = &format!("{}:{}", indexed_field.field_name, val);
|
||||
let query = query_parser.parse_query(query_str).unwrap();
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2)).unwrap();
|
||||
let count_docs = searcher
|
||||
.search(&*query, &TopDocs::with_limit(2).order_by_score())
|
||||
.unwrap();
|
||||
if indexed_field.field_name.contains("empty") || indexed_field.typ == Type::Json {
|
||||
assert_eq!(count_docs.len(), 0);
|
||||
} else {
|
||||
@@ -659,7 +663,9 @@ mod tests_mmap {
|
||||
for (indexed_field, val) in fields_and_vals.iter() {
|
||||
let query_str = &format!("{indexed_field}:{val}");
|
||||
let query = query_parser.parse_query(query_str).unwrap();
|
||||
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2)).unwrap();
|
||||
let count_docs = searcher
|
||||
.search(&*query, &TopDocs::with_limit(2).order_by_score())
|
||||
.unwrap();
|
||||
assert!(!count_docs.is_empty(), "{indexed_field}:{val}");
|
||||
}
|
||||
// Test if field name can be used for aggregation
|
||||
|
||||
@@ -1052,8 +1052,9 @@ mod tests {
|
||||
let query = QueryParser::for_index(&index, vec![text_field])
|
||||
.parse_query(term)
|
||||
.unwrap();
|
||||
let top_docs: Vec<(f32, DocAddress)> =
|
||||
searcher.search(&query, &TopDocs::with_limit(3)).unwrap();
|
||||
let top_docs: Vec<(f32, DocAddress)> = searcher
|
||||
.search(&query, &TopDocs::with_limit(3).order_by_score())
|
||||
.unwrap();
|
||||
|
||||
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
@@ -7,6 +7,7 @@ use super::operation::AddOperation;
|
||||
use crate::fastfield::FastFieldsWriter;
|
||||
use crate::fieldnorm::{FieldNormReaders, FieldNormsWriter};
|
||||
use crate::index::{Segment, SegmentComponent};
|
||||
use crate::indexer::indexing_term::IndexingTerm;
|
||||
use crate::indexer::segment_serializer::SegmentSerializer;
|
||||
use crate::json_utils::{index_json_value, IndexingPositionsPerPath};
|
||||
use crate::postings::{
|
||||
@@ -14,7 +15,7 @@ use crate::postings::{
|
||||
PerFieldPostingsWriter, PostingsWriter,
|
||||
};
|
||||
use crate::schema::document::{Document, Value};
|
||||
use crate::schema::{FieldEntry, FieldType, Schema, Term, DATE_TIME_PRECISION_INDEXED};
|
||||
use crate::schema::{FieldEntry, FieldType, Schema, DATE_TIME_PRECISION_INDEXED};
|
||||
use crate::tokenizer::{FacetTokenizer, PreTokenizedStream, TextAnalyzer, Tokenizer};
|
||||
use crate::{DocId, Opstamp, TantivyError};
|
||||
|
||||
@@ -55,7 +56,7 @@ pub struct SegmentWriter {
|
||||
pub(crate) json_positions_per_path: IndexingPositionsPerPath,
|
||||
pub(crate) doc_opstamps: Vec<Opstamp>,
|
||||
per_field_text_analyzers: Vec<TextAnalyzer>,
|
||||
term_buffer: Term,
|
||||
term_buffer: IndexingTerm,
|
||||
schema: Schema,
|
||||
}
|
||||
|
||||
@@ -112,7 +113,7 @@ impl SegmentWriter {
|
||||
)?,
|
||||
doc_opstamps: Vec::with_capacity(1_000),
|
||||
per_field_text_analyzers,
|
||||
term_buffer: Term::with_capacity(16),
|
||||
term_buffer: IndexingTerm::with_capacity(16),
|
||||
schema,
|
||||
})
|
||||
}
|
||||
@@ -519,7 +520,7 @@ mod tests {
|
||||
.reader()
|
||||
.unwrap()
|
||||
.searcher()
|
||||
.search(&text_query, &TopDocs::with_limit(4))
|
||||
.search(&text_query, &TopDocs::with_limit(4).order_by_score())
|
||||
.unwrap();
|
||||
assert_eq!(score_docs.len(), 1);
|
||||
|
||||
@@ -528,7 +529,7 @@ mod tests {
|
||||
.reader()
|
||||
.unwrap()
|
||||
.searcher()
|
||||
.search(&text_query, &TopDocs::with_limit(4))
|
||||
.search(&text_query, &TopDocs::with_limit(4).order_by_score())
|
||||
.unwrap();
|
||||
assert_eq!(score_docs.len(), 2);
|
||||
}
|
||||
@@ -561,7 +562,7 @@ mod tests {
|
||||
.reader()
|
||||
.unwrap()
|
||||
.searcher()
|
||||
.search(&text_query, &TopDocs::with_limit(4))
|
||||
.search(&text_query, &TopDocs::with_limit(4).order_by_score())
|
||||
.unwrap();
|
||||
assert_eq!(score_docs.len(), 1);
|
||||
};
|
||||
|
||||
@@ -42,7 +42,6 @@ mod test {
|
||||
|
||||
use super::Stamper;
|
||||
|
||||
#[expect(clippy::redundant_clone)]
|
||||
#[test]
|
||||
fn test_stamper() {
|
||||
let stamper = Stamper::new(7u64);
|
||||
@@ -58,7 +57,6 @@ mod test {
|
||||
assert_eq!(stamper.stamp(), 15u64);
|
||||
}
|
||||
|
||||
#[expect(clippy::redundant_clone)]
|
||||
#[test]
|
||||
fn test_stamper_revert() {
|
||||
let stamper = Stamper::new(7u64);
|
||||
|
||||
@@ -85,7 +85,7 @@
|
||||
//! // Perform search.
|
||||
//! // `topdocs` contains the 10 most relevant doc ids, sorted by decreasing scores...
|
||||
//! let top_docs: Vec<(Score, DocAddress)> =
|
||||
//! searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
//! searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
//!
|
||||
//! for (_score, doc_address) in top_docs {
|
||||
//! // Retrieve the actual content of documents given its `doc_address`.
|
||||
@@ -125,7 +125,7 @@
|
||||
//!
|
||||
//! - **Searching**: [Searcher] searches the segments with anything that implements
|
||||
//! [Query](query::Query) and merges the results. The list of [supported
|
||||
//! queries](query::Query#implementors). Custom Queries are supported by implementing the
|
||||
//! queries](query::Query#implementers). Custom Queries are supported by implementing the
|
||||
//! [Query](query::Query) trait.
|
||||
//!
|
||||
//! - **[Directory](directory)**: Abstraction over the storage where the index data is stored.
|
||||
|
||||
@@ -3,13 +3,14 @@ use std::io;
|
||||
use common::json_path_writer::JSON_END_OF_PATH;
|
||||
use stacker::Addr;
|
||||
|
||||
use crate::indexer::indexing_term::IndexingTerm;
|
||||
use crate::indexer::path_to_unordered_id::OrderedPathId;
|
||||
use crate::postings::postings_writer::SpecializedPostingsWriter;
|
||||
use crate::postings::recorder::{BufferLender, DocIdRecorder, Recorder};
|
||||
use crate::postings::{FieldSerializer, IndexingContext, IndexingPosition, PostingsWriter};
|
||||
use crate::schema::{Field, Type};
|
||||
use crate::schema::{Field, Type, ValueBytes};
|
||||
use crate::tokenizer::TokenStream;
|
||||
use crate::{DocId, Term};
|
||||
use crate::DocId;
|
||||
|
||||
/// The `JsonPostingsWriter` is odd in that it relies on a hidden contract:
|
||||
///
|
||||
@@ -33,7 +34,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
pos: u32,
|
||||
term: &crate::Term,
|
||||
term: &IndexingTerm,
|
||||
ctx: &mut IndexingContext,
|
||||
) {
|
||||
self.non_str_posting_writer.subscribe(doc, pos, term, ctx);
|
||||
@@ -43,7 +44,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
|
||||
&mut self,
|
||||
doc_id: DocId,
|
||||
token_stream: &mut dyn TokenStream,
|
||||
term_buffer: &mut Term,
|
||||
term_buffer: &mut IndexingTerm,
|
||||
ctx: &mut IndexingContext,
|
||||
indexing_position: &mut IndexingPosition,
|
||||
) {
|
||||
@@ -64,40 +65,38 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
|
||||
ctx: &IndexingContext,
|
||||
serializer: &mut FieldSerializer,
|
||||
) -> io::Result<()> {
|
||||
let mut term_buffer = Term::with_capacity(48);
|
||||
let mut term_buffer = JsonTermSerializer(Vec::with_capacity(48));
|
||||
let mut buffer_lender = BufferLender::default();
|
||||
term_buffer.clear_with_field_and_type(Type::Json, Field::from_field_id(0));
|
||||
let mut prev_term_id = u32::MAX;
|
||||
let mut term_path_len = 0; // this will be set in the first iteration
|
||||
for (_field, path_id, term, addr) in ordered_term_addrs {
|
||||
if prev_term_id != path_id.path_id() {
|
||||
term_buffer.truncate_value_bytes(0);
|
||||
term_buffer.append_path(ordered_id_to_path[path_id.path_id() as usize].as_bytes());
|
||||
term_buffer.append_bytes(&[JSON_END_OF_PATH]);
|
||||
term_path_len = term_buffer.len_bytes();
|
||||
term_buffer.clear();
|
||||
term_buffer.append_json_path(ordered_id_to_path[path_id.path_id() as usize]);
|
||||
term_path_len = term_buffer.len();
|
||||
prev_term_id = path_id.path_id();
|
||||
}
|
||||
term_buffer.truncate_value_bytes(term_path_len);
|
||||
term_buffer.truncate(term_path_len);
|
||||
term_buffer.append_bytes(term);
|
||||
if let Some(json_value) = term_buffer.value().as_json_value_bytes() {
|
||||
let typ = json_value.typ();
|
||||
if typ == Type::Str {
|
||||
SpecializedPostingsWriter::<Rec>::serialize_one_term(
|
||||
term_buffer.serialized_value_bytes(),
|
||||
*addr,
|
||||
&mut buffer_lender,
|
||||
ctx,
|
||||
serializer,
|
||||
)?;
|
||||
} else {
|
||||
SpecializedPostingsWriter::<DocIdRecorder>::serialize_one_term(
|
||||
term_buffer.serialized_value_bytes(),
|
||||
*addr,
|
||||
&mut buffer_lender,
|
||||
ctx,
|
||||
serializer,
|
||||
)?;
|
||||
}
|
||||
|
||||
let json_value = ValueBytes::wrap(term);
|
||||
let typ = json_value.typ();
|
||||
if typ == Type::Str {
|
||||
SpecializedPostingsWriter::<Rec>::serialize_one_term(
|
||||
term_buffer.as_bytes(),
|
||||
*addr,
|
||||
&mut buffer_lender,
|
||||
ctx,
|
||||
serializer,
|
||||
)?;
|
||||
} else {
|
||||
SpecializedPostingsWriter::<DocIdRecorder>::serialize_one_term(
|
||||
term_buffer.as_bytes(),
|
||||
*addr,
|
||||
&mut buffer_lender,
|
||||
ctx,
|
||||
serializer,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
@@ -107,3 +106,48 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
|
||||
self.str_posting_writer.total_num_tokens() + self.non_str_posting_writer.total_num_tokens()
|
||||
}
|
||||
}
|
||||
|
||||
struct JsonTermSerializer(Vec<u8>);
|
||||
impl JsonTermSerializer {
|
||||
/// Appends a JSON path to the Term.
|
||||
/// The path is terminated by a special end-of-path 0 byte.
|
||||
#[inline]
|
||||
pub fn append_json_path(&mut self, path: &str) {
|
||||
let bytes = path.as_bytes();
|
||||
// Replace any occurrence of the end-of-path byte with Ascii '0' byte.
|
||||
if bytes.contains(&JSON_END_OF_PATH) {
|
||||
self.0.extend(
|
||||
bytes
|
||||
.iter()
|
||||
.map(|&b| if b == JSON_END_OF_PATH { b'0' } else { b }),
|
||||
);
|
||||
} else {
|
||||
self.0.extend_from_slice(bytes);
|
||||
}
|
||||
self.0.push(JSON_END_OF_PATH);
|
||||
}
|
||||
|
||||
/// Appends value bytes to the Term.
|
||||
///
|
||||
/// This function returns the segment that has just been added.
|
||||
#[inline]
|
||||
pub fn append_bytes(&mut self, bytes: &[u8]) -> &mut [u8] {
|
||||
let len_before = self.0.len();
|
||||
self.0.extend_from_slice(bytes);
|
||||
&mut self.0[len_before..]
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.0.clear();
|
||||
}
|
||||
fn truncate(&mut self, len: usize) {
|
||||
self.0.truncate(len);
|
||||
}
|
||||
fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
fn as_bytes(&self) -> &[u8] {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ use std::ops::Range;
|
||||
use stacker::Addr;
|
||||
|
||||
use crate::fieldnorm::FieldNormReaders;
|
||||
use crate::indexer::indexing_term::IndexingTerm;
|
||||
use crate::indexer::path_to_unordered_id::OrderedPathId;
|
||||
use crate::postings::recorder::{BufferLender, Recorder};
|
||||
use crate::postings::{
|
||||
@@ -111,7 +112,7 @@ pub(crate) trait PostingsWriter: Send + Sync {
|
||||
/// * term - the term
|
||||
/// * ctx - Contains a term hashmap and a memory arena to store all necessary posting list
|
||||
/// information.
|
||||
fn subscribe(&mut self, doc: DocId, pos: u32, term: &Term, ctx: &mut IndexingContext);
|
||||
fn subscribe(&mut self, doc: DocId, pos: u32, term: &IndexingTerm, ctx: &mut IndexingContext);
|
||||
|
||||
/// Serializes the postings on disk.
|
||||
/// The actual serialization format is handled by the `PostingsSerializer`.
|
||||
@@ -128,7 +129,7 @@ pub(crate) trait PostingsWriter: Send + Sync {
|
||||
&mut self,
|
||||
doc_id: DocId,
|
||||
token_stream: &mut dyn TokenStream,
|
||||
term_buffer: &mut Term,
|
||||
term_buffer: &mut IndexingTerm,
|
||||
ctx: &mut IndexingContext,
|
||||
indexing_position: &mut IndexingPosition,
|
||||
) {
|
||||
@@ -198,7 +199,13 @@ impl<Rec: Recorder> SpecializedPostingsWriter<Rec> {
|
||||
|
||||
impl<Rec: Recorder> PostingsWriter for SpecializedPostingsWriter<Rec> {
|
||||
#[inline]
|
||||
fn subscribe(&mut self, doc: DocId, position: u32, term: &Term, ctx: &mut IndexingContext) {
|
||||
fn subscribe(
|
||||
&mut self,
|
||||
doc: DocId,
|
||||
position: u32,
|
||||
term: &IndexingTerm,
|
||||
ctx: &mut IndexingContext,
|
||||
) {
|
||||
debug_assert!(term.serialized_term().len() >= 4);
|
||||
self.total_num_tokens += 1;
|
||||
let (term_index, arena) = (&mut ctx.term_index, &mut ctx.arena);
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::fieldnorm::FieldNormReader;
|
||||
use crate::query::Explanation;
|
||||
use crate::schema::Field;
|
||||
@@ -57,13 +59,13 @@ fn cached_tf_component(fieldnorm: u32, average_fieldnorm: Score) -> Score {
|
||||
K1 * (1.0 - B + B * fieldnorm as Score / average_fieldnorm)
|
||||
}
|
||||
|
||||
fn compute_tf_cache(average_fieldnorm: Score) -> [Score; 256] {
|
||||
fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> {
|
||||
let mut cache: [Score; 256] = [0.0; 256];
|
||||
for (fieldnorm_id, cache_mut) in cache.iter_mut().enumerate() {
|
||||
let fieldnorm = FieldNormReader::id_to_fieldnorm(fieldnorm_id as u8);
|
||||
*cache_mut = cached_tf_component(fieldnorm, average_fieldnorm);
|
||||
}
|
||||
cache
|
||||
Arc::new(cache)
|
||||
}
|
||||
|
||||
/// A struct used for computing BM25 scores.
|
||||
@@ -71,17 +73,20 @@ fn compute_tf_cache(average_fieldnorm: Score) -> [Score; 256] {
|
||||
pub struct Bm25Weight {
|
||||
idf_explain: Option<Explanation>,
|
||||
weight: Score,
|
||||
cache: [Score; 256],
|
||||
cache: Arc<[Score; 256]>,
|
||||
average_fieldnorm: Score,
|
||||
}
|
||||
|
||||
impl Bm25Weight {
|
||||
/// Increase the weight by a multiplicative factor.
|
||||
pub fn boost_by(&self, boost: Score) -> Bm25Weight {
|
||||
if boost == 1.0f32 {
|
||||
return self.clone();
|
||||
}
|
||||
Bm25Weight {
|
||||
idf_explain: self.idf_explain.clone(),
|
||||
weight: self.weight * boost,
|
||||
cache: self.cache,
|
||||
cache: self.cache.clone(),
|
||||
average_fieldnorm: self.average_fieldnorm,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
|
||||
use crate::query::term_query::TermScorer;
|
||||
use crate::query::weight::{for_each_docset_buffered, for_each_pruning_scorer, for_each_scorer};
|
||||
use crate::query::{
|
||||
intersect_scorers, BufferedUnionScorer, EmptyScorer, Exclude, Explanation, Occur,
|
||||
intersect_scorers, AllScorer, BufferedUnionScorer, EmptyScorer, Exclude, Explanation, Occur,
|
||||
RequiredOptionalScorer, Scorer, Weight,
|
||||
};
|
||||
use crate::{DocId, Score};
|
||||
@@ -97,6 +97,15 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
|
||||
}
|
||||
}
|
||||
|
||||
enum ShouldScorersCombinationMethod {
|
||||
// Should scorers are irrelevant.
|
||||
Ignored,
|
||||
// Only contributes to final score.
|
||||
Optional(SpecializedScorer),
|
||||
// Regardless of score, the should scorers may impact whether a document is matching or not.
|
||||
Required(SpecializedScorer),
|
||||
}
|
||||
|
||||
/// Weight associated to the `BoolQuery`.
|
||||
pub struct BooleanWeight<TScoreCombiner: ScoreCombiner> {
|
||||
weights: Vec<(Occur, Box<dyn Weight>)>,
|
||||
@@ -159,27 +168,50 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
) -> crate::Result<SpecializedScorer> {
|
||||
let num_docs = reader.num_docs();
|
||||
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
|
||||
// Indicate how should clauses are combined with other clauses.
|
||||
enum CombinationMethod {
|
||||
Ignored,
|
||||
// Only contributes to final score.
|
||||
Optional(SpecializedScorer),
|
||||
Required(SpecializedScorer),
|
||||
|
||||
// Indicate how should clauses are combined with must clauses.
|
||||
let mut must_scorers: Vec<Box<dyn Scorer>> =
|
||||
per_occur_scorers.remove(&Occur::Must).unwrap_or_default();
|
||||
let must_special_scorer_counts = remove_and_count_all_and_empty_scorers(&mut must_scorers);
|
||||
|
||||
if must_special_scorer_counts.num_empty_scorers > 0 {
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
|
||||
}
|
||||
let mut must_scorers = per_occur_scorers.remove(&Occur::Must);
|
||||
let should_opt = if let Some(mut should_scorers) = per_occur_scorers.remove(&Occur::Should)
|
||||
{
|
||||
|
||||
let mut should_scorers = per_occur_scorers.remove(&Occur::Should).unwrap_or_default();
|
||||
let should_special_scorer_counts =
|
||||
remove_and_count_all_and_empty_scorers(&mut should_scorers);
|
||||
|
||||
let mut exclude_scorers: Vec<Box<dyn Scorer>> = per_occur_scorers
|
||||
.remove(&Occur::MustNot)
|
||||
.unwrap_or_default();
|
||||
let exclude_special_scorer_counts =
|
||||
remove_and_count_all_and_empty_scorers(&mut exclude_scorers);
|
||||
|
||||
if exclude_special_scorer_counts.num_all_scorers > 0 {
|
||||
// We exclude all documents at one point.
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
|
||||
}
|
||||
|
||||
let minimum_number_should_match = self
|
||||
.minimum_number_should_match
|
||||
.saturating_sub(should_special_scorer_counts.num_all_scorers);
|
||||
|
||||
let should_scorers: ShouldScorersCombinationMethod = {
|
||||
let num_of_should_scorers = should_scorers.len();
|
||||
if self.minimum_number_should_match > num_of_should_scorers {
|
||||
if minimum_number_should_match > num_of_should_scorers {
|
||||
// We don't have enough scorers to satisfy the minimum number of should matches.
|
||||
// The request will match no documents.
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
|
||||
}
|
||||
match self.minimum_number_should_match {
|
||||
0 => CombinationMethod::Optional(scorer_union(
|
||||
match minimum_number_should_match {
|
||||
0 if num_of_should_scorers == 0 => ShouldScorersCombinationMethod::Ignored,
|
||||
0 => ShouldScorersCombinationMethod::Optional(scorer_union(
|
||||
should_scorers,
|
||||
&score_combiner_fn,
|
||||
num_docs,
|
||||
)),
|
||||
1 => CombinationMethod::Required(scorer_union(
|
||||
1 => ShouldScorersCombinationMethod::Required(scorer_union(
|
||||
should_scorers,
|
||||
&score_combiner_fn,
|
||||
num_docs,
|
||||
@@ -187,76 +219,120 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
n if num_of_should_scorers == n => {
|
||||
// When num_of_should_scorers equals the number of should clauses,
|
||||
// they are no different from must clauses.
|
||||
must_scorers = match must_scorers.take() {
|
||||
Some(mut must_scorers) => {
|
||||
must_scorers.append(&mut should_scorers);
|
||||
Some(must_scorers)
|
||||
}
|
||||
None => Some(should_scorers),
|
||||
};
|
||||
CombinationMethod::Ignored
|
||||
must_scorers.append(&mut should_scorers);
|
||||
ShouldScorersCombinationMethod::Ignored
|
||||
}
|
||||
_ => CombinationMethod::Required(SpecializedScorer::Other(scorer_disjunction(
|
||||
should_scorers,
|
||||
score_combiner_fn(),
|
||||
self.minimum_number_should_match,
|
||||
))),
|
||||
}
|
||||
} else {
|
||||
// None of should clauses are provided.
|
||||
if self.minimum_number_should_match > 0 {
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
|
||||
} else {
|
||||
CombinationMethod::Ignored
|
||||
_ => ShouldScorersCombinationMethod::Required(SpecializedScorer::Other(
|
||||
scorer_disjunction(
|
||||
should_scorers,
|
||||
score_combiner_fn(),
|
||||
self.minimum_number_should_match,
|
||||
),
|
||||
)),
|
||||
}
|
||||
};
|
||||
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
|
||||
.remove(&Occur::MustNot)
|
||||
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default, num_docs))
|
||||
.map(|specialized_scorer: SpecializedScorer| {
|
||||
into_box_scorer(specialized_scorer, DoNothingCombiner::default, num_docs)
|
||||
});
|
||||
let positive_scorer = match (should_opt, must_scorers) {
|
||||
(CombinationMethod::Ignored, Some(must_scorers)) => {
|
||||
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
|
||||
|
||||
let exclude_scorer_opt: Option<Box<dyn Scorer>> = if exclude_scorers.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let exclude_specialized_scorer: SpecializedScorer =
|
||||
scorer_union(exclude_scorers, DoNothingCombiner::default, num_docs);
|
||||
Some(into_box_scorer(
|
||||
exclude_specialized_scorer,
|
||||
DoNothingCombiner::default,
|
||||
num_docs,
|
||||
))
|
||||
};
|
||||
|
||||
let include_scorer = match (should_scorers, must_scorers) {
|
||||
(ShouldScorersCombinationMethod::Ignored, must_scorers) => {
|
||||
let boxed_scorer: Box<dyn Scorer> = if must_scorers.is_empty() {
|
||||
// We do not have any should scorers, nor all scorers.
|
||||
// There are still two cases here.
|
||||
//
|
||||
// If this follows the removal of some AllScorers in the should/must clauses,
|
||||
// then we match all documents.
|
||||
//
|
||||
// Otherwise, it is really just an EmptyScorer.
|
||||
if must_special_scorer_counts.num_all_scorers
|
||||
+ should_special_scorer_counts.num_all_scorers
|
||||
> 0
|
||||
{
|
||||
Box::new(AllScorer::new(reader.max_doc()))
|
||||
} else {
|
||||
Box::new(EmptyScorer)
|
||||
}
|
||||
} else {
|
||||
intersect_scorers(must_scorers, num_docs)
|
||||
};
|
||||
SpecializedScorer::Other(boxed_scorer)
|
||||
}
|
||||
(CombinationMethod::Optional(should_scorer), Some(must_scorers)) => {
|
||||
let must_scorer = intersect_scorers(must_scorers, num_docs);
|
||||
if self.scoring_enabled {
|
||||
SpecializedScorer::Other(Box::new(
|
||||
RequiredOptionalScorer::<_, _, TScoreCombiner>::new(
|
||||
(ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => {
|
||||
if must_scorers.is_empty() && must_special_scorer_counts.num_all_scorers == 0 {
|
||||
// Optional options are promoted to required if no must scorers exists.
|
||||
should_scorer
|
||||
} else {
|
||||
let must_scorer = intersect_scorers(must_scorers, num_docs);
|
||||
if self.scoring_enabled {
|
||||
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
|
||||
_,
|
||||
_,
|
||||
TScoreCombiner,
|
||||
>::new(
|
||||
must_scorer,
|
||||
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
|
||||
),
|
||||
))
|
||||
} else {
|
||||
SpecializedScorer::Other(must_scorer)
|
||||
)))
|
||||
} else {
|
||||
SpecializedScorer::Other(must_scorer)
|
||||
}
|
||||
}
|
||||
}
|
||||
(CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => {
|
||||
must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs));
|
||||
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
|
||||
(ShouldScorersCombinationMethod::Required(should_scorer), mut must_scorers) => {
|
||||
if must_scorers.is_empty() {
|
||||
should_scorer
|
||||
} else {
|
||||
must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs));
|
||||
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
|
||||
}
|
||||
}
|
||||
(CombinationMethod::Ignored, None) => {
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)))
|
||||
}
|
||||
(CombinationMethod::Required(should_scorer), None) => should_scorer,
|
||||
// Optional options are promoted to required if no must scorers exists.
|
||||
(CombinationMethod::Optional(should_scorer), None) => should_scorer,
|
||||
};
|
||||
if let Some(exclude_scorer) = exclude_scorer_opt {
|
||||
let positive_scorer_boxed =
|
||||
into_box_scorer(positive_scorer, &score_combiner_fn, num_docs);
|
||||
let include_scorer_boxed =
|
||||
into_box_scorer(include_scorer, &score_combiner_fn, num_docs);
|
||||
Ok(SpecializedScorer::Other(Box::new(Exclude::new(
|
||||
positive_scorer_boxed,
|
||||
include_scorer_boxed,
|
||||
exclude_scorer,
|
||||
))))
|
||||
} else {
|
||||
Ok(positive_scorer)
|
||||
Ok(include_scorer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Copy, Clone, Debug)]
|
||||
struct AllAndEmptyScorerCounts {
|
||||
num_all_scorers: usize,
|
||||
num_empty_scorers: usize,
|
||||
}
|
||||
|
||||
fn remove_and_count_all_and_empty_scorers(
|
||||
scorers: &mut Vec<Box<dyn Scorer>>,
|
||||
) -> AllAndEmptyScorerCounts {
|
||||
let mut counts = AllAndEmptyScorerCounts::default();
|
||||
scorers.retain(|scorer| {
|
||||
if scorer.is::<AllScorer>() {
|
||||
counts.num_all_scorers += 1;
|
||||
false
|
||||
} else if scorer.is::<EmptyScorer>() {
|
||||
counts.num_empty_scorers += 1;
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
counts
|
||||
}
|
||||
|
||||
impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombiner> {
|
||||
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
|
||||
let num_docs = reader.num_docs();
|
||||
@@ -293,7 +369,7 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
|
||||
|
||||
let mut explanation = Explanation::new("BooleanClause. sum of ...", scorer.score());
|
||||
for (occur, subweight) in &self.weights {
|
||||
if is_positive_occur(*occur) {
|
||||
if is_include_occur(*occur) {
|
||||
if let Ok(child_explanation) = subweight.explain(reader, doc) {
|
||||
explanation.add_detail(child_explanation);
|
||||
}
|
||||
@@ -377,7 +453,7 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
|
||||
}
|
||||
}
|
||||
|
||||
fn is_positive_occur(occur: Occur) -> bool {
|
||||
fn is_include_occur(occur: Occur) -> bool {
|
||||
match occur {
|
||||
Occur::Must | Occur::Should => true,
|
||||
Occur::MustNot => false,
|
||||
|
||||
@@ -14,8 +14,8 @@ mod tests {
|
||||
use crate::collector::TopDocs;
|
||||
use crate::query::term_query::TermScorer;
|
||||
use crate::query::{
|
||||
EnableScoring, Intersection, Occur, Query, QueryParser, RequiredOptionalScorer, Scorer,
|
||||
SumCombiner, TermQuery,
|
||||
AllScorer, EmptyScorer, EnableScoring, Intersection, Occur, Query, QueryParser,
|
||||
RequiredOptionalScorer, Scorer, SumCombiner, TermQuery,
|
||||
};
|
||||
use crate::schema::*;
|
||||
use crate::{assert_nearly_equals, DocAddress, DocId, Index, IndexWriter, Score};
|
||||
@@ -182,7 +182,7 @@ mod tests {
|
||||
let matching_topdocs = |query: &dyn Query| {
|
||||
reader
|
||||
.searcher()
|
||||
.search(query, &TopDocs::with_limit(3))
|
||||
.search(query, &TopDocs::with_limit(3).order_by_score())
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
@@ -311,4 +311,67 @@ mod tests {
|
||||
assert_nearly_equals!(explanation.value(), std::f32::consts::LN_2);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_boolean_weight_optimization() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let text_field = schema_builder.add_text_field("text", TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests()?;
|
||||
index_writer.add_document(doc!(text_field=>"hello"))?;
|
||||
index_writer.add_document(doc!(text_field=>"hello happy"))?;
|
||||
index_writer.commit()?;
|
||||
let searcher = index.reader()?.searcher();
|
||||
let term_match_all: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
Term::from_field_text(text_field, "hello"),
|
||||
IndexRecordOption::Basic,
|
||||
));
|
||||
let term_match_some: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
Term::from_field_text(text_field, "happy"),
|
||||
IndexRecordOption::Basic,
|
||||
));
|
||||
let term_match_none: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
Term::from_field_text(text_field, "tax"),
|
||||
IndexRecordOption::Basic,
|
||||
));
|
||||
{
|
||||
let query = BooleanQuery::from(vec![
|
||||
(Occur::Must, term_match_all.box_clone()),
|
||||
(Occur::Must, term_match_some.box_clone()),
|
||||
]);
|
||||
let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?;
|
||||
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
|
||||
assert!(scorer.is::<TermScorer>());
|
||||
}
|
||||
{
|
||||
let query = BooleanQuery::from(vec![
|
||||
(Occur::Must, term_match_all.box_clone()),
|
||||
(Occur::Must, term_match_some.box_clone()),
|
||||
(Occur::Must, term_match_none.box_clone()),
|
||||
]);
|
||||
let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?;
|
||||
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
|
||||
assert!(scorer.is::<EmptyScorer>());
|
||||
}
|
||||
{
|
||||
let query = BooleanQuery::from(vec![
|
||||
(Occur::Should, term_match_all.box_clone()),
|
||||
(Occur::Should, term_match_none.box_clone()),
|
||||
]);
|
||||
let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?;
|
||||
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
|
||||
assert!(scorer.is::<AllScorer>());
|
||||
}
|
||||
{
|
||||
let query = BooleanQuery::from(vec![
|
||||
(Occur::Should, term_match_some.box_clone()),
|
||||
(Occur::Should, term_match_none.box_clone()),
|
||||
]);
|
||||
let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?;
|
||||
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
|
||||
assert!(scorer.is::<TermScorer>());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ use crate::{Score, Term};
|
||||
/// // TermQuery "diary" and "girl" should be present and only one should be accounted in score
|
||||
/// let queries1 = vec![diary_term_query.box_clone(), girl_term_query.box_clone()];
|
||||
/// let diary_and_girl = DisjunctionMaxQuery::new(queries1);
|
||||
/// let documents = searcher.search(&diary_and_girl, &TopDocs::with_limit(3))?;
|
||||
/// let documents = searcher.search(&diary_and_girl, &TopDocs::with_limit(3).order_by_score())?;
|
||||
/// assert_eq!(documents[0].0, documents[1].0);
|
||||
/// assert_eq!(documents[1].0, documents[2].0);
|
||||
///
|
||||
@@ -62,7 +62,7 @@ use crate::{Score, Term};
|
||||
/// let queries2 = vec![diary_term_query.box_clone(), girl_term_query.box_clone()];
|
||||
/// let tie_breaker = 0.7;
|
||||
/// let diary_and_girl_with_tie_breaker = DisjunctionMaxQuery::with_tie_breaker(queries2, tie_breaker);
|
||||
/// let documents = searcher.search(&diary_and_girl_with_tie_breaker, &TopDocs::with_limit(3))?;
|
||||
/// let documents = searcher.search(&diary_and_girl_with_tie_breaker, &TopDocs::with_limit(3).order_by_score())?;
|
||||
/// assert_eq!(documents[1].0, documents[2].0);
|
||||
/// // For this test all terms brings the same score. So we can do easy math and assume that
|
||||
/// // `DisjunctionMaxQuery` with tie breakers score should be equal
|
||||
|
||||
@@ -127,7 +127,11 @@ impl Weight for ExistsWeight {
|
||||
.any(|col| matches!(col.column_index(), ColumnIndex::Full))
|
||||
{
|
||||
let all_scorer = AllScorer::new(max_doc);
|
||||
return Ok(Box::new(BoostScorer::new(all_scorer, boost)));
|
||||
if boost != 1.0f32 {
|
||||
return Ok(Box::new(BoostScorer::new(all_scorer, boost)));
|
||||
} else {
|
||||
return Ok(Box::new(all_scorer));
|
||||
}
|
||||
}
|
||||
|
||||
// If we have a single dynamic column, use ExistsDocSet
|
||||
|
||||
@@ -67,7 +67,7 @@ impl Automaton for DfaWrapper {
|
||||
/// {
|
||||
/// let term = Term::from_field_text(title, "Diary");
|
||||
/// let query = FuzzyTermQuery::new(term, 1, true);
|
||||
/// let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count)).unwrap();
|
||||
/// let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2).order_by_score(), Count)).unwrap();
|
||||
/// assert_eq!(count, 2);
|
||||
/// assert_eq!(top_docs.len(), 2);
|
||||
/// }
|
||||
@@ -241,7 +241,8 @@ mod test {
|
||||
{
|
||||
let term = get_json_path_term("attributes.aa:japan")?;
|
||||
let fuzzy_query = FuzzyTermQuery::new(term, 2, true);
|
||||
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
|
||||
let top_docs =
|
||||
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
|
||||
assert_eq!(top_docs[0].1.doc_id, 1, "Expected the second document");
|
||||
}
|
||||
@@ -252,7 +253,8 @@ mod test {
|
||||
let term = get_json_path_term("attributes.a:japon")?;
|
||||
|
||||
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
|
||||
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
|
||||
let top_docs =
|
||||
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
|
||||
assert_eq!(top_docs[0].1.doc_id, 0, "Expected the first document");
|
||||
}
|
||||
@@ -262,7 +264,8 @@ mod test {
|
||||
let term = get_json_path_term("attributes.a:jap")?;
|
||||
|
||||
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
|
||||
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
|
||||
let top_docs =
|
||||
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(top_docs.len(), 0, "Expected no document");
|
||||
}
|
||||
|
||||
@@ -292,7 +295,8 @@ mod test {
|
||||
{
|
||||
let term = Term::from_field_text(country_field, "japon");
|
||||
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
|
||||
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
|
||||
let top_docs =
|
||||
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
|
||||
let (score, _) = top_docs[0];
|
||||
assert_nearly_equals!(1.0, score);
|
||||
@@ -303,7 +307,8 @@ mod test {
|
||||
let term = Term::from_field_text(country_field, "jap");
|
||||
|
||||
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
|
||||
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
|
||||
let top_docs =
|
||||
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(top_docs.len(), 0, "Expected no document");
|
||||
}
|
||||
|
||||
@@ -311,7 +316,8 @@ mod test {
|
||||
{
|
||||
let term = Term::from_field_text(country_field, "jap");
|
||||
let fuzzy_query = FuzzyTermQuery::new_prefix(term, 1, true);
|
||||
let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
|
||||
let top_docs =
|
||||
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
|
||||
let (score, _) = top_docs[0];
|
||||
assert_nearly_equals!(1.0, score);
|
||||
|
||||
@@ -267,7 +267,7 @@ mod tests {
|
||||
.with_boost_factor(1.0)
|
||||
.with_stop_words(vec!["old".to_string()])
|
||||
.with_document(DocAddress::new(0, 0));
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(5))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(5).order_by_score())?;
|
||||
let mut doc_ids: Vec<_> = top_docs.iter().map(|item| item.1.doc_id).collect();
|
||||
doc_ids.sort_unstable();
|
||||
|
||||
@@ -283,7 +283,7 @@ mod tests {
|
||||
.with_max_word_length(5)
|
||||
.with_boost_factor(1.0)
|
||||
.with_document(DocAddress::new(0, 4));
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(5))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(5).order_by_score())?;
|
||||
let mut doc_ids: Vec<_> = top_docs.iter().map(|item| item.1.doc_id).collect();
|
||||
doc_ids.sort_unstable();
|
||||
|
||||
|
||||
@@ -266,8 +266,9 @@ mod tests {
|
||||
use super::RangeQuery;
|
||||
use crate::collector::{Count, TopDocs};
|
||||
use crate::indexer::NoMergePolicy;
|
||||
use crate::query::range_query::fast_field_range_doc_set::RangeDocSet;
|
||||
use crate::query::range_query::range_query::InvertedIndexRangeQuery;
|
||||
use crate::query::QueryParser;
|
||||
use crate::query::{AllScorer, ConstScorer, EmptyScorer, EnableScoring, Query, QueryParser};
|
||||
use crate::schema::{
|
||||
Field, IntoIpv6Addr, Schema, TantivyDocument, FAST, INDEXED, STORED, TEXT,
|
||||
};
|
||||
@@ -495,7 +496,7 @@ mod tests {
|
||||
let searcher = reader.searcher();
|
||||
let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
let query = query_parser.parse_query("hemoglobin AND year:[1970 TO 1990]")?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
|
||||
assert_eq!(top_docs.len(), 1);
|
||||
Ok(())
|
||||
}
|
||||
@@ -549,7 +550,7 @@ mod tests {
|
||||
|
||||
let get_num_hits = |query| {
|
||||
let (_top_docs, count) = searcher
|
||||
.search(&query, &(TopDocs::with_limit(10), Count))
|
||||
.search(&query, &(TopDocs::with_limit(10).order_by_score(), Count))
|
||||
.unwrap();
|
||||
count
|
||||
};
|
||||
@@ -660,4 +661,46 @@ mod tests {
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range_query_simplified() {
|
||||
// This test checks that if the targeted column values are entirely
|
||||
// within the range, and the column is full, we end up with a AllScorer.
|
||||
let mut schema_builder = Schema::builder();
|
||||
let u64_field = schema_builder.add_u64_field("u64_field", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
index_writer.add_document(doc!(u64_field=> 2u64)).unwrap();
|
||||
index_writer.add_document(doc!(u64_field=> 4u64)).unwrap();
|
||||
index_writer.commit().unwrap();
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
assert_eq!(searcher.segment_readers().len(), 1);
|
||||
let make_term = |value: u64| Term::from_field_u64(u64_field, value);
|
||||
let make_scorer = move |lower_bound: Bound<u64>, upper_bound: Bound<u64>| {
|
||||
let lower_bound_term = lower_bound.map(make_term);
|
||||
let upper_bound_term = upper_bound.map(make_term);
|
||||
let range_query = RangeQuery::new(lower_bound_term, upper_bound_term);
|
||||
let range_weight = range_query
|
||||
.weight(EnableScoring::disabled_from_schema(&schema))
|
||||
.unwrap();
|
||||
let range_scorer = range_weight
|
||||
.scorer(&searcher.segment_readers()[0], 1.0f32)
|
||||
.unwrap();
|
||||
range_scorer
|
||||
};
|
||||
let range_scorer = make_scorer(Bound::Included(1), Bound::Included(4));
|
||||
assert!(range_scorer.is::<AllScorer>());
|
||||
let range_scorer = make_scorer(Bound::Included(0), Bound::Included(2));
|
||||
assert!(range_scorer.is::<ConstScorer<RangeDocSet<u64>>>());
|
||||
let range_scorer = make_scorer(Bound::Included(3), Bound::Included(10));
|
||||
assert!(range_scorer.is::<ConstScorer<RangeDocSet<u64>>>());
|
||||
let range_scorer = make_scorer(Bound::Included(10), Bound::Included(12));
|
||||
assert!(range_scorer.is::<ConstScorer<RangeDocSet<u64>>>());
|
||||
let range_scorer = make_scorer(Bound::Included(0), Bound::Included(1));
|
||||
assert!(range_scorer.is::<EmptyScorer>());
|
||||
let range_scorer = make_scorer(Bound::Included(0), Bound::Excluded(2));
|
||||
assert!(range_scorer.is::<EmptyScorer>());
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user