Compare commits

..

7 Commits

Author SHA1 Message Date
PSeitz-dd
65b5a1a306 one collector per agg request instead per bucket (#2759)
* improve bench

* add more tests for new collection type

* one collector per agg request instead per bucket

In this refactoring a collector knows in which bucket of the parent
their data is in. This allows to convert the previous approach of one
collector per bucket to one collector per request.

low card bucket optimization

* reduce dynamic dispatch, faster term agg

* use radix map, fix prepare_max_bucket

use paged term map in term agg
use special no sub agg term map impl

* specialize columntype in stats

* remove stacktrace bloat, use &mut helper

increase cache to 2048

* cleanup

remove clone
move data in term req, single doc opt for stats

* add comment

* share column block accessor

* simplify fetch block in column_block_accessor

* split subaggcache into two trait impls

* move partitions to heap

* fix name, add comment

---------

Co-authored-by: Pascal Seitz <pascal.seitz@gmail.com>
2026-01-06 11:50:55 +01:00
ChangRui-Ryan
db2ecc6057 fix Column.first method parameter type (#2792) 2026-01-05 10:03:01 +01:00
Paul Masurel
77505c3d03 Making stemming optional. (#2791)
Fixed code and CI to run on no default features.

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2026-01-02 12:40:42 +01:00
PSeitz
735c588f4f fix union performance regression (#2790)
* add inlines

* fix union performance regression

Remove unwrap from hotpath generates better assembly.

closes #2788
2026-01-02 12:06:51 +01:00
PSeitz
242a1531bf fix flaky test (#2784)
Signed-off-by: Pascal Seitz <pascal.seitz@gmail.com>
2026-01-02 11:30:51 +01:00
trinity-1686a
6443b63177 document 1bit hole and some queries supporting running with just fastfield (#2779)
* add small doc on some queries using fast field when not indexed

* document 1 unused bit in skiplist
2026-01-02 10:32:37 +01:00
Stu Hood
4987495ee4 Add an erased SortKeyComputer to sort on types which are not known until runtime (#2770)
* Remove PartialOrd bound on compared values.

* Fix declared `SortKey` type of `impl<..> SortKeyComputer for (HeadSortKeyComputer, TailSortKeyComputer)`

* Add a SortByOwnedValue implementation to provide a type-erased column.

* Add support for comparing mismatched `OwnedValue` types.

* Support JSON columns.

* Refer to https://github.com/quickwit-oss/tantivy/issues/2776

* Rename to `SortByErasedType`.

* Comment on transitivity.

Co-authored-by: Paul Masurel <paul@quickwit.io>

* Fix clippy warnings in new code.

---------

Co-authored-by: Paul Masurel <paul@quickwit.io>
2026-01-02 10:28:47 +01:00
92 changed files with 3157 additions and 2061 deletions

View File

@@ -39,11 +39,11 @@ jobs:
- name: Check Formatting - name: Check Formatting
run: cargo +nightly fmt --all -- --check run: cargo +nightly fmt --all -- --check
- name: Check Stable Compilation - name: Check Stable Compilation
run: cargo build --all-features run: cargo build --all-features
- name: Check Bench Compilation - name: Check Bench Compilation
run: cargo +nightly bench --no-run --profile=dev --all-features run: cargo +nightly bench --no-run --profile=dev --all-features
@@ -59,10 +59,10 @@ jobs:
strategy: strategy:
matrix: matrix:
features: [ features:
{ label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints" }, - { label: "all", flags: "mmap,stopwords,lz4-compression,zstd-compression,failpoints,stemmer" }
{ label: "quickwit", flags: "mmap,quickwit,failpoints" } - { label: "quickwit", flags: "mmap,quickwit,failpoints" }
] - { label: "none", flags: "" }
name: test-${{ matrix.features.label}} name: test-${{ matrix.features.label}}
@@ -80,7 +80,21 @@ jobs:
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
- name: Run tests - name: Run tests
run: cargo +stable nextest run --features ${{ matrix.features.flags }} --verbose --workspace run: |
# if matrix.feature.flags is empty then run on --lib to avoid compiling examples
# (as most of them rely on mmap) otherwise run all
if [ -z "${{ matrix.features.flags }}" ]; then
cargo +stable nextest run --lib --no-default-features --verbose --workspace
else
cargo +stable nextest run --features ${{ matrix.features.flags }} --no-default-features --verbose --workspace
fi
- name: Run doctests - name: Run doctests
run: cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace run: |
# if matrix.feature.flags is empty then run on --lib to avoid compiling examples
# (as most of them rely on mmap) otherwise run all
if [ -z "${{ matrix.features.flags }}" ]; then
echo "no doctest for no feature flag"
else
cargo +stable test --doc --features ${{ matrix.features.flags }} --verbose --workspace
fi

View File

@@ -37,7 +37,7 @@ fs4 = { version = "0.13.1", optional = true }
levenshtein_automata = "0.2.1" levenshtein_automata = "0.2.1"
uuid = { version = "1.0.0", features = ["v4", "serde"] } uuid = { version = "1.0.0", features = ["v4", "serde"] }
crossbeam-channel = "0.5.4" crossbeam-channel = "0.5.4"
rust-stemmers = "1.2.0" rust-stemmers = { version = "1.2.0", optional = true }
downcast-rs = "2.0.1" downcast-rs = "2.0.1"
bitpacking = { version = "0.9.2", default-features = false, features = [ bitpacking = { version = "0.9.2", default-features = false, features = [
"bitpacker4x", "bitpacker4x",
@@ -113,7 +113,8 @@ debug-assertions = true
overflow-checks = true overflow-checks = true
[features] [features]
default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression"] default = ["mmap", "stopwords", "lz4-compression", "columnar-zstd-compression", "stemmer"]
stemmer = ["rust-stemmers"]
mmap = ["fs4", "tempfile", "memmap2"] mmap = ["fs4", "tempfile", "memmap2"]
stopwords = [] stopwords = []

View File

@@ -54,33 +54,33 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, stats_f64); register!(group, stats_f64);
register!(group, extendedstats_f64); register!(group, extendedstats_f64);
register!(group, percentiles_f64); register!(group, percentiles_f64);
register!(group, terms_few); register!(group, terms_7);
register!(group, terms_all_unique); register!(group, terms_all_unique);
register!(group, terms_many); register!(group, terms_150_000);
register!(group, terms_many_top_1000); register!(group, terms_many_top_1000);
register!(group, terms_many_order_by_term); register!(group, terms_many_order_by_term);
register!(group, terms_many_with_top_hits); register!(group, terms_many_with_top_hits);
register!(group, terms_all_unique_with_avg_sub_agg); register!(group, terms_all_unique_with_avg_sub_agg);
register!(group, terms_many_with_avg_sub_agg); register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_few_with_avg_sub_agg);
register!(group, terms_status_with_avg_sub_agg); register!(group, terms_status_with_avg_sub_agg);
register!(group, terms_status);
register!(group, terms_few_with_histogram);
register!(group, terms_status_with_histogram); register!(group, terms_status_with_histogram);
register!(group, terms_zipf_1000);
register!(group, terms_zipf_1000_with_histogram);
register!(group, terms_zipf_1000_with_avg_sub_agg);
register!(group, terms_many_json_mixed_type_with_avg_sub_agg); register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
register!(group, cardinality_agg); register!(group, cardinality_agg);
register!(group, terms_few_with_cardinality_agg); register!(group, terms_status_with_cardinality_agg);
register!(group, range_agg); register!(group, range_agg);
register!(group, range_agg_with_avg_sub_agg); register!(group, range_agg_with_avg_sub_agg);
register!(group, range_agg_with_term_agg_few); register!(group, range_agg_with_term_agg_status);
register!(group, range_agg_with_term_agg_many); register!(group, range_agg_with_term_agg_many);
register!(group, histogram); register!(group, histogram);
register!(group, histogram_hard_bounds); register!(group, histogram_hard_bounds);
register!(group, histogram_with_avg_sub_agg); register!(group, histogram_with_avg_sub_agg);
register!(group, histogram_with_term_agg_few); register!(group, histogram_with_term_agg_status);
register!(group, avg_and_range_with_avg_sub_agg); register!(group, avg_and_range_with_avg_sub_agg);
// Filter aggregation benchmarks // Filter aggregation benchmarks
@@ -159,10 +159,10 @@ fn cardinality_agg(index: &Index) {
}); });
execute_agg(index, agg_req); execute_agg(index, agg_req);
} }
fn terms_few_with_cardinality_agg(index: &Index) { fn terms_status_with_cardinality_agg(index: &Index) {
let agg_req = json!({ let agg_req = json!({
"my_texts": { "my_texts": {
"terms": { "field": "text_few_terms" }, "terms": { "field": "text_few_terms_status" },
"aggs": { "aggs": {
"cardinality": { "cardinality": {
"cardinality": { "cardinality": {
@@ -175,13 +175,7 @@ fn terms_few_with_cardinality_agg(index: &Index) {
execute_agg(index, agg_req); execute_agg(index, agg_req);
} }
fn terms_few(index: &Index) { fn terms_7(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms" } },
});
execute_agg(index, agg_req);
}
fn terms_status(index: &Index) {
let agg_req = json!({ let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms_status" } }, "my_texts": { "terms": { "field": "text_few_terms_status" } },
}); });
@@ -194,7 +188,7 @@ fn terms_all_unique(index: &Index) {
execute_agg(index, agg_req); execute_agg(index, agg_req);
} }
fn terms_many(index: &Index) { fn terms_150_000(index: &Index) {
let agg_req = json!({ let agg_req = json!({
"my_texts": { "terms": { "field": "text_many_terms" } }, "my_texts": { "terms": { "field": "text_many_terms" } },
}); });
@@ -253,17 +247,6 @@ fn terms_all_unique_with_avg_sub_agg(index: &Index) {
}); });
execute_agg(index, agg_req); execute_agg(index, agg_req);
} }
fn terms_few_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_histogram(index: &Index) { fn terms_status_with_histogram(index: &Index) {
let agg_req = json!({ let agg_req = json!({
"my_texts": { "my_texts": {
@@ -276,17 +259,18 @@ fn terms_status_with_histogram(index: &Index) {
execute_agg(index, agg_req); execute_agg(index, agg_req);
} }
fn terms_few_with_avg_sub_agg(index: &Index) { fn terms_zipf_1000_with_histogram(index: &Index) {
let agg_req = json!({ let agg_req = json!({
"my_texts": { "my_texts": {
"terms": { "field": "text_few_terms" }, "terms": { "field": "text_1000_terms_zipf" },
"aggs": { "aggs": {
"average_f64": { "avg": { "field": "score_f64" } } "histo": {"histogram": { "field": "score_f64", "interval": 10 }}
} }
}, }
}); });
execute_agg(index, agg_req); execute_agg(index, agg_req);
} }
fn terms_status_with_avg_sub_agg(index: &Index) { fn terms_status_with_avg_sub_agg(index: &Index) {
let agg_req = json!({ let agg_req = json!({
"my_texts": { "my_texts": {
@@ -299,6 +283,25 @@ fn terms_status_with_avg_sub_agg(index: &Index) {
execute_agg(index, agg_req); execute_agg(index, agg_req);
} }
fn terms_zipf_1000_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_1000_terms_zipf" } },
});
execute_agg(index, agg_req);
}
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) { fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
let agg_req = json!({ let agg_req = json!({
"my_texts": { "my_texts": {
@@ -354,7 +357,7 @@ fn range_agg_with_avg_sub_agg(index: &Index) {
execute_agg(index, agg_req); execute_agg(index, agg_req);
} }
fn range_agg_with_term_agg_few(index: &Index) { fn range_agg_with_term_agg_status(index: &Index) {
let agg_req = json!({ let agg_req = json!({
"rangef64": { "rangef64": {
"range": { "range": {
@@ -369,7 +372,7 @@ fn range_agg_with_term_agg_few(index: &Index) {
] ]
}, },
"aggs": { "aggs": {
"my_texts": { "terms": { "field": "text_few_terms" } }, "my_texts": { "terms": { "field": "text_few_terms_status" } },
} }
}, },
}); });
@@ -425,12 +428,12 @@ fn histogram_with_avg_sub_agg(index: &Index) {
}); });
execute_agg(index, agg_req); execute_agg(index, agg_req);
} }
fn histogram_with_term_agg_few(index: &Index) { fn histogram_with_term_agg_status(index: &Index) {
let agg_req = json!({ let agg_req = json!({
"rangef64": { "rangef64": {
"histogram": { "field": "score_f64", "interval": 10 }, "histogram": { "field": "score_f64", "interval": 10 },
"aggs": { "aggs": {
"my_texts": { "terms": { "field": "text_few_terms" } } "my_texts": { "terms": { "field": "text_few_terms_status" } }
} }
} }
}); });
@@ -475,6 +478,13 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
} }
fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> { fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
// Flag to use existing index
let reuse_index = std::env::var("REUSE_AGG_BENCH_INDEX").is_ok();
if reuse_index && std::path::Path::new("agg_bench").exists() {
return Index::open_in_dir("agg_bench");
}
// crreate dir
std::fs::create_dir_all("agg_bench")?;
let mut schema_builder = Schema::builder(); let mut schema_builder = Schema::builder();
let text_fieldtype = tantivy::schema::TextOptions::default() let text_fieldtype = tantivy::schema::TextOptions::default()
.set_indexing_options( .set_indexing_options(
@@ -486,24 +496,44 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
let text_field_all_unique_terms = let text_field_all_unique_terms =
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST); schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST); let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
let text_field_few_terms_status = let text_field_few_terms_status =
schema_builder.add_text_field("text_few_terms_status", STRING | FAST); schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
let text_field_1000_terms_zipf =
schema_builder.add_text_field("text_1000_terms_zipf", STRING | FAST);
let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast(); let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast();
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone()); let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone()); let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype); let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
let index = Index::create_from_tempdir(schema_builder.build())?; // use tmp dir
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"]; let index = if reuse_index {
// Approximate production log proportions: INFO dominant, WARN and DEBUG occasional, ERROR rare. Index::create_in_dir("agg_bench", schema_builder.build())?
let log_level_distribution = WeightedIndex::new([80u32, 3, 12, 5]).unwrap(); } else {
Index::create_from_tempdir(schema_builder.build())?
};
// Approximate log proportions
let status_field_data = [
("INFO", 8000),
("ERROR", 300),
("WARN", 1200),
("DEBUG", 500),
("OK", 500),
("CRITICAL", 20),
("EMERGENCY", 1),
];
let log_level_distribution =
WeightedIndex::new(status_field_data.iter().map(|item| item.1)).unwrap();
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap(); let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
let many_terms_data = (0..150_000) let many_terms_data = (0..150_000)
.map(|num| format!("author{num}")) .map(|num| format!("author{num}"))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
// Prepare 1000 unique terms sampled using a Zipf distribution.
// Exponent ~1.1 approximates top-20 terms covering around ~20%.
let terms_1000: Vec<String> = (1..=1000).map(|i| format!("term_{i}")).collect();
let zipf_1000 = rand_distr::Zipf::new(1000, 1.1f64).unwrap();
{ {
let mut rng = StdRng::from_seed([1u8; 32]); let mut rng = StdRng::from_seed([1u8; 32]);
let mut index_writer = index.writer_with_num_threads(1, 200_000_000)?; let mut index_writer = index.writer_with_num_threads(1, 200_000_000)?;
@@ -513,8 +543,12 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!())?; index_writer.add_document(doc!())?;
} }
if cardinality == Cardinality::Multivalued { if cardinality == Cardinality::Multivalued {
let log_level_sample_a = few_terms_data[log_level_distribution.sample(&mut rng)]; let log_level_sample_a = status_field_data[log_level_distribution.sample(&mut rng)].0;
let log_level_sample_b = few_terms_data[log_level_distribution.sample(&mut rng)]; let log_level_sample_b = status_field_data[log_level_distribution.sample(&mut rng)].0;
let idx_a = zipf_1000.sample(&mut rng) as usize - 1;
let idx_b = zipf_1000.sample(&mut rng) as usize - 1;
let term_1000_a = &terms_1000[idx_a];
let term_1000_b = &terms_1000[idx_b];
index_writer.add_document(doc!( index_writer.add_document(doc!(
json_field => json!({"mixed_type": 10.0}), json_field => json!({"mixed_type": 10.0}),
json_field => json!({"mixed_type": 10.0}), json_field => json!({"mixed_type": 10.0}),
@@ -524,10 +558,10 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
text_field_all_unique_terms => "coolo", text_field_all_unique_terms => "coolo",
text_field_many_terms => "cool", text_field_many_terms => "cool",
text_field_many_terms => "cool", text_field_many_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms_status => log_level_sample_a, text_field_few_terms_status => log_level_sample_a,
text_field_few_terms_status => log_level_sample_b, text_field_few_terms_status => log_level_sample_b,
text_field_1000_terms_zipf => term_1000_a.as_str(),
text_field_1000_terms_zipf => term_1000_b.as_str(),
score_field => 1u64, score_field => 1u64,
score_field => 1u64, score_field => 1u64,
score_field_f64 => lg_norm.sample(&mut rng), score_field_f64 => lg_norm.sample(&mut rng),
@@ -554,8 +588,8 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
json_field => json, json_field => json,
text_field_all_unique_terms => format!("unique_term_{}", rng.gen::<u64>()), text_field_all_unique_terms => format!("unique_term_{}", rng.gen::<u64>()),
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(), text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(), text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0,
text_field_few_terms_status => few_terms_data[log_level_distribution.sample(&mut rng)], text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(),
score_field => val as u64, score_field => val as u64,
score_field_f64 => lg_norm.sample(&mut rng), score_field_f64 => lg_norm.sample(&mut rng),
score_field_i64 => val as i64, score_field_i64 => val as i64,
@@ -607,7 +641,7 @@ fn filter_agg_all_query_with_sub_aggs(index: &Index) {
"avg_score": { "avg": { "field": "score" } }, "avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } }, "stats_score": { "stats": { "field": "score_f64" } },
"terms_text": { "terms_text": {
"terms": { "field": "text_few_terms" } "terms": { "field": "text_few_terms_status" }
} }
} }
} }
@@ -623,7 +657,7 @@ fn filter_agg_term_query_with_sub_aggs(index: &Index) {
"avg_score": { "avg": { "field": "score" } }, "avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } }, "stats_score": { "stats": { "field": "score_f64" } },
"terms_text": { "terms_text": {
"terms": { "field": "text_few_terms" } "terms": { "field": "text_few_terms_status" }
} }
} }
} }

View File

@@ -29,12 +29,20 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
} }
} }
#[inline] #[inline]
pub fn fetch_block_with_missing(&mut self, docs: &[u32], accessor: &Column<T>, missing: T) { pub fn fetch_block_with_missing(
&mut self,
docs: &[u32],
accessor: &Column<T>,
missing: Option<T>,
) {
self.fetch_block(docs, accessor); self.fetch_block(docs, accessor);
// no missing values // no missing values
if accessor.index.get_cardinality().is_full() { if accessor.index.get_cardinality().is_full() {
return; return;
} }
let Some(missing) = missing else {
return;
};
// We can compare docid_cache length with docs to find missing docs // We can compare docid_cache length with docs to find missing docs
// For multi value columns we can't rely on the length and always need to scan // For multi value columns we can't rely on the length and always need to scan

View File

@@ -85,8 +85,8 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
} }
#[inline] #[inline]
pub fn first(&self, row_id: RowId) -> Option<T> { pub fn first(&self, doc_id: DocId) -> Option<T> {
self.values_for_doc(row_id).next() self.values_for_doc(doc_id).next()
} }
/// Load the first value for each docid in the provided slice. /// Load the first value for each docid in the provided slice.

View File

@@ -60,7 +60,7 @@ fn test_dataframe_writer_bool() {
let DynamicColumn::Bool(bool_col) = dyn_bool_col else { let DynamicColumn::Bool(bool_col) = dyn_bool_col else {
panic!(); panic!();
}; };
let vals: Vec<Option<bool>> = (0..5).map(|row_id| bool_col.first(row_id)).collect(); let vals: Vec<Option<bool>> = (0..5).map(|doc_id| bool_col.first(doc_id)).collect();
assert_eq!(&vals, &[None, Some(false), None, Some(true), None,]); assert_eq!(&vals, &[None, Some(false), None, Some(true), None,]);
} }
@@ -108,7 +108,7 @@ fn test_dataframe_writer_ip_addr() {
let DynamicColumn::IpAddr(ip_col) = dyn_bool_col else { let DynamicColumn::IpAddr(ip_col) = dyn_bool_col else {
panic!(); panic!();
}; };
let vals: Vec<Option<Ipv6Addr>> = (0..5).map(|row_id| ip_col.first(row_id)).collect(); let vals: Vec<Option<Ipv6Addr>> = (0..5).map(|doc_id| ip_col.first(doc_id)).collect();
assert_eq!( assert_eq!(
&vals, &vals,
&[ &[
@@ -169,7 +169,7 @@ fn test_dictionary_encoded_str() {
let DynamicColumn::Str(str_col) = col_handles[0].open().unwrap() else { let DynamicColumn::Str(str_col) = col_handles[0].open().unwrap() else {
panic!(); panic!();
}; };
let index: Vec<Option<u64>> = (0..5).map(|row_id| str_col.ords().first(row_id)).collect(); let index: Vec<Option<u64>> = (0..5).map(|doc_id| str_col.ords().first(doc_id)).collect();
assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]); assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]);
assert_eq!(str_col.num_rows(), 5); assert_eq!(str_col.num_rows(), 5);
let mut term_buffer = String::new(); let mut term_buffer = String::new();
@@ -204,7 +204,7 @@ fn test_dictionary_encoded_bytes() {
panic!(); panic!();
}; };
let index: Vec<Option<u64>> = (0..5) let index: Vec<Option<u64>> = (0..5)
.map(|row_id| bytes_col.ords().first(row_id)) .map(|doc_id| bytes_col.ords().first(doc_id))
.collect(); .collect();
assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]); assert_eq!(index, &[None, Some(0), None, Some(2), Some(1)]);
assert_eq!(bytes_col.num_rows(), 5); assert_eq!(bytes_col.num_rows(), 5);

View File

@@ -181,6 +181,14 @@ pub struct BitSet {
len: u64, len: u64,
max_value: u32, max_value: u32,
} }
impl std::fmt::Debug for BitSet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BitSet")
.field("len", &self.len)
.field("max_value", &self.max_value)
.finish()
}
}
fn num_buckets(max_val: u32) -> u32 { fn num_buckets(max_val: u32) -> u32 {
max_val.div_ceil(64u32) max_val.div_ceil(64u32)

View File

@@ -1,4 +1,4 @@
use columnar::{Column, ColumnType, StrColumn}; use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn};
use common::BitSet; use common::BitSet;
use rustc_hash::FxHashSet; use rustc_hash::FxHashSet;
use serde::Serialize; use serde::Serialize;
@@ -10,16 +10,16 @@ use crate::aggregation::accessor_helpers::{
}; };
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations}; use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
use crate::aggregation::bucket::{ use crate::aggregation::bucket::{
FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam, build_segment_filter_collector, build_segment_range_collector, FilterAggReqData,
MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector, HistogramAggReqData, HistogramBounds, IncludeExcludeParam, MissingTermAggReqData,
SegmentRangeCollector, TermMissingAgg, TermsAggReqData, TermsAggregation, RangeAggReqData, SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
TermsAggregationInternal, TermsAggregationInternal,
}; };
use crate::aggregation::metric::{ use crate::aggregation::metric::{
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation, build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
ExtendedStatsAggregation, MaxAggregation, MetricAggReqData, MinAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation,
SegmentCardinalityCollector, SegmentExtendedStatsCollector, SegmentPercentilesCollector, MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
SegmentStatsCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData, SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
TopHitsSegmentCollector, TopHitsSegmentCollector,
}; };
use crate::aggregation::segment_agg_result::{ use crate::aggregation::segment_agg_result::{
@@ -35,6 +35,7 @@ pub struct AggregationsSegmentCtx {
/// Request data for each aggregation type. /// Request data for each aggregation type.
pub per_request: PerRequestAggSegCtx, pub per_request: PerRequestAggSegCtx,
pub context: AggContextParams, pub context: AggContextParams,
pub column_block_accessor: ColumnBlockAccessor<u64>,
} }
impl AggregationsSegmentCtx { impl AggregationsSegmentCtx {
@@ -107,21 +108,14 @@ impl AggregationsSegmentCtx {
.as_deref() .as_deref()
.expect("range_req_data slot is empty (taken)") .expect("range_req_data slot is empty (taken)")
} }
#[inline]
pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData {
self.per_request.filter_req_data[idx]
.as_deref()
.expect("filter_req_data slot is empty (taken)")
}
// ---------- mutable getters ---------- // ---------- mutable getters ----------
#[inline] #[inline]
pub(crate) fn get_term_req_data_mut(&mut self, idx: usize) -> &mut TermsAggReqData { pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
self.per_request.term_req_data[idx] &mut self.per_request.stats_metric_req_data[idx]
.as_deref_mut()
.expect("term_req_data slot is empty (taken)")
} }
#[inline] #[inline]
pub(crate) fn get_cardinality_req_data_mut( pub(crate) fn get_cardinality_req_data_mut(
&mut self, &mut self,
@@ -129,10 +123,7 @@ impl AggregationsSegmentCtx {
) -> &mut CardinalityAggReqData { ) -> &mut CardinalityAggReqData {
&mut self.per_request.cardinality_req_data[idx] &mut self.per_request.cardinality_req_data[idx]
} }
#[inline]
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
&mut self.per_request.stats_metric_req_data[idx]
}
#[inline] #[inline]
pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData { pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData {
self.per_request.histogram_req_data[idx] self.per_request.histogram_req_data[idx]
@@ -142,21 +133,6 @@ impl AggregationsSegmentCtx {
// ---------- take / put (terms, histogram, range) ---------- // ---------- take / put (terms, histogram, range) ----------
/// Move out the boxed Terms request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_term_req_data(&mut self, idx: usize) -> Box<TermsAggReqData> {
self.per_request.term_req_data[idx]
.take()
.expect("term_req_data slot is empty (taken)")
}
/// Put back a Terms request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_term_req_data(&mut self, idx: usize, value: Box<TermsAggReqData>) {
debug_assert!(self.per_request.term_req_data[idx].is_none());
self.per_request.term_req_data[idx] = Some(value);
}
/// Move out the boxed Histogram request at `idx`, leaving `None`. /// Move out the boxed Histogram request at `idx`, leaving `None`.
#[inline] #[inline]
pub(crate) fn take_histogram_req_data(&mut self, idx: usize) -> Box<HistogramAggReqData> { pub(crate) fn take_histogram_req_data(&mut self, idx: usize) -> Box<HistogramAggReqData> {
@@ -320,6 +296,7 @@ impl PerRequestAggSegCtx {
/// Convert the aggregation tree into a serializable struct representation. /// Convert the aggregation tree into a serializable struct representation.
/// Each node contains: { name, kind, children }. /// Each node contains: { name, kind, children }.
#[allow(dead_code)]
pub fn get_view_tree(&self) -> Vec<AggTreeViewNode> { pub fn get_view_tree(&self) -> Vec<AggTreeViewNode> {
fn node_to_view(node: &AggRefNode, pr: &PerRequestAggSegCtx) -> AggTreeViewNode { fn node_to_view(node: &AggRefNode, pr: &PerRequestAggSegCtx) -> AggTreeViewNode {
let mut children: Vec<AggTreeViewNode> = let mut children: Vec<AggTreeViewNode> =
@@ -345,12 +322,19 @@ impl PerRequestAggSegCtx {
pub(crate) fn build_segment_agg_collectors_root( pub(crate) fn build_segment_agg_collectors_root(
req: &mut AggregationsSegmentCtx, req: &mut AggregationsSegmentCtx,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> { ) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
build_segment_agg_collectors(req, &req.per_request.agg_tree.clone()) build_segment_agg_collectors_generic(req, &req.per_request.agg_tree.clone())
} }
pub(crate) fn build_segment_agg_collectors( pub(crate) fn build_segment_agg_collectors(
req: &mut AggregationsSegmentCtx, req: &mut AggregationsSegmentCtx,
nodes: &[AggRefNode], nodes: &[AggRefNode],
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
build_segment_agg_collectors_generic(req, nodes)
}
fn build_segment_agg_collectors_generic(
req: &mut AggregationsSegmentCtx,
nodes: &[AggRefNode],
) -> crate::Result<Box<dyn SegmentAggregationCollector>> { ) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let mut collectors = Vec::new(); let mut collectors = Vec::new();
for node in nodes.iter() { for node in nodes.iter() {
@@ -388,6 +372,8 @@ pub(crate) fn build_segment_agg_collector(
Ok(Box::new(SegmentCardinalityCollector::from_req( Ok(Box::new(SegmentCardinalityCollector::from_req(
req_data.column_type, req_data.column_type,
node.idx_in_req_data, node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
))) )))
} }
AggKind::StatsKind(stats_type) => { AggKind::StatsKind(stats_type) => {
@@ -398,20 +384,21 @@ pub(crate) fn build_segment_agg_collector(
| StatsType::Count | StatsType::Count
| StatsType::Max | StatsType::Max
| StatsType::Min | StatsType::Min
| StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req( | StatsType::Stats => build_segment_stats_collector(req_data),
node.idx_in_req_data, StatsType::ExtendedStats(sigma) => Ok(Box::new(
))), SegmentExtendedStatsCollector::from_req(req_data, sigma),
StatsType::ExtendedStats(sigma) => {
Ok(Box::new(SegmentExtendedStatsCollector::from_req(
req_data.field_type,
sigma,
node.idx_in_req_data,
req_data.missing,
)))
}
StatsType::Percentiles => Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(node.idx_in_req_data)?,
)), )),
StatsType::Percentiles => {
let req_data = req.get_metric_req_data_mut(node.idx_in_req_data);
Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(
req_data.field_type,
req_data.missing_u64,
req_data.accessor.clone(),
node.idx_in_req_data,
),
))
}
} }
} }
AggKind::TopHits => { AggKind::TopHits => {
@@ -428,12 +415,8 @@ pub(crate) fn build_segment_agg_collector(
AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate( AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
req, node, req, node,
)?)), )?)),
AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate( AggKind::Range => Ok(build_segment_range_collector(req, node)?),
req, node, AggKind::Filter => build_segment_filter_collector(req, node),
)?)),
AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate(
req, node,
)?)),
} }
} }
@@ -493,6 +476,7 @@ pub(crate) fn build_aggregations_data_from_req(
let mut data = AggregationsSegmentCtx { let mut data = AggregationsSegmentCtx {
per_request: Default::default(), per_request: Default::default(),
context, context,
column_block_accessor: ColumnBlockAccessor::default(),
}; };
for (name, agg) in aggs.iter() { for (name, agg) in aggs.iter() {
@@ -521,9 +505,9 @@ fn build_nodes(
let idx_in_req_data = data.push_range_req_data(RangeAggReqData { let idx_in_req_data = data.push_range_req_data(RangeAggReqData {
accessor, accessor,
field_type, field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(), name: agg_name.to_string(),
req: range_req.clone(), req: range_req.clone(),
is_top_level,
}); });
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?; let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
Ok(vec![AggRefNode { Ok(vec![AggRefNode {
@@ -541,9 +525,7 @@ fn build_nodes(
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData { let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
accessor, accessor,
field_type, field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(), name: agg_name.to_string(),
sub_aggregation_blueprint: None,
req: histo_req.clone(), req: histo_req.clone(),
is_date_histogram: false, is_date_histogram: false,
bounds: HistogramBounds { bounds: HistogramBounds {
@@ -568,9 +550,7 @@ fn build_nodes(
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData { let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
accessor, accessor,
field_type, field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(), name: agg_name.to_string(),
sub_aggregation_blueprint: None,
req: histo_req, req: histo_req,
is_date_histogram: true, is_date_histogram: true,
bounds: HistogramBounds { bounds: HistogramBounds {
@@ -650,7 +630,6 @@ fn build_nodes(
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData { let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
accessor, accessor,
field_type, field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(), name: agg_name.to_string(),
collecting_for, collecting_for,
missing: *missing, missing: *missing,
@@ -678,7 +657,6 @@ fn build_nodes(
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData { let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
accessor, accessor,
field_type, field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(), name: agg_name.to_string(),
collecting_for: StatsType::Percentiles, collecting_for: StatsType::Percentiles,
missing: percentiles_req.missing, missing: percentiles_req.missing,
@@ -753,6 +731,7 @@ fn build_nodes(
segment_reader: reader.clone(), segment_reader: reader.clone(),
evaluator, evaluator,
matching_docs_buffer, matching_docs_buffer,
is_top_level,
}); });
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?; let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
Ok(vec![AggRefNode { Ok(vec![AggRefNode {
@@ -895,7 +874,7 @@ fn build_terms_or_cardinality_nodes(
}); });
} }
// Add one node per accessor to mirror previous behavior and allow per-type missing handling. // Add one node per accessor
for (accessor, column_type) in column_and_types { for (accessor, column_type) in column_and_types {
let missing_value_for_accessor = if use_special_missing_agg { let missing_value_for_accessor = if use_special_missing_agg {
None None
@@ -926,11 +905,8 @@ fn build_terms_or_cardinality_nodes(
column_type, column_type,
str_dict_column: str_dict_column.clone(), str_dict_column: str_dict_column.clone(),
missing_value_for_accessor, missing_value_for_accessor,
column_block_accessor: Default::default(),
name: agg_name.to_string(), name: agg_name.to_string(),
req: TermsAggregationInternal::from_req(req), req: TermsAggregationInternal::from_req(req),
// Will be filled later when building collectors
sub_aggregation_blueprint: None,
sug_aggregations: sub_aggs.clone(), sug_aggregations: sub_aggs.clone(),
allowed_term_ids, allowed_term_ids,
is_top_level, is_top_level,
@@ -943,7 +919,6 @@ fn build_terms_or_cardinality_nodes(
column_type, column_type,
str_dict_column: str_dict_column.clone(), str_dict_column: str_dict_column.clone(),
missing_value_for_accessor, missing_value_for_accessor,
column_block_accessor: Default::default(),
name: agg_name.to_string(), name: agg_name.to_string(),
req: req.clone(), req: req.clone(),
}); });

View File

@@ -2,15 +2,441 @@ use serde_json::Value;
use crate::aggregation::agg_req::{Aggregation, Aggregations}; use crate::aggregation::agg_req::{Aggregation, Aggregations};
use crate::aggregation::agg_result::AggregationResults; use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
use crate::aggregation::collector::AggregationCollector; use crate::aggregation::collector::AggregationCollector;
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults; use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms}; use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
use crate::aggregation::DistributedAggregationCollector; use crate::aggregation::DistributedAggregationCollector;
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
use crate::query::{AllQuery, TermQuery}; use crate::query::{AllQuery, TermQuery};
use crate::schema::{IndexRecordOption, Schema, FAST}; use crate::schema::{IndexRecordOption, Schema, FAST};
use crate::{Index, IndexWriter, Term}; use crate::{Index, IndexWriter, Term};
// The following tests ensure that each bucket aggregation type correctly functions as a
// sub-aggregation of another bucket aggregation in two scenarios:
// 1) The parent has more buckets than the child sub-aggregation
// 2) The child sub-aggregation has more buckets than the parent
//
// These scenarios exercise the bucket id mapping and sub-aggregation routing logic.
#[test]
fn test_terms_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with 4 buckets
// Child: terms on text -> 2 buckets
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
// Exact expected structure and counts
assert_eq!(
res["parent_range"]["buckets"],
json!([
{
"key": "*-3",
"doc_count": 1,
"to": 3.0,
"child_terms": {
"buckets": [
{"doc_count": 1, "key": "cool"}
],
"sum_other_doc_count": 0
}
},
{
"key": "3-7",
"doc_count": 3,
"from": 3.0,
"to": 7.0,
"child_terms": {
"buckets": [
{"doc_count": 2, "key": "cool"},
{"doc_count": 1, "key": "nohit"}
],
"sum_other_doc_count": 0
}
},
{
"key": "7-20",
"doc_count": 3,
"from": 7.0,
"to": 20.0,
"child_terms": {
"buckets": [
{"doc_count": 3, "key": "cool"}
],
"sum_other_doc_count": 0
}
},
{
"key": "20-*",
"doc_count": 2,
"from": 20.0,
"child_terms": {
"buckets": [
{"doc_count": 1, "key": "cool"},
{"doc_count": 1, "key": "nohit"}
],
"sum_other_doc_count": 0
}
}
])
);
// Case B: child has more buckets than parent
// Parent: histogram on score with large interval -> 1 bucket
// Child: terms on text -> 2 buckets (cool/nohit)
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_hist": {
"histogram": {"field": "score", "interval": 100.0},
"aggs": {
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_hist"],
json!({
"buckets": [
{
"key": 0.0,
"doc_count": 9,
"child_terms": {
"buckets": [
{"doc_count": 7, "key": "cool"},
{"doc_count": 2, "key": "nohit"}
],
"sum_other_doc_count": 0
}
}
]
})
);
Ok(())
}
#[test]
fn test_range_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with 5 buckets
// Child: coarse range with 3 buckets
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 20.0}
]
}
}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
assert_eq!(
res["parent_range"]["buckets"],
json!([
{"key": "*-3", "doc_count": 1, "to": 3.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 1, "to": 3.0},
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 3, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 1, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 2, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "20-*", "doc_count": 2, "from": 20.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 2, "from": 20.0}
]}
}
])
);
// Case B: child has more buckets than parent
// Parent: terms on text (2 buckets)
// Child: range with 4 buckets
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 20.0}
]
}
}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_terms"],
json!({
"buckets": [
{
"key": "cool",
"doc_count": 7,
"child_range": {
"buckets": [
{"key": "*-3", "doc_count": 1, "to": 3.0},
{"key": "3-7", "doc_count": 2, "from": 3.0, "to": 7.0},
{"key": "7-20", "doc_count": 3, "from": 7.0, "to": 20.0},
{"key": "20-*", "doc_count": 1, "from": 20.0}
]
}
},
{
"key": "nohit",
"doc_count": 2,
"child_range": {
"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-7", "doc_count": 1, "from": 3.0, "to": 7.0},
{"key": "7-20", "doc_count": 0, "from": 7.0, "to": 20.0},
{"key": "20-*", "doc_count": 1, "from": 20.0}
]
}
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
})
);
Ok(())
}
#[test]
fn test_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with several ranges
// Child: histogram with large interval (single bucket per parent)
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_hist": {"histogram": {"field": "score", "interval": 100.0}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
assert_eq!(
res["parent_range"]["buckets"],
json!([
{"key": "*-3", "doc_count": 1, "to": 3.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
},
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 3} ]}
},
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
},
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
},
{"key": "20-*", "doc_count": 2, "from": 20.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
}
])
);
// Case B: child has more buckets than parent
// Parent: terms on text -> 2 buckets
// Child: histogram with small interval -> multiple buckets including empties
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_hist": {"histogram": {"field": "score", "interval": 10.0}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_terms"],
json!({
"buckets": [
{
"key": "cool",
"doc_count": 7,
"child_hist": {
"buckets": [
{"key": 0.0, "doc_count": 4},
{"key": 10.0, "doc_count": 2},
{"key": 20.0, "doc_count": 0},
{"key": 30.0, "doc_count": 0},
{"key": 40.0, "doc_count": 1}
]
}
},
{
"key": "nohit",
"doc_count": 2,
"child_hist": {
"buckets": [
{"key": 0.0, "doc_count": 1},
{"key": 10.0, "doc_count": 0},
{"key": 20.0, "doc_count": 0},
{"key": 30.0, "doc_count": 0},
{"key": 40.0, "doc_count": 1}
]
}
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
})
);
Ok(())
}
#[test]
fn test_date_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with several buckets
// Child: date_histogram with 30d -> single bucket per parent
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "30d"}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
let buckets = res["parent_range"]["buckets"].as_array().unwrap();
// Verify each parent bucket has exactly one child date bucket with matching doc_count
for bucket in buckets {
let parent_count = bucket["doc_count"].as_u64().unwrap();
let child_buckets = bucket["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(child_buckets.len(), 1);
assert_eq!(child_buckets[0]["doc_count"], parent_count);
}
// Case B: child has more buckets than parent
// Parent: terms on text (2 buckets)
// Child: date_histogram with 1d -> multiple buckets
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "1d"}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
let buckets = res["parent_terms"]["buckets"].as_array().unwrap();
// cool bucket
assert_eq!(buckets[0]["key"], "cool");
let cool_buckets = buckets[0]["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(cool_buckets.len(), 3);
assert_eq!(cool_buckets[0]["doc_count"], 1); // day 0
assert_eq!(cool_buckets[1]["doc_count"], 4); // day 1
assert_eq!(cool_buckets[2]["doc_count"], 2); // day 2
// nohit bucket
assert_eq!(buckets[1]["key"], "nohit");
let nohit_buckets = buckets[1]["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(nohit_buckets.len(), 2);
assert_eq!(nohit_buckets[0]["doc_count"], 1); // day 1
assert_eq!(nohit_buckets[1]["doc_count"], 1); // day 2
Ok(())
}
fn get_avg_req(field_name: &str) -> Aggregation { fn get_avg_req(field_name: &str) -> Aggregation {
serde_json::from_value(json!({ serde_json::from_value(json!({
"avg": { "avg": {
@@ -25,6 +451,10 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
} }
// *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE *** // *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE ***
// Note: The flushng part of these tests are outdated, since the buffering change after converting
// the collection into one collector per request instead of per bucket.
//
// However they are useful as they test a complex aggregation requests.
fn test_aggregation_flushing( fn test_aggregation_flushing(
merge_segments: bool, merge_segments: bool,
use_distributed_collector: bool, use_distributed_collector: bool,
@@ -37,8 +467,9 @@ fn test_aggregation_flushing(
let reader = index.reader()?; let reader = index.reader()?;
assert_eq!(DOC_BLOCK_SIZE, 64); assert_eq!(COLLECT_BLOCK_BUFFER_LEN, 64);
// In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block. // In the tree we cache documents of COLLECT_BLOCK_BUFFER_LEN before passing them down as one
// block.
// //
// Build a request so that on the first level we have one full cache, which is then flushed. // Build a request so that on the first level we have one full cache, which is then flushed.
// The same cache should have some residue docs at the end, which are flushed (Range 0-70) // The same cache should have some residue docs at the end, which are flushed (Range 0-70)

View File

@@ -6,10 +6,14 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::aggregation::agg_data::{ use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
}; };
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardSubAggCache, SubAggCache,
};
use crate::aggregation::intermediate_agg_result::{ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
}; };
use crate::aggregation::segment_agg_result::{CollectorClone, SegmentAggregationCollector}; use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
use crate::docset::DocSet; use crate::docset::DocSet;
use crate::query::{AllQuery, EnableScoring, Query, QueryParser}; use crate::query::{AllQuery, EnableScoring, Query, QueryParser};
use crate::schema::Schema; use crate::schema::Schema;
@@ -404,15 +408,18 @@ pub struct FilterAggReqData {
pub evaluator: DocumentQueryEvaluator, pub evaluator: DocumentQueryEvaluator,
/// Reusable buffer for matching documents to minimize allocations during collection /// Reusable buffer for matching documents to minimize allocations during collection
pub matching_docs_buffer: Vec<DocId>, pub matching_docs_buffer: Vec<DocId>,
/// True if this filter aggregation is at the top level of the aggregation tree (not nested).
pub is_top_level: bool,
} }
impl FilterAggReqData { impl FilterAggReqData {
pub(crate) fn get_memory_consumption(&self) -> usize { pub(crate) fn get_memory_consumption(&self) -> usize {
// Estimate: name + segment reader reference + bitset + buffer capacity // Estimate: name + segment reader reference + bitset + buffer capacity
self.name.len() self.name.len()
+ std::mem::size_of::<SegmentReader>() + std::mem::size_of::<SegmentReader>()
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes) + self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>() + self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
+ std::mem::size_of::<bool>()
} }
} }
@@ -446,7 +453,7 @@ impl DocumentQueryEvaluator {
let weight = query.weight(EnableScoring::disabled_from_schema(&schema))?; let weight = query.weight(EnableScoring::disabled_from_schema(&schema))?;
// Get a scorer that iterates over matching documents // Get a scorer that iterates over matching documents
let mut scorer = weight.scorer(segment_reader, 1.0, 0)?; let mut scorer = weight.scorer(segment_reader, 1.0)?;
// Create a BitSet to hold all matching documents // Create a BitSet to hold all matching documents
let mut bitset = BitSet::with_max_value(max_doc); let mut bitset = BitSet::with_max_value(max_doc);
@@ -489,17 +496,24 @@ impl Debug for DocumentQueryEvaluator {
} }
} }
/// Segment collector for filter aggregation #[derive(Debug, Clone, PartialEq, Copy)]
pub struct SegmentFilterCollector { struct DocCount {
/// Document count in this bucket
doc_count: u64, doc_count: u64,
bucket_id: BucketId,
}
/// Segment collector for filter aggregation
pub struct SegmentFilterCollector<C: SubAggCache> {
/// Document counts per parent bucket
parent_buckets: Vec<DocCount>,
/// Sub-aggregation collectors /// Sub-aggregation collectors
sub_aggregations: Option<Box<dyn SegmentAggregationCollector>>, sub_aggregations: Option<CachedSubAggs<C>>,
bucket_id_provider: BucketIdProvider,
/// Accessor index for this filter aggregation (to access FilterAggReqData) /// Accessor index for this filter aggregation (to access FilterAggReqData)
accessor_idx: usize, accessor_idx: usize,
} }
impl SegmentFilterCollector { impl<C: SubAggCache> SegmentFilterCollector<C> {
/// Create a new filter segment collector following the new agg_data pattern /// Create a new filter segment collector following the new agg_data pattern
pub(crate) fn from_req_and_validate( pub(crate) fn from_req_and_validate(
req: &mut AggregationsSegmentCtx, req: &mut AggregationsSegmentCtx,
@@ -511,47 +525,75 @@ impl SegmentFilterCollector {
} else { } else {
None None
}; };
let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new);
Ok(SegmentFilterCollector { Ok(SegmentFilterCollector {
doc_count: 0, parent_buckets: Vec::new(),
sub_aggregations: sub_agg_collector, sub_aggregations: sub_agg_collector,
accessor_idx: node.idx_in_req_data, accessor_idx: node.idx_in_req_data,
bucket_id_provider: BucketIdProvider::default(),
}) })
} }
} }
impl Debug for SegmentFilterCollector { pub(crate) fn build_segment_filter_collector(
req: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let is_top_level = req.per_request.filter_req_data[node.idx_in_req_data]
.as_ref()
.expect("filter_req_data slot is empty")
.is_top_level;
if is_top_level {
Ok(Box::new(
SegmentFilterCollector::<LowCardSubAggCache>::from_req_and_validate(req, node)?,
))
} else {
Ok(Box::new(
SegmentFilterCollector::<HighCardSubAggCache>::from_req_and_validate(req, node)?,
))
}
}
impl<C: SubAggCache> Debug for SegmentFilterCollector<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentFilterCollector") f.debug_struct("SegmentFilterCollector")
.field("doc_count", &self.doc_count) .field("buckets", &self.parent_buckets)
.field("has_sub_aggs", &self.sub_aggregations.is_some()) .field("has_sub_aggs", &self.sub_aggregations.is_some())
.field("accessor_idx", &self.accessor_idx) .field("accessor_idx", &self.accessor_idx)
.finish() .finish()
} }
} }
impl CollectorClone for SegmentFilterCollector { impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
// For now, panic - this needs proper implementation with weight recreation
panic!("SegmentFilterCollector cloning not yet implemented - requires weight recreation")
}
}
impl SegmentAggregationCollector for SegmentFilterCollector {
fn add_intermediate_aggregation_result( fn add_intermediate_aggregation_result(
self: Box<Self>, &mut self,
agg_data: &AggregationsSegmentCtx, agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults, results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> { ) -> crate::Result<()> {
let mut sub_results = IntermediateAggregationResults::default(); let mut sub_results = IntermediateAggregationResults::default();
let bucket_opt = self.parent_buckets.get(parent_bucket_id as usize);
if let Some(sub_aggs) = self.sub_aggregations { if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs.add_intermediate_aggregation_result(agg_data, &mut sub_results)?; sub_aggs
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_results,
// Here we create a new bucket ID for sub-aggregations if the bucket doesn't
// exist, so that sub-aggregations can still produce results (e.g., zero doc
// count)
bucket_opt
.map(|bucket| bucket.bucket_id)
.unwrap_or(self.bucket_id_provider.next_bucket_id()),
)?;
} }
// Create the filter bucket result // Create the filter bucket result
let filter_bucket_result = IntermediateBucketResult::Filter { let filter_bucket_result = IntermediateBucketResult::Filter {
doc_count: self.doc_count, doc_count: bucket_opt.map(|b| b.doc_count).unwrap_or(0),
sub_aggregations: sub_results, sub_aggregations: sub_results,
}; };
@@ -570,32 +612,17 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
Ok(()) Ok(())
} }
fn collect(&mut self, doc: DocId, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { fn collect(
// Access the evaluator from FilterAggReqData
let req_data = agg_data.get_filter_req_data(self.accessor_idx);
// O(1) BitSet lookup to check if document matches filter
if req_data.evaluator.matches_document(doc) {
self.doc_count += 1;
// If we have sub-aggregations, collect on them for this filtered document
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs.collect(doc, agg_data)?;
}
}
Ok(())
}
#[inline]
fn collect_block(
&mut self, &mut self,
docs: &[DocId], parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx, agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> { ) -> crate::Result<()> {
if docs.is_empty() { if docs.is_empty() {
return Ok(()); return Ok(());
} }
let mut bucket = self.parent_buckets[parent_bucket_id as usize];
// Take the request data to avoid borrow checker issues with sub-aggregations // Take the request data to avoid borrow checker issues with sub-aggregations
let mut req = agg_data.take_filter_req_data(self.accessor_idx); let mut req = agg_data.take_filter_req_data(self.accessor_idx);
@@ -604,18 +631,24 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
req.evaluator req.evaluator
.filter_batch(docs, &mut req.matching_docs_buffer); .filter_batch(docs, &mut req.matching_docs_buffer);
self.doc_count += req.matching_docs_buffer.len() as u64; bucket.doc_count += req.matching_docs_buffer.len() as u64;
// Batch process sub-aggregations if we have matches // Batch process sub-aggregations if we have matches
if !req.matching_docs_buffer.is_empty() { if !req.matching_docs_buffer.is_empty() {
if let Some(sub_aggs) = &mut self.sub_aggregations { if let Some(sub_aggs) = &mut self.sub_aggregations {
// Use collect_block for better sub-aggregation performance for &doc_id in &req.matching_docs_buffer {
sub_aggs.collect_block(&req.matching_docs_buffer, agg_data)?; sub_aggs.push(bucket.bucket_id, doc_id);
}
} }
} }
// Put the request data back // Put the request data back
agg_data.put_back_filter_req_data(self.accessor_idx, req); agg_data.put_back_filter_req_data(self.accessor_idx, req);
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs.check_flush_local(agg_data)?;
}
// put back bucket
self.parent_buckets[parent_bucket_id as usize] = bucket;
Ok(()) Ok(())
} }
@@ -626,6 +659,21 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
} }
Ok(()) Ok(())
} }
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
let bucket_id = self.bucket_id_provider.next_bucket_id();
self.parent_buckets.push(DocCount {
doc_count: 0,
bucket_id,
});
}
Ok(())
}
} }
/// Intermediate result for filter aggregation /// Intermediate result for filter aggregation
@@ -1519,9 +1567,9 @@ mod tests {
let searcher = reader.searcher(); let searcher = reader.searcher();
let agg = json!({ let agg = json!({
"test": { "test": {
"filter": deserialized, "filter": deserialized,
"aggs": { "count": { "value_count": { "field": "brand" } } } "aggs": { "count": { "value_count": { "field": "brand" } } }
} }
}); });

View File

@@ -1,6 +1,6 @@
use std::cmp::Ordering; use std::cmp::Ordering;
use columnar::{Column, ColumnBlockAccessor, ColumnType}; use columnar::{Column, ColumnType};
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tantivy_bitpacker::minmax; use tantivy_bitpacker::minmax;
@@ -8,14 +8,14 @@ use tantivy_bitpacker::minmax;
use crate::aggregation::agg_data::{ use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
}; };
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req::Aggregations; use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_result::BucketEntry; use crate::aggregation::agg_result::BucketEntry;
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
use crate::aggregation::intermediate_agg_result::{ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateHistogramBucketEntry, IntermediateHistogramBucketEntry,
}; };
use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::*; use crate::aggregation::*;
use crate::TantivyError; use crate::TantivyError;
@@ -26,13 +26,8 @@ pub struct HistogramAggReqData {
pub accessor: Column<u64>, pub accessor: Column<u64>,
/// The field type of the fast field. /// The field type of the fast field.
pub field_type: ColumnType, pub field_type: ColumnType,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The name of the aggregation. /// The name of the aggregation.
pub name: String, pub name: String,
/// The sub aggregation blueprint, used to create sub aggregations for each bucket.
/// Will be filled during initialization of the collector.
pub sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
/// The histogram aggregation request. /// The histogram aggregation request.
pub req: HistogramAggregation, pub req: HistogramAggregation,
/// True if this is a date_histogram aggregation. /// True if this is a date_histogram aggregation.
@@ -257,18 +252,24 @@ impl HistogramBounds {
pub(crate) struct SegmentHistogramBucketEntry { pub(crate) struct SegmentHistogramBucketEntry {
pub key: f64, pub key: f64,
pub doc_count: u64, pub doc_count: u64,
pub bucket_id: BucketId,
} }
impl SegmentHistogramBucketEntry { impl SegmentHistogramBucketEntry {
pub(crate) fn into_intermediate_bucket_entry( pub(crate) fn into_intermediate_bucket_entry(
self, self,
sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>, sub_aggregation: &mut Option<HighCardCachedSubAggs>,
agg_data: &AggregationsSegmentCtx, agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateHistogramBucketEntry> { ) -> crate::Result<IntermediateHistogramBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default(); let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = sub_aggregation { if let Some(sub_aggregation) = sub_aggregation {
sub_aggregation sub_aggregation
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?; .get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
self.bucket_id,
)?;
} }
Ok(IntermediateHistogramBucketEntry { Ok(IntermediateHistogramBucketEntry {
key: self.key, key: self.key,
@@ -278,27 +279,38 @@ impl SegmentHistogramBucketEntry {
} }
} }
#[derive(Clone, Debug, Default)]
struct HistogramBuckets {
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to /// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype. /// the correct datatype.
#[derive(Clone, Debug)] #[derive(Debug)]
pub struct SegmentHistogramCollector { pub struct SegmentHistogramCollector {
/// The buckets containing the aggregation data. /// The buckets containing the aggregation data.
buckets: FxHashMap<i64, SegmentHistogramBucketEntry>, /// One Histogram bucket per parent bucket id.
sub_aggregations: FxHashMap<i64, Box<dyn SegmentAggregationCollector>>, parent_buckets: Vec<HistogramBuckets>,
sub_agg: Option<HighCardCachedSubAggs>,
accessor_idx: usize, accessor_idx: usize,
bucket_id_provider: BucketIdProvider,
} }
impl SegmentAggregationCollector for SegmentHistogramCollector { impl SegmentAggregationCollector for SegmentHistogramCollector {
fn add_intermediate_aggregation_result( fn add_intermediate_aggregation_result(
self: Box<Self>, &mut self,
agg_data: &AggregationsSegmentCtx, agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults, results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> { ) -> crate::Result<()> {
let name = agg_data let name = agg_data
.get_histogram_req_data(self.accessor_idx) .get_histogram_req_data(self.accessor_idx)
.name .name
.clone(); .clone();
let bucket = self.into_intermediate_bucket_result(agg_data)?; // TODO: avoid prepare_max_bucket here and handle empty buckets.
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let histogram = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
let bucket = self.add_intermediate_bucket_result(agg_data, histogram)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?; results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(()) Ok(())
@@ -307,44 +319,40 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
#[inline] #[inline]
fn collect( fn collect(
&mut self, &mut self,
doc: crate::DocId, parent_bucket_id: BucketId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId], docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx, agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> { ) -> crate::Result<()> {
let mut req = agg_data.take_histogram_req_data(self.accessor_idx); let req = agg_data.take_histogram_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption(); let mem_pre = self.get_memory_consumption();
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
let bounds = req.bounds; let bounds = req.bounds;
let interval = req.req.interval; let interval = req.req.interval;
let offset = req.offset; let offset = req.offset;
let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64; let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64;
req.column_block_accessor.fetch_block(docs, &req.accessor); agg_data
for (doc, val) in req .column_block_accessor
.fetch_block(docs, &req.accessor);
for (doc, val) in agg_data
.column_block_accessor .column_block_accessor
.iter_docid_vals(docs, &req.accessor) .iter_docid_vals(docs, &req.accessor)
{ {
let val = f64_from_fastfield_u64(val, &req.field_type); let val = f64_from_fastfield_u64(val, req.field_type);
let bucket_pos = get_bucket_pos(val); let bucket_pos = get_bucket_pos(val);
if bounds.contains(val) { if bounds.contains(val) {
let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| { let bucket = buckets.entry(bucket_pos).or_insert_with(|| {
let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset); let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset);
SegmentHistogramBucketEntry { key, doc_count: 0 } SegmentHistogramBucketEntry {
key,
doc_count: 0,
bucket_id: self.bucket_id_provider.next_bucket_id(),
}
}); });
bucket.doc_count += 1; bucket.doc_count += 1;
if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() { if let Some(sub_agg) = &mut self.sub_agg {
self.sub_aggregations sub_agg.push(bucket.bucket_id, doc);
.entry(bucket_pos)
.or_insert_with(|| sub_aggregation_blueprint.clone())
.collect(doc, agg_data)?;
} }
} }
} }
@@ -358,14 +366,30 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
.add_memory_consumed(mem_delta as u64)?; .add_memory_consumed(mem_delta as u64)?;
} }
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.check_flush_local(agg_data)?;
}
Ok(()) Ok(())
} }
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for sub_aggregation in self.sub_aggregations.values_mut() { if let Some(sub_aggregation) = &mut self.sub_agg {
sub_aggregation.flush(agg_data)?; sub_aggregation.flush(agg_data)?;
} }
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
self.parent_buckets.push(HistogramBuckets {
buckets: FxHashMap::default(),
});
}
Ok(()) Ok(())
} }
} }
@@ -373,22 +397,19 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
impl SegmentHistogramCollector { impl SegmentHistogramCollector {
fn get_memory_consumption(&self) -> usize { fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>(); let self_mem = std::mem::size_of::<Self>();
let sub_aggs_mem = self.sub_aggregations.memory_consumption(); let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
let buckets_mem = self.buckets.memory_consumption(); self_mem + buckets_mem
self_mem + sub_aggs_mem + buckets_mem
} }
/// Converts the collector result into a intermediate bucket result. /// Converts the collector result into a intermediate bucket result.
pub fn into_intermediate_bucket_result( fn add_intermediate_bucket_result(
self, &mut self,
agg_data: &AggregationsSegmentCtx, agg_data: &AggregationsSegmentCtx,
histogram: HistogramBuckets,
) -> crate::Result<IntermediateBucketResult> { ) -> crate::Result<IntermediateBucketResult> {
let mut buckets = Vec::with_capacity(self.buckets.len()); let mut buckets = Vec::with_capacity(histogram.buckets.len());
for (bucket_pos, bucket) in self.buckets { for bucket in histogram.buckets.into_values() {
let bucket_res = bucket.into_intermediate_bucket_entry( let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data);
self.sub_aggregations.get(&bucket_pos).cloned(),
agg_data,
);
buckets.push(bucket_res?); buckets.push(bucket_res?);
} }
@@ -408,7 +429,7 @@ impl SegmentHistogramCollector {
agg_data: &mut AggregationsSegmentCtx, agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode, node: &AggRefNode,
) -> crate::Result<Self> { ) -> crate::Result<Self> {
let blueprint = if !node.children.is_empty() { let sub_agg = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?) Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else { } else {
None None
@@ -423,13 +444,13 @@ impl SegmentHistogramCollector {
max: f64::MAX, max: f64::MAX,
}); });
req_data.offset = req_data.req.offset.unwrap_or(0.0); req_data.offset = req_data.req.offset.unwrap_or(0.0);
let sub_agg = sub_agg.map(CachedSubAggs::new);
req_data.sub_aggregation_blueprint = blueprint;
Ok(Self { Ok(Self {
buckets: Default::default(), parent_buckets: Default::default(),
sub_aggregations: Default::default(), sub_agg,
accessor_idx: node.idx_in_req_data, accessor_idx: node.idx_in_req_data,
bucket_id_provider: BucketIdProvider::default(),
}) })
} }
} }

View File

@@ -1,18 +1,22 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::ops::Range; use std::ops::Range;
use columnar::{Column, ColumnBlockAccessor, ColumnType}; use columnar::{Column, ColumnType};
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::aggregation::agg_data::{ use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
}; };
use crate::aggregation::agg_limits::AggregationLimitsGuard;
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache,
};
use crate::aggregation::intermediate_agg_result::{ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateRangeBucketEntry, IntermediateRangeBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
}; };
use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::*; use crate::aggregation::*;
use crate::TantivyError; use crate::TantivyError;
@@ -23,12 +27,12 @@ pub struct RangeAggReqData {
pub accessor: Column<u64>, pub accessor: Column<u64>,
/// The type of the fast field. /// The type of the fast field.
pub field_type: ColumnType, pub field_type: ColumnType,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The range aggregation request. /// The range aggregation request.
pub req: RangeAggregation, pub req: RangeAggregation,
/// The name of the aggregation. /// The name of the aggregation.
pub name: String, pub name: String,
/// Whether this is a top-level aggregation.
pub is_top_level: bool,
} }
impl RangeAggReqData { impl RangeAggReqData {
@@ -151,19 +155,47 @@ pub(crate) struct SegmentRangeAndBucketEntry {
/// The collector puts values from the fast field into the correct buckets and does a conversion to /// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype. /// the correct datatype.
#[derive(Clone, Debug)] pub struct SegmentRangeCollector<C: SubAggCache> {
pub struct SegmentRangeCollector {
/// The buckets containing the aggregation data. /// The buckets containing the aggregation data.
buckets: Vec<SegmentRangeAndBucketEntry>, /// One for each ParentBucketId
parent_buckets: Vec<Vec<SegmentRangeAndBucketEntry>>,
column_type: ColumnType, column_type: ColumnType,
pub(crate) accessor_idx: usize, pub(crate) accessor_idx: usize,
sub_agg: Option<CachedSubAggs<C>>,
/// Here things get a bit weird. We need to assign unique bucket ids across all
/// parent buckets. So we keep track of the next available bucket id here.
/// This allows a kind of flattening of the bucket ids across all parent buckets.
/// E.g. in nested aggregations:
/// Term Agg -> Range aggregation -> Stats aggregation
/// E.g. the Term Agg creates 3 buckets ["INFO", "ERROR", "WARN"], each of these has a Range
/// aggregation with 4 buckets. The Range aggregation will create buckets with ids:
/// - INFO: 0,1,2,3
/// - ERROR: 4,5,6,7
/// - WARN: 8,9,10,11
///
/// This allows the Stats aggregation to have unique bucket ids to refer to.
bucket_id_provider: BucketIdProvider,
limits: AggregationLimitsGuard,
} }
impl<C: SubAggCache> Debug for SegmentRangeCollector<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentRangeCollector")
.field("parent_buckets_len", &self.parent_buckets.len())
.field("column_type", &self.column_type)
.field("accessor_idx", &self.accessor_idx)
.field("has_sub_agg", &self.sub_agg.is_some())
.finish()
}
}
/// TODO: Bad naming, there's also SegmentRangeAndBucketEntry
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct SegmentRangeBucketEntry { pub(crate) struct SegmentRangeBucketEntry {
pub key: Key, pub key: Key,
pub doc_count: u64, pub doc_count: u64,
pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>, // pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
pub bucket_id: BucketId,
/// The from range of the bucket. Equals `f64::MIN` when `None`. /// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>, pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not /// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not
@@ -184,48 +216,50 @@ impl Debug for SegmentRangeBucketEntry {
impl SegmentRangeBucketEntry { impl SegmentRangeBucketEntry {
pub(crate) fn into_intermediate_bucket_entry( pub(crate) fn into_intermediate_bucket_entry(
self, self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateRangeBucketEntry> { ) -> crate::Result<IntermediateRangeBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default(); let sub_aggregation = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = self.sub_aggregation {
sub_aggregation
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?
} else {
Default::default()
};
Ok(IntermediateRangeBucketEntry { Ok(IntermediateRangeBucketEntry {
key: self.key.into(), key: self.key.into(),
doc_count: self.doc_count, doc_count: self.doc_count,
sub_aggregation: sub_aggregation_res, sub_aggregation_res: sub_aggregation,
from: self.from, from: self.from,
to: self.to, to: self.to,
}) })
} }
} }
impl SegmentAggregationCollector for SegmentRangeCollector { impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
fn add_intermediate_aggregation_result( fn add_intermediate_aggregation_result(
self: Box<Self>, &mut self,
agg_data: &AggregationsSegmentCtx, agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults, results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> { ) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let field_type = self.column_type; let field_type = self.column_type;
let name = agg_data let name = agg_data
.get_range_req_data(self.accessor_idx) .get_range_req_data(self.accessor_idx)
.name .name
.to_string(); .to_string();
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = self let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
.buckets
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = buckets
.into_iter() .into_iter()
.map(move |range_bucket| { .map(|range_bucket| {
Ok(( let bucket_id = range_bucket.bucket.bucket_id;
range_to_string(&range_bucket.range, &field_type)?, let mut agg = range_bucket.bucket.into_intermediate_bucket_entry()?;
range_bucket if let Some(sub_aggregation) = &mut self.sub_agg {
.bucket sub_aggregation
.into_intermediate_bucket_entry(agg_data)?, .get_sub_agg_collector()
)) .add_intermediate_aggregation_result(
agg_data,
&mut agg.sub_aggregation_res,
bucket_id,
)?;
}
Ok((range_to_string(&range_bucket.range, &field_type)?, agg))
}) })
.collect::<crate::Result<_>>()?; .collect::<crate::Result<_>>()?;
@@ -242,73 +276,114 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
#[inline] #[inline]
fn collect( fn collect(
&mut self, &mut self,
doc: crate::DocId, parent_bucket_id: BucketId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId], docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx, agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> { ) -> crate::Result<()> {
// Take request data to avoid borrow conflicts during sub-aggregation let req = agg_data.take_range_req_data(self.accessor_idx);
let mut req = agg_data.take_range_req_data(self.accessor_idx);
req.column_block_accessor.fetch_block(docs, &req.accessor); agg_data
.column_block_accessor
.fetch_block(docs, &req.accessor);
for (doc, val) in req let buckets = &mut self.parent_buckets[parent_bucket_id as usize];
for (doc, val) in agg_data
.column_block_accessor .column_block_accessor
.iter_docid_vals(docs, &req.accessor) .iter_docid_vals(docs, &req.accessor)
{ {
let bucket_pos = self.get_bucket_pos(val); let bucket_pos = get_bucket_pos(val, buckets);
let bucket = &mut self.buckets[bucket_pos]; let bucket = &mut buckets[bucket_pos];
bucket.bucket.doc_count += 1; bucket.bucket.doc_count += 1;
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() { if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.collect(doc, agg_data)?; sub_agg.push(bucket.bucket.bucket_id, doc);
} }
} }
agg_data.put_back_range_req_data(self.accessor_idx, req); agg_data.put_back_range_req_data(self.accessor_idx, req);
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
Ok(()) Ok(())
} }
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> { fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for bucket in self.buckets.iter_mut() { if let Some(sub_agg) = self.sub_agg.as_mut() {
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() { sub_agg.flush(agg_data)?;
sub_agg.flush(agg_data)?;
}
} }
Ok(()) Ok(())
} }
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
let new_buckets = self.create_new_buckets(agg_data)?;
self.parent_buckets.push(new_buckets);
}
Ok(())
}
}
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
/// bucket storage, depending on the column type and aggregation level.
pub(crate) fn build_segment_range_collector(
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let accessor_idx = node.idx_in_req_data;
let req_data = agg_data.get_range_req_data(node.idx_in_req_data);
let field_type = req_data.field_type;
// TODO: A better metric instead of is_top_level would be the number of buckets expected.
// E.g. If range agg is not top level, but the parent is a bucket agg with less than 10 buckets,
// we can are still in low cardinality territory.
let is_low_card = req_data.is_top_level && req_data.req.ranges.len() <= 64;
let sub_agg = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else {
None
};
if is_low_card {
Ok(Box::new(SegmentRangeCollector::<LowCardSubAggCache> {
sub_agg: sub_agg.map(LowCardCachedSubAggs::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
} else {
Ok(Box::new(SegmentRangeCollector::<HighCardSubAggCache> {
sub_agg: sub_agg.map(CachedSubAggs::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
}
} }
impl SegmentRangeCollector { impl<C: SubAggCache> SegmentRangeCollector<C> {
pub(crate) fn from_req_and_validate( pub(crate) fn create_new_buckets(
req_data: &mut AggregationsSegmentCtx, &mut self,
node: &AggRefNode, agg_data: &AggregationsSegmentCtx,
) -> crate::Result<Self> { ) -> crate::Result<Vec<SegmentRangeAndBucketEntry>> {
let accessor_idx = node.idx_in_req_data; let field_type = self.column_type;
let (field_type, ranges) = { let req_data = agg_data.get_range_req_data(self.accessor_idx);
let req_view = req_data.get_range_req_data(node.idx_in_req_data);
(req_view.field_type, req_view.req.ranges.clone())
};
// The range input on the request is f64. // The range input on the request is f64.
// We need to convert to u64 ranges, because we read the values as u64. // We need to convert to u64 ranges, because we read the values as u64.
// The mapping from the conversion is monotonic so ordering is preserved. // The mapping from the conversion is monotonic so ordering is preserved.
let sub_agg_prototype = if !node.children.is_empty() { let buckets: Vec<_> = extend_validate_ranges(&req_data.req.ranges, &field_type)?
Some(build_segment_agg_collectors(req_data, &node.children)?)
} else {
None
};
let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)?
.iter() .iter()
.map(|range| { .map(|range| {
let bucket_id = self.bucket_id_provider.next_bucket_id();
let key = range let key = range
.key .key
.clone() .clone()
@@ -317,20 +392,20 @@ impl SegmentRangeCollector {
let to = if range.range.end == u64::MAX { let to = if range.range.end == u64::MAX {
None None
} else { } else {
Some(f64_from_fastfield_u64(range.range.end, &field_type)) Some(f64_from_fastfield_u64(range.range.end, field_type))
}; };
let from = if range.range.start == u64::MIN { let from = if range.range.start == u64::MIN {
None None
} else { } else {
Some(f64_from_fastfield_u64(range.range.start, &field_type)) Some(f64_from_fastfield_u64(range.range.start, field_type))
}; };
let sub_aggregation = sub_agg_prototype.clone(); // let sub_aggregation = sub_agg_prototype.clone();
Ok(SegmentRangeAndBucketEntry { Ok(SegmentRangeAndBucketEntry {
range: range.range.clone(), range: range.range.clone(),
bucket: SegmentRangeBucketEntry { bucket: SegmentRangeBucketEntry {
doc_count: 0, doc_count: 0,
sub_aggregation, bucket_id,
key, key,
from, from,
to, to,
@@ -339,27 +414,20 @@ impl SegmentRangeCollector {
}) })
.collect::<crate::Result<_>>()?; .collect::<crate::Result<_>>()?;
req_data.context.limits.add_memory_consumed( self.limits.add_memory_consumed(
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64, buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
)?; )?;
Ok(buckets)
Ok(SegmentRangeCollector {
buckets,
column_type: field_type,
accessor_idx,
})
}
#[inline]
fn get_bucket_pos(&self, val: u64) -> usize {
let pos = self
.buckets
.binary_search_by_key(&val, |probe| probe.range.start)
.unwrap_or_else(|pos| pos - 1);
debug_assert!(self.buckets[pos].range.contains(&val));
pos
} }
} }
#[inline]
fn get_bucket_pos(val: u64, buckets: &[SegmentRangeAndBucketEntry]) -> usize {
let pos = buckets
.binary_search_by_key(&val, |probe| probe.range.start)
.unwrap_or_else(|pos| pos - 1);
debug_assert!(buckets[pos].range.contains(&val));
pos
}
/// Converts the user provided f64 range value to fast field value space. /// Converts the user provided f64 range value to fast field value space.
/// ///
@@ -456,7 +524,7 @@ pub(crate) fn range_to_string(
let val = i64::from_u64(val); let val = i64::from_u64(val);
format_date(val) format_date(val)
} else { } else {
Ok(f64_from_fastfield_u64(val, field_type).to_string()) Ok(f64_from_fastfield_u64(val, *field_type).to_string())
} }
}; };
@@ -486,7 +554,7 @@ mod tests {
pub fn get_collector_from_ranges( pub fn get_collector_from_ranges(
ranges: Vec<RangeAggregationRange>, ranges: Vec<RangeAggregationRange>,
field_type: ColumnType, field_type: ColumnType,
) -> SegmentRangeCollector { ) -> SegmentRangeCollector<HighCardSubAggCache> {
let req = RangeAggregation { let req = RangeAggregation {
field: "dummy".to_string(), field: "dummy".to_string(),
ranges, ranges,
@@ -506,30 +574,33 @@ mod tests {
let to = if range.range.end == u64::MAX { let to = if range.range.end == u64::MAX {
None None
} else { } else {
Some(f64_from_fastfield_u64(range.range.end, &field_type)) Some(f64_from_fastfield_u64(range.range.end, field_type))
}; };
let from = if range.range.start == u64::MIN { let from = if range.range.start == u64::MIN {
None None
} else { } else {
Some(f64_from_fastfield_u64(range.range.start, &field_type)) Some(f64_from_fastfield_u64(range.range.start, field_type))
}; };
SegmentRangeAndBucketEntry { SegmentRangeAndBucketEntry {
range: range.range.clone(), range: range.range.clone(),
bucket: SegmentRangeBucketEntry { bucket: SegmentRangeBucketEntry {
doc_count: 0, doc_count: 0,
sub_aggregation: None,
key, key,
from, from,
to, to,
bucket_id: 0,
}, },
} }
}) })
.collect(); .collect();
SegmentRangeCollector { SegmentRangeCollector {
buckets, parent_buckets: vec![buckets],
column_type: field_type, column_type: field_type,
accessor_idx: 0, accessor_idx: 0,
sub_agg: None,
bucket_id_provider: Default::default(),
limits: AggregationLimitsGuard::default(),
} }
} }
@@ -776,7 +847,7 @@ mod tests {
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()]; let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64); let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets; let buckets = collector.parent_buckets[0].clone();
assert_eq!(buckets[0].range.start, u64::MIN); assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64()); assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64()); assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -799,7 +870,7 @@ mod tests {
]; ];
let collector = get_collector_from_ranges(buckets, ColumnType::F64); let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets; let buckets = collector.parent_buckets[0].clone();
assert_eq!(buckets[0].range.start, u64::MIN); assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64()); assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64()); assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -814,7 +885,7 @@ mod tests {
let buckets = vec![(-10f64..-1f64).into()]; let buckets = vec![(-10f64..-1f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64); let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets; let buckets = collector.parent_buckets[0].clone();
assert_eq!(&buckets[0].bucket.key.to_string(), "*--10"); assert_eq!(&buckets[0].bucket.key.to_string(), "*--10");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*"); assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*");
} }
@@ -823,7 +894,7 @@ mod tests {
let buckets = vec![(0f64..10f64).into()]; let buckets = vec![(0f64..10f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64); let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets; let buckets = collector.parent_buckets[0].clone();
assert_eq!(&buckets[0].bucket.key.to_string(), "*-0"); assert_eq!(&buckets[0].bucket.key.to_string(), "*-0");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*"); assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*");
} }
@@ -832,7 +903,7 @@ mod tests {
fn range_binary_search_test_u64() { fn range_binary_search_test_u64() {
let check_ranges = |ranges: Vec<RangeAggregationRange>| { let check_ranges = |ranges: Vec<RangeAggregationRange>| {
let collector = get_collector_from_ranges(ranges, ColumnType::U64); let collector = get_collector_from_ranges(ranges, ColumnType::U64);
let search = |val: u64| collector.get_bucket_pos(val); let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
assert_eq!(search(u64::MIN), 0); assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9), 0); assert_eq!(search(9), 0);
@@ -878,7 +949,7 @@ mod tests {
let ranges = vec![(10.0..100.0).into()]; let ranges = vec![(10.0..100.0).into()];
let collector = get_collector_from_ranges(ranges, ColumnType::F64); let collector = get_collector_from_ranges(ranges, ColumnType::F64);
let search = |val: u64| collector.get_bucket_pos(val); let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
assert_eq!(search(u64::MIN), 0); assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9f64.to_u64()), 0); assert_eq!(search(9f64.to_u64()), 0);
@@ -890,63 +961,3 @@ mod tests {
// the max value // the max value
} }
} }
#[cfg(all(test, feature = "unstable"))]
mod bench {
use itertools::Itertools;
use rand::seq::SliceRandom;
use rand::thread_rng;
use super::*;
use crate::aggregation::bucket::range::tests::get_collector_from_ranges;
const TOTAL_DOCS: u64 = 1_000_000u64;
const NUM_DOCS: u64 = 50_000u64;
fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector {
let bucket_size = num_docs / num_buckets;
let mut buckets: Vec<RangeAggregationRange> = vec![];
for i in 0..num_buckets {
let bucket_start = (i * bucket_size) as f64;
buckets.push((bucket_start..bucket_start + bucket_size as f64).into())
}
get_collector_from_ranges(buckets, ColumnType::U64)
}
fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec<u64> {
let mut rng = thread_rng();
let all_docs = (0..total_docs - 1).collect_vec();
let mut vals = all_docs
.as_slice()
.choose_multiple(&mut rng, num_docs_returned as usize)
.cloned()
.collect_vec();
vals.sort();
vals
}
fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) {
let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS);
let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS);
b.iter(|| {
let mut bucket_pos = 0;
for val in &vals {
bucket_pos = collector.get_bucket_pos(*val);
}
bucket_pos
})
}
#[bench]
fn bench_range_100_buckets(b: &mut test::Bencher) {
bench_range_binary_search(b, 100)
}
#[bench]
fn bench_range_10_buckets(b: &mut test::Bencher) {
bench_range_binary_search(b, 10)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -5,11 +5,13 @@ use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx, build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
}; };
use crate::aggregation::bucket::term_agg::TermsAggregation; use crate::aggregation::bucket::term_agg::TermsAggregation;
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
use crate::aggregation::intermediate_agg_result::{ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult, IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
}; };
use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
/// Special aggregation to handle missing values for term aggregations. /// Special aggregation to handle missing values for term aggregations.
/// This missing aggregation will check multiple columns for existence. /// This missing aggregation will check multiple columns for existence.
@@ -35,41 +37,55 @@ impl MissingTermAggReqData {
} }
} }
/// The specialized missing term aggregation.
#[derive(Default, Debug, Clone)] #[derive(Default, Debug, Clone)]
pub struct TermMissingAgg { struct MissingCount {
missing_count: u32, missing_count: u32,
bucket_id: BucketId,
}
/// The specialized missing term aggregation.
#[derive(Default, Debug)]
pub struct TermMissingAgg {
accessor_idx: usize, accessor_idx: usize,
sub_agg: Option<Box<dyn SegmentAggregationCollector>>, sub_agg: Option<HighCardCachedSubAggs>,
/// Idx = parent bucket id, Value = missing count for that bucket
missing_count_per_bucket: Vec<MissingCount>,
bucket_id_provider: BucketIdProvider,
} }
impl TermMissingAgg { impl TermMissingAgg {
pub(crate) fn new( pub(crate) fn new(
req_data: &mut AggregationsSegmentCtx, agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode, node: &AggRefNode,
) -> crate::Result<Self> { ) -> crate::Result<Self> {
let has_sub_aggregations = !node.children.is_empty(); let has_sub_aggregations = !node.children.is_empty();
let accessor_idx = node.idx_in_req_data; let accessor_idx = node.idx_in_req_data;
let sub_agg = if has_sub_aggregations { let sub_agg = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?; let sub_aggregation = build_segment_agg_collectors(agg_data, &node.children)?;
Some(sub_aggregation) Some(sub_aggregation)
} else { } else {
None None
}; };
let sub_agg = sub_agg.map(CachedSubAggs::new);
let bucket_id_provider = BucketIdProvider::default();
Ok(Self { Ok(Self {
accessor_idx, accessor_idx,
sub_agg, sub_agg,
..Default::default() missing_count_per_bucket: Vec::new(),
bucket_id_provider,
}) })
} }
} }
impl SegmentAggregationCollector for TermMissingAgg { impl SegmentAggregationCollector for TermMissingAgg {
fn add_intermediate_aggregation_result( fn add_intermediate_aggregation_result(
self: Box<Self>, &mut self,
agg_data: &AggregationsSegmentCtx, agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults, results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> { ) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx); let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
let term_agg = &req_data.req; let term_agg = &req_data.req;
let missing = term_agg let missing = term_agg
@@ -80,13 +96,16 @@ impl SegmentAggregationCollector for TermMissingAgg {
let mut entries: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> = let mut entries: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> =
Default::default(); Default::default();
let missing_count = &self.missing_count_per_bucket[parent_bucket_id as usize];
let mut missing_entry = IntermediateTermBucketEntry { let mut missing_entry = IntermediateTermBucketEntry {
doc_count: self.missing_count, doc_count: missing_count.missing_count,
sub_aggregation: Default::default(), sub_aggregation: Default::default(),
}; };
if let Some(sub_agg) = self.sub_agg { if let Some(sub_agg) = &mut self.sub_agg {
let mut res = IntermediateAggregationResults::default(); let mut res = IntermediateAggregationResults::default();
sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?; sub_agg
.get_sub_agg_collector()
.add_intermediate_aggregation_result(agg_data, &mut res, missing_count.bucket_id)?;
missing_entry.sub_aggregation = res; missing_entry.sub_aggregation = res;
} }
entries.insert(missing.into(), missing_entry); entries.insert(missing.into(), missing_entry);
@@ -109,30 +128,52 @@ impl SegmentAggregationCollector for TermMissingAgg {
fn collect( fn collect(
&mut self, &mut self,
doc: crate::DocId, parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx, agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> { ) -> crate::Result<()> {
let bucket = &mut self.missing_count_per_bucket[parent_bucket_id as usize];
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx); let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
let has_value = req_data
.accessors for doc in docs {
.iter() let doc = *doc;
.any(|(acc, _)| acc.index.has_value(doc)); let has_value = req_data
if !has_value { .accessors
self.missing_count += 1; .iter()
if let Some(sub_agg) = self.sub_agg.as_mut() { .any(|(acc, _)| acc.index.has_value(doc));
sub_agg.collect(doc, agg_data)?; if !has_value {
bucket.missing_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.push(bucket.bucket_id, doc);
}
} }
} }
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
Ok(()) Ok(())
} }
fn collect_block( fn prepare_max_bucket(
&mut self, &mut self,
docs: &[crate::DocId], max_bucket: BucketId,
agg_data: &mut AggregationsSegmentCtx, _agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> { ) -> crate::Result<()> {
for doc in docs { while self.missing_count_per_bucket.len() <= max_bucket as usize {
self.collect(*doc, agg_data)?; let bucket_id = self.bucket_id_provider.next_bucket_id();
self.missing_count_per_bucket.push(MissingCount {
missing_count: 0,
bucket_id,
});
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.flush(agg_data)?;
} }
Ok(()) Ok(())
} }

View File

@@ -1,87 +0,0 @@
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::DocId;
#[cfg(test)]
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
#[cfg(not(test))]
pub(crate) const DOC_BLOCK_SIZE: usize = 256;
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
/// BufAggregationCollector buffers documents before calling collect_block().
#[derive(Clone)]
pub(crate) struct BufAggregationCollector {
pub(crate) collector: Box<dyn SegmentAggregationCollector>,
staged_docs: DocBlock,
num_staged_docs: usize,
}
impl std::fmt::Debug for BufAggregationCollector {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SegmentAggregationResultsCollector")
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
.field("num_staged_docs", &self.num_staged_docs)
.finish()
}
}
impl BufAggregationCollector {
pub fn new(collector: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
collector,
num_staged_docs: 0,
staged_docs: [0; DOC_BLOCK_SIZE],
}
}
}
impl SegmentAggregationCollector for BufAggregationCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results)
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.staged_docs[self.num_staged_docs] = doc;
self.num_staged_docs += 1;
if self.num_staged_docs == self.staged_docs.len() {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
self.num_staged_docs = 0;
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collector.collect_block(docs, agg_data)?;
Ok(())
}
#[inline]
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
self.num_staged_docs = 0;
self.collector.flush(agg_data)?;
Ok(())
}
}

View File

@@ -0,0 +1,245 @@
use std::fmt::Debug;
use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC;
use crate::aggregation::BucketId;
use crate::DocId;
/// A cache for sub-aggregations, storing doc ids per bucket id.
/// Depending on the cardinality of the parent aggregation, we use different
/// storage strategies.
///
/// ## Low Cardinality
/// Cardinality here refers to the number of unique flattened buckets that can be created
/// by the parent aggregation.
/// Flattened buckets are the result of combining all buckets per collector
/// into a single list of buckets, where each bucket is identified by its BucketId.
///
/// ## Usage
/// Since this is caching for sub-aggregations, it is only used by bucket
/// aggregations.
///
/// TODO: consider using a more advanced data structure for high cardinality
/// aggregations.
/// What this datastructure does in general is to group docs by bucket id.
#[derive(Debug)]
pub(crate) struct CachedSubAggs<C: SubAggCache> {
cache: C,
sub_agg_collector: Box<dyn SegmentAggregationCollector>,
num_docs: usize,
}
pub type LowCardCachedSubAggs = CachedSubAggs<LowCardSubAggCache>;
pub type HighCardCachedSubAggs = CachedSubAggs<HighCardSubAggCache>;
const FLUSH_THRESHOLD: usize = 2048;
/// A trait for caching sub-aggregation doc ids per bucket id.
/// Different implementations can be used depending on the cardinality
/// of the parent aggregation.
pub trait SubAggCache: Debug {
fn new() -> Self;
fn push(&mut self, bucket_id: BucketId, doc_id: DocId);
fn flush_local(
&mut self,
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
agg_data: &mut AggregationsSegmentCtx,
force: bool,
) -> crate::Result<()>;
}
impl<Backend: SubAggCache + Debug> CachedSubAggs<Backend> {
pub fn new(sub_agg: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
cache: Backend::new(),
sub_agg_collector: sub_agg,
num_docs: 0,
}
}
pub fn get_sub_agg_collector(&mut self) -> &mut Box<dyn SegmentAggregationCollector> {
&mut self.sub_agg_collector
}
#[inline]
pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
self.cache.push(bucket_id, doc_id);
self.num_docs += 1;
}
/// Check if we need to flush based on the number of documents cached.
/// If so, flushes the cache to the provided aggregation collector.
pub fn check_flush_local(
&mut self,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if self.num_docs >= FLUSH_THRESHOLD {
self.cache
.flush_local(&mut self.sub_agg_collector, agg_data, false)?;
self.num_docs = 0;
}
Ok(())
}
/// Note: this _does_ flush the sub aggregations.
pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if self.num_docs != 0 {
self.cache
.flush_local(&mut self.sub_agg_collector, agg_data, true)?;
self.num_docs = 0;
}
self.sub_agg_collector.flush(agg_data)?;
Ok(())
}
}
/// Number of partitions for high cardinality sub-aggregation cache.
const NUM_PARTITIONS: usize = 16;
#[derive(Debug)]
pub(crate) struct HighCardSubAggCache {
/// This weird partitioning is used to do some cheap grouping on the bucket ids.
/// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality,
/// but there are just 16 bucket ids, each bucket id will go to its own partition.
///
/// We want to keep this cheap, because high cardinality aggregations can have a lot of
/// buckets, and there may be nothing to group.
partitions: Box<[PartitionEntry; NUM_PARTITIONS]>,
}
impl HighCardSubAggCache {
#[inline]
fn clear(&mut self) {
for partition in self.partitions.iter_mut() {
partition.clear();
}
}
}
#[derive(Debug, Clone, Default)]
struct PartitionEntry {
bucket_ids: Vec<BucketId>,
docs: Vec<DocId>,
}
impl PartitionEntry {
#[inline]
fn clear(&mut self) {
self.bucket_ids.clear();
self.docs.clear();
}
}
impl SubAggCache for HighCardSubAggCache {
fn new() -> Self {
Self {
partitions: Box::new(core::array::from_fn(|_| PartitionEntry::default())),
}
}
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
let idx = bucket_id % NUM_PARTITIONS as u32;
let slot = &mut self.partitions[idx as usize];
slot.bucket_ids.push(bucket_id);
slot.docs.push(doc_id);
}
fn flush_local(
&mut self,
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
agg_data: &mut AggregationsSegmentCtx,
_force: bool,
) -> crate::Result<()> {
let mut max_bucket = 0u32;
for partition in self.partitions.iter() {
if let Some(&local_max) = partition.bucket_ids.iter().max() {
max_bucket = max_bucket.max(local_max);
}
}
sub_agg.prepare_max_bucket(max_bucket, agg_data)?;
for slot in self.partitions.iter() {
if !slot.bucket_ids.is_empty() {
// Reduce dynamic dispatch overhead by collecting a full partition in one call.
sub_agg.collect_multiple(&slot.bucket_ids, &slot.docs, agg_data)?;
}
}
self.clear();
Ok(())
}
}
#[derive(Debug)]
pub(crate) struct LowCardSubAggCache {
/// Cache doc ids per bucket for sub-aggregations.
///
/// The outer Vec is indexed by BucketId.
per_bucket_docs: Vec<Vec<DocId>>,
}
impl LowCardSubAggCache {
#[inline]
fn clear(&mut self) {
for v in &mut self.per_bucket_docs {
v.clear();
}
}
}
impl SubAggCache for LowCardSubAggCache {
fn new() -> Self {
Self {
per_bucket_docs: Vec::new(),
}
}
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
let idx = bucket_id as usize;
if self.per_bucket_docs.len() <= idx {
self.per_bucket_docs.resize_with(idx + 1, Vec::new);
}
self.per_bucket_docs[idx].push(doc_id);
}
fn flush_local(
&mut self,
sub_agg: &mut Box<dyn SegmentAggregationCollector>,
agg_data: &mut AggregationsSegmentCtx,
force: bool,
) -> crate::Result<()> {
// Pre-aggregated: call collect per bucket.
let max_bucket = (self.per_bucket_docs.len() as BucketId).saturating_sub(1);
sub_agg.prepare_max_bucket(max_bucket, agg_data)?;
// The threshold above which we flush buckets individually.
// Note: We need to make sure that we don't lock ourselves into a situation where we hit
// the FLUSH_THRESHOLD, but never flush any buckets. (except the final flush)
let mut bucket_treshold = FLUSH_THRESHOLD / (self.per_bucket_docs.len().max(1) * 2);
const _: () = {
// MAX_NUM_TERMS_FOR_VEC threshold is used for term aggregations
// Note: There may be other flexible values, for other aggregations, but we can use the
// const value here as a upper bound. (better than nothing)
let bucket_treshold_limit = FLUSH_THRESHOLD / (MAX_NUM_TERMS_FOR_VEC as usize * 2);
assert!(
bucket_treshold_limit > 0,
"Bucket threshold must be greater than 0"
);
};
if force {
bucket_treshold = 0;
}
for (bucket_id, docs) in self
.per_bucket_docs
.iter()
.enumerate()
.filter(|(_, docs)| docs.len() > bucket_treshold)
{
sub_agg.collect(bucket_id as BucketId, docs, agg_data)?;
}
self.clear();
Ok(())
}
}

View File

@@ -1,9 +1,9 @@
use super::agg_req::Aggregations; use super::agg_req::Aggregations;
use super::agg_result::AggregationResults; use super::agg_result::AggregationResults;
use super::buf_collector::BufAggregationCollector; use super::cached_sub_aggs::LowCardCachedSubAggs;
use super::intermediate_agg_result::IntermediateAggregationResults; use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationCollector;
use super::AggContextParams; use super::AggContextParams;
// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly.
use crate::aggregation::agg_data::{ use crate::aggregation::agg_data::{
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx, build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
}; };
@@ -136,7 +136,7 @@ fn merge_fruits(
/// `AggregationSegmentCollector` does the aggregation collection on a segment. /// `AggregationSegmentCollector` does the aggregation collection on a segment.
pub struct AggregationSegmentCollector { pub struct AggregationSegmentCollector {
aggs_with_accessor: AggregationsSegmentCtx, aggs_with_accessor: AggregationsSegmentCtx,
agg_collector: BufAggregationCollector, agg_collector: LowCardCachedSubAggs,
error: Option<TantivyError>, error: Option<TantivyError>,
} }
@@ -151,8 +151,11 @@ impl AggregationSegmentCollector {
) -> crate::Result<Self> { ) -> crate::Result<Self> {
let mut agg_data = let mut agg_data =
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?; build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
let result = let mut result =
BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?); LowCardCachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
result
.get_sub_agg_collector()
.prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero
Ok(AggregationSegmentCollector { Ok(AggregationSegmentCollector {
aggs_with_accessor: agg_data, aggs_with_accessor: agg_data,
@@ -170,26 +173,31 @@ impl SegmentCollector for AggregationSegmentCollector {
if self.error.is_some() { if self.error.is_some() {
return; return;
} }
if let Err(err) = self self.agg_collector.push(0, doc);
match self
.agg_collector .agg_collector
.collect(doc, &mut self.aggs_with_accessor) .check_flush_local(&mut self.aggs_with_accessor)
{ {
self.error = Some(err); Ok(_) => {}
Err(e) => {
self.error = Some(e);
}
} }
} }
/// The query pushes the documents to the collector via this method.
///
/// Only valid for Collectors that ignore docs
fn collect_block(&mut self, docs: &[DocId]) { fn collect_block(&mut self, docs: &[DocId]) {
if self.error.is_some() { if self.error.is_some() {
return; return;
} }
if let Err(err) = self
.agg_collector match self.agg_collector.get_sub_agg_collector().collect(
.collect_block(docs, &mut self.aggs_with_accessor) 0,
{ docs,
self.error = Some(err); &mut self.aggs_with_accessor,
) {
Ok(_) => {}
Err(e) => {
self.error = Some(e);
}
} }
} }
@@ -200,10 +208,13 @@ impl SegmentCollector for AggregationSegmentCollector {
self.agg_collector.flush(&mut self.aggs_with_accessor)?; self.agg_collector.flush(&mut self.aggs_with_accessor)?;
let mut sub_aggregation_res = IntermediateAggregationResults::default(); let mut sub_aggregation_res = IntermediateAggregationResults::default();
Box::new(self.agg_collector).add_intermediate_aggregation_result( self.agg_collector
&self.aggs_with_accessor, .get_sub_agg_collector()
&mut sub_aggregation_res, .add_intermediate_aggregation_result(
)?; &self.aggs_with_accessor,
&mut sub_aggregation_res,
0,
)?;
Ok(sub_aggregation_res) Ok(sub_aggregation_res)
} }

View File

@@ -792,7 +792,7 @@ pub struct IntermediateRangeBucketEntry {
/// The number of documents in the bucket. /// The number of documents in the bucket.
pub doc_count: u64, pub doc_count: u64,
/// The sub_aggregation in this bucket. /// The sub_aggregation in this bucket.
pub sub_aggregation: IntermediateAggregationResults, pub sub_aggregation_res: IntermediateAggregationResults,
/// The from range of the bucket. Equals `f64::MIN` when `None`. /// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>, pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`. /// The to range of the bucket. Equals `f64::MAX` when `None`.
@@ -811,7 +811,7 @@ impl IntermediateRangeBucketEntry {
key: self.key.into(), key: self.key.into(),
doc_count: self.doc_count, doc_count: self.doc_count,
sub_aggregation: self sub_aggregation: self
.sub_aggregation .sub_aggregation_res
.into_final_result_internal(req, limits)?, .into_final_result_internal(req, limits)?,
to: self.to, to: self.to,
from: self.from, from: self.from,
@@ -857,7 +857,8 @@ impl MergeFruits for IntermediateTermBucketEntry {
impl MergeFruits for IntermediateRangeBucketEntry { impl MergeFruits for IntermediateRangeBucketEntry {
fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> { fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> {
self.doc_count += other.doc_count; self.doc_count += other.doc_count;
self.sub_aggregation.merge_fruits(other.sub_aggregation)?; self.sub_aggregation_res
.merge_fruits(other.sub_aggregation_res)?;
Ok(()) Ok(())
} }
} }
@@ -887,7 +888,7 @@ mod tests {
IntermediateRangeBucketEntry { IntermediateRangeBucketEntry {
key: IntermediateKey::Str(key.to_string()), key: IntermediateKey::Str(key.to_string()),
doc_count: *doc_count, doc_count: *doc_count,
sub_aggregation: Default::default(), sub_aggregation_res: Default::default(),
from: None, from: None,
to: None, to: None,
}, },
@@ -920,7 +921,7 @@ mod tests {
doc_count: *doc_count, doc_count: *doc_count,
from: None, from: None,
to: None, to: None,
sub_aggregation: get_sub_test_tree(&[( sub_aggregation_res: get_sub_test_tree(&[(
sub_aggregation_key.to_string(), sub_aggregation_key.to_string(),
*sub_aggregation_count, *sub_aggregation_count,
)]), )]),

View File

@@ -52,10 +52,8 @@ pub struct IntermediateAverage {
impl IntermediateAverage { impl IntermediateAverage {
/// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`]. /// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { Self { stats }
stats: collector.stats,
}
} }
/// Merges the other intermediate result into self. /// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateAverage) { pub fn merge_fruits(&mut self, other: IntermediateAverage) {

View File

@@ -2,7 +2,7 @@ use std::collections::hash_map::DefaultHasher;
use std::hash::{BuildHasher, Hasher}; use std::hash::{BuildHasher, Hasher};
use columnar::column_values::CompactSpaceU64Accessor; use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{Column, ColumnBlockAccessor, ColumnType, Dictionary, StrColumn}; use columnar::{Column, ColumnType, Dictionary, StrColumn};
use common::f64_to_u64; use common::f64_to_u64;
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus}; use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
use rustc_hash::FxHashSet; use rustc_hash::FxHashSet;
@@ -106,8 +106,6 @@ pub struct CardinalityAggReqData {
pub str_dict_column: Option<StrColumn>, pub str_dict_column: Option<StrColumn>,
/// The missing value normalized to the internal u64 representation of the field type. /// The missing value normalized to the internal u64 representation of the field type.
pub missing_value_for_accessor: Option<u64>, pub missing_value_for_accessor: Option<u64>,
/// The column block accessor to access the fast field values.
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
/// The name of the aggregation. /// The name of the aggregation.
pub name: String, pub name: String,
/// The aggregation request. /// The aggregation request.
@@ -135,45 +133,34 @@ impl CardinalityAggregationReq {
} }
} }
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug)]
pub(crate) struct SegmentCardinalityCollector { pub(crate) struct SegmentCardinalityCollector {
cardinality: CardinalityCollector, buckets: Vec<SegmentCardinalityCollectorBucket>,
entries: FxHashSet<u64>,
accessor_idx: usize, accessor_idx: usize,
/// The column accessor to access the fast field values.
accessor: Column<u64>,
/// The column_type of the field.
column_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
missing_value_for_accessor: Option<u64>,
} }
impl SegmentCardinalityCollector { #[derive(Clone, Debug, PartialEq, Default)]
pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self { pub(crate) struct SegmentCardinalityCollectorBucket {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
}
impl SegmentCardinalityCollectorBucket {
pub fn new(column_type: ColumnType) -> Self {
Self { Self {
cardinality: CardinalityCollector::new(column_type as u8), cardinality: CardinalityCollector::new(column_type as u8),
entries: Default::default(), entries: FxHashSet::default(),
accessor_idx,
} }
} }
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_data: &mut CardinalityAggReqData,
) {
if let Some(missing) = agg_data.missing_value_for_accessor {
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&agg_data.accessor,
missing,
);
} else {
agg_data
.column_block_accessor
.fetch_block(docs, &agg_data.accessor);
}
}
fn into_intermediate_metric_result( fn into_intermediate_metric_result(
mut self, mut self,
agg_data: &AggregationsSegmentCtx, req_data: &CardinalityAggReqData,
) -> crate::Result<IntermediateMetricResult> { ) -> crate::Result<IntermediateMetricResult> {
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
if req_data.column_type == ColumnType::Str { if req_data.column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty(); let fallback_dict = Dictionary::empty();
let dict = req_data let dict = req_data
@@ -194,6 +181,7 @@ impl SegmentCardinalityCollector {
term_ids.push(term_ord as u32); term_ids.push(term_ord as u32);
} }
} }
term_ids.sort_unstable(); term_ids.sort_unstable();
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| { dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
self.cardinality.sketch.insert_any(&term); self.cardinality.sketch.insert_any(&term);
@@ -227,16 +215,49 @@ impl SegmentCardinalityCollector {
} }
} }
impl SegmentCardinalityCollector {
pub fn from_req(
column_type: ColumnType,
accessor_idx: usize,
accessor: Column<u64>,
missing_value_for_accessor: Option<u64>,
) -> Self {
Self {
buckets: vec![SegmentCardinalityCollectorBucket::new(column_type); 1],
column_type,
accessor_idx,
accessor,
missing_value_for_accessor,
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) {
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_value_for_accessor,
);
}
}
impl SegmentAggregationCollector for SegmentCardinalityCollector { impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn add_intermediate_aggregation_result( fn add_intermediate_aggregation_result(
self: Box<Self>, &mut self,
agg_data: &AggregationsSegmentCtx, agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults, results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> { ) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx); let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
let name = req_data.name.to_string(); let name = req_data.name.to_string();
// take the bucket in buckets and replace it with a new empty one
let bucket = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let intermediate_result = self.into_intermediate_metric_result(agg_data)?; let intermediate_result = bucket.into_intermediate_metric_result(req_data)?;
results.push( results.push(
name, name,
IntermediateAggregationResult::Metric(intermediate_result), IntermediateAggregationResult::Metric(intermediate_result),
@@ -247,27 +268,20 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn collect( fn collect(
&mut self, &mut self,
doc: crate::DocId, parent_bucket_id: BucketId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
fn collect_block(
&mut self,
docs: &[crate::DocId], docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx, agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> { ) -> crate::Result<()> {
let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx); self.fetch_block_with_field(docs, agg_data);
self.fetch_block_with_field(docs, req_data); let bucket = &mut self.buckets[parent_bucket_id as usize];
let col_block_accessor = &req_data.column_block_accessor; let col_block_accessor = &agg_data.column_block_accessor;
if req_data.column_type == ColumnType::Str { if self.column_type == ColumnType::Str {
for term_ord in col_block_accessor.iter_vals() { for term_ord in col_block_accessor.iter_vals() {
self.entries.insert(term_ord); bucket.entries.insert(term_ord);
} }
} else if req_data.column_type == ColumnType::IpAddr { } else if self.column_type == ColumnType::IpAddr {
let compact_space_accessor = req_data let compact_space_accessor = self
.accessor .accessor
.values .values
.clone() .clone()
@@ -282,16 +296,29 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
})?; })?;
for val in col_block_accessor.iter_vals() { for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32); let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
self.cardinality.sketch.insert_any(&val); bucket.cardinality.sketch.insert_any(&val);
} }
} else { } else {
for val in col_block_accessor.iter_vals() { for val in col_block_accessor.iter_vals() {
self.cardinality.sketch.insert_any(&val); bucket.cardinality.sketch.insert_any(&val);
} }
} }
Ok(()) Ok(())
} }
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
if max_bucket as usize >= self.buckets.len() {
self.buckets.resize_with(max_bucket as usize + 1, || {
SegmentCardinalityCollectorBucket::new(self.column_type)
});
}
Ok(())
}
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]

View File

@@ -52,10 +52,8 @@ pub struct IntermediateCount {
impl IntermediateCount { impl IntermediateCount {
/// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`]. /// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { Self { stats }
stats: collector.stats,
}
} }
/// Merges the other intermediate result into self. /// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateCount) { pub fn merge_fruits(&mut self, other: IntermediateCount) {

View File

@@ -8,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
}; };
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*; use crate::aggregation::*;
use crate::{DocId, TantivyError}; use crate::TantivyError;
/// A multi-value metric aggregation that computes a collection of extended statistics /// A multi-value metric aggregation that computes a collection of extended statistics
/// on numeric values that are extracted /// on numeric values that are extracted
@@ -318,51 +317,28 @@ impl IntermediateExtendedStats {
} }
} }
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug)]
pub(crate) struct SegmentExtendedStatsCollector { pub(crate) struct SegmentExtendedStatsCollector {
name: String,
missing: Option<u64>, missing: Option<u64>,
field_type: ColumnType, field_type: ColumnType,
pub(crate) extended_stats: IntermediateExtendedStats, accessor: columnar::Column<u64>,
pub(crate) accessor_idx: usize, buckets: Vec<IntermediateExtendedStats>,
val_cache: Vec<u64>, sigma: Option<f64>,
} }
impl SegmentExtendedStatsCollector { impl SegmentExtendedStatsCollector {
pub fn from_req( pub fn from_req(req: &MetricAggReqData, sigma: Option<f64>) -> Self {
field_type: ColumnType, let missing = req
sigma: Option<f64>, .missing
accessor_idx: usize, .and_then(|val| f64_to_fastfield_u64(val, &req.field_type));
missing: Option<f64>,
) -> Self {
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
Self { Self {
field_type, name: req.name.clone(),
extended_stats: IntermediateExtendedStats::with_sigma(sigma), field_type: req.field_type,
accessor_idx, accessor: req.accessor.clone(),
missing, missing,
val_cache: Default::default(), buckets: vec![IntermediateExtendedStats::with_sigma(sigma); 16],
} sigma,
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = self.missing.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
} }
} }
} }
@@ -370,15 +346,18 @@ impl SegmentExtendedStatsCollector {
impl SegmentAggregationCollector for SegmentExtendedStatsCollector { impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline] #[inline]
fn add_intermediate_aggregation_result( fn add_intermediate_aggregation_result(
self: Box<Self>, &mut self,
agg_data: &AggregationsSegmentCtx, agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults, results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> { ) -> crate::Result<()> {
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone(); let name = self.name.clone();
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let extended_stats = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
results.push( results.push(
name, name,
IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats( IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats(
self.extended_stats, extended_stats,
)), )),
)?; )?;
@@ -388,39 +367,36 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline] #[inline]
fn collect( fn collect(
&mut self, &mut self,
doc: crate::DocId, parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx, agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> { ) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data(self.accessor_idx); let mut extended_stats = self.buckets[parent_bucket_id as usize].clone();
if let Some(missing) = self.missing {
let mut has_val = false; agg_data
for val in req_data.accessor.values_for_doc(doc) { .column_block_accessor
let val1 = f64_from_fastfield_u64(val, &self.field_type); .fetch_block_with_missing(docs, &self.accessor, self.missing);
self.extended_stats.collect(val1); for val in agg_data.column_block_accessor.iter_vals() {
has_val = true; let val1 = f64_from_fastfield_u64(val, self.field_type);
} extended_stats.collect(val1);
if !has_val {
self.extended_stats
.collect(f64_from_fastfield_u64(missing, &self.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
}
} }
// store back
self.buckets[parent_bucket_id as usize] = extended_stats;
Ok(()) Ok(())
} }
#[inline] fn prepare_max_bucket(
fn collect_block(
&mut self, &mut self,
docs: &[crate::DocId], max_bucket: BucketId,
agg_data: &mut AggregationsSegmentCtx, _agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> { ) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); if self.buckets.len() <= max_bucket as usize {
self.collect_block_with_field(docs, req_data); self.buckets.resize_with(max_bucket as usize + 1, || {
IntermediateExtendedStats::with_sigma(self.sigma)
});
}
Ok(()) Ok(())
} }
} }

View File

@@ -52,10 +52,8 @@ pub struct IntermediateMax {
impl IntermediateMax { impl IntermediateMax {
/// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`]. /// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { Self { stats }
stats: collector.stats,
}
} }
/// Merges the other intermediate result into self. /// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateMax) { pub fn merge_fruits(&mut self, other: IntermediateMax) {

View File

@@ -52,10 +52,8 @@ pub struct IntermediateMin {
impl IntermediateMin { impl IntermediateMin {
/// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`]. /// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { Self { stats }
stats: collector.stats,
}
} }
/// Merges the other intermediate result into self. /// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateMin) { pub fn merge_fruits(&mut self, other: IntermediateMin) {

View File

@@ -31,7 +31,7 @@ use std::collections::HashMap;
pub use average::*; pub use average::*;
pub use cardinality::*; pub use cardinality::*;
use columnar::{Column, ColumnBlockAccessor, ColumnType}; use columnar::{Column, ColumnType};
pub use count::*; pub use count::*;
pub use extended_stats::*; pub use extended_stats::*;
pub use max::*; pub use max::*;
@@ -55,8 +55,6 @@ pub struct MetricAggReqData {
pub field_type: ColumnType, pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type. /// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>, pub missing_u64: Option<u64>,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The column accessor to access the fast field values. /// The column accessor to access the fast field values.
pub accessor: Column<u64>, pub accessor: Column<u64>,
/// Used when converting to intermediate result /// Used when converting to intermediate result

View File

@@ -7,10 +7,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
}; };
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*; use crate::aggregation::*;
use crate::{DocId, TantivyError}; use crate::TantivyError;
/// # Percentiles /// # Percentiles
/// ///
@@ -131,10 +130,16 @@ impl PercentilesAggregationReq {
} }
} }
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug)]
pub(crate) struct SegmentPercentilesCollector { pub(crate) struct SegmentPercentilesCollector {
pub(crate) percentiles: PercentilesCollector, pub(crate) buckets: Vec<PercentilesCollector>,
pub(crate) accessor_idx: usize, pub(crate) accessor_idx: usize,
/// The type of the field.
pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>,
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
@@ -229,33 +234,18 @@ impl PercentilesCollector {
} }
impl SegmentPercentilesCollector { impl SegmentPercentilesCollector {
pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result<Self> { pub fn from_req_and_validate(
Ok(Self { field_type: ColumnType,
percentiles: PercentilesCollector::new(), missing_u64: Option<u64>,
accessor: Column<u64>,
accessor_idx: usize,
) -> Self {
Self {
buckets: Vec::with_capacity(64),
field_type,
missing_u64,
accessor,
accessor_idx, accessor_idx,
})
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = req_data.missing_u64.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
} }
} }
} }
@@ -263,12 +253,18 @@ impl SegmentPercentilesCollector {
impl SegmentAggregationCollector for SegmentPercentilesCollector { impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline] #[inline]
fn add_intermediate_aggregation_result( fn add_intermediate_aggregation_result(
self: Box<Self>, &mut self,
agg_data: &AggregationsSegmentCtx, agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults, results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> { ) -> crate::Result<()> {
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone(); let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles); self.prepare_max_bucket(parent_bucket_id, agg_data)?;
// Swap collector with an empty one to avoid cloning
let percentiles_collector = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let intermediate_metric_result =
IntermediateMetricResult::Percentiles(percentiles_collector);
results.push( results.push(
name, name,
@@ -281,40 +277,33 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline] #[inline]
fn collect( fn collect(
&mut self, &mut self,
doc: crate::DocId, parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx, agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> { ) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data(self.accessor_idx); let percentiles = &mut self.buckets[parent_bucket_id as usize];
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_u64,
);
if let Some(missing) = req_data.missing_u64 { for val in agg_data.column_block_accessor.iter_vals() {
let mut has_val = false; let val1 = f64_from_fastfield_u64(val, self.field_type);
for val in req_data.accessor.values_for_doc(doc) { percentiles.collect(val1);
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
has_val = true;
}
if !has_val {
self.percentiles
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
}
} }
Ok(()) Ok(())
} }
#[inline] fn prepare_max_bucket(
fn collect_block(
&mut self, &mut self,
docs: &[crate::DocId], max_bucket: BucketId,
agg_data: &mut AggregationsSegmentCtx, _agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> { ) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); while self.buckets.len() <= max_bucket as usize {
self.collect_block_with_field(docs, req_data); self.buckets.push(PercentilesCollector::new());
}
Ok(()) Ok(())
} }
} }

View File

@@ -1,5 +1,6 @@
use std::fmt::Debug; use std::fmt::Debug;
use columnar::{Column, ColumnType};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::*; use super::*;
@@ -7,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
}; };
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*; use crate::aggregation::*;
use crate::{DocId, TantivyError}; use crate::TantivyError;
/// A multi-value metric aggregation that computes a collection of statistics on numeric values that /// A multi-value metric aggregation that computes a collection of statistics on numeric values that
/// are extracted from the aggregated documents. /// are extracted from the aggregated documents.
@@ -83,7 +83,7 @@ impl Stats {
/// Intermediate result of the stats aggregation that can be combined with other intermediate /// Intermediate result of the stats aggregation that can be combined with other intermediate
/// results. /// results.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateStats { pub struct IntermediateStats {
/// The number of extracted values. /// The number of extracted values.
pub(crate) count: u64, pub(crate) count: u64,
@@ -187,75 +187,75 @@ pub enum StatsType {
Percentiles, Percentiles,
} }
fn create_collector<const TYPE_ID: u8>(
req: &MetricAggReqData,
) -> Box<dyn SegmentAggregationCollector> {
Box::new(SegmentStatsCollector::<TYPE_ID> {
name: req.name.clone(),
collecting_for: req.collecting_for,
is_number_or_date_type: req.is_number_or_date_type,
missing_u64: req.missing_u64,
accessor: req.accessor.clone(),
buckets: vec![IntermediateStats::default()],
})
}
/// Build a concrete `SegmentStatsCollector` depending on the column type.
pub(crate) fn build_segment_stats_collector(
req: &MetricAggReqData,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
match req.field_type {
ColumnType::I64 => Ok(create_collector::<{ ColumnType::I64 as u8 }>(req)),
ColumnType::U64 => Ok(create_collector::<{ ColumnType::U64 as u8 }>(req)),
ColumnType::F64 => Ok(create_collector::<{ ColumnType::F64 as u8 }>(req)),
ColumnType::Bool => Ok(create_collector::<{ ColumnType::Bool as u8 }>(req)),
ColumnType::DateTime => Ok(create_collector::<{ ColumnType::DateTime as u8 }>(req)),
ColumnType::Bytes => Ok(create_collector::<{ ColumnType::Bytes as u8 }>(req)),
ColumnType::Str => Ok(create_collector::<{ ColumnType::Str as u8 }>(req)),
ColumnType::IpAddr => Ok(create_collector::<{ ColumnType::IpAddr as u8 }>(req)),
}
}
#[repr(C)]
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub(crate) struct SegmentStatsCollector { pub(crate) struct SegmentStatsCollector<const COLUMN_TYPE_ID: u8> {
pub(crate) stats: IntermediateStats, pub(crate) missing_u64: Option<u64>,
pub(crate) accessor_idx: usize, pub(crate) accessor: Column<u64>,
pub(crate) is_number_or_date_type: bool,
pub(crate) buckets: Vec<IntermediateStats>,
pub(crate) name: String,
pub(crate) collecting_for: StatsType,
} }
impl SegmentStatsCollector { impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
pub fn from_req(accessor_idx: usize) -> Self { for SegmentStatsCollector<COLUMN_TYPE_ID>
Self { {
stats: IntermediateStats::default(),
accessor_idx,
}
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = req_data.missing_u64.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
if req_data.is_number_or_date_type {
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
}
} else {
for _val in req_data.column_block_accessor.iter_vals() {
// we ignore the value and simply record that we got something
self.stats.collect(0.0);
}
}
}
}
impl SegmentAggregationCollector for SegmentStatsCollector {
#[inline] #[inline]
fn add_intermediate_aggregation_result( fn add_intermediate_aggregation_result(
self: Box<Self>, &mut self,
agg_data: &AggregationsSegmentCtx, agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults, results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> { ) -> crate::Result<()> {
let req = agg_data.get_metric_req_data(self.accessor_idx); let name = self.name.clone();
let name = req.name.clone();
let intermediate_metric_result = match req.collecting_for { self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let stats = self.buckets[parent_bucket_id as usize];
let intermediate_metric_result = match self.collecting_for {
StatsType::Average => { StatsType::Average => {
IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self)) IntermediateMetricResult::Average(IntermediateAverage::from_stats(stats))
} }
StatsType::Count => { StatsType::Count => {
IntermediateMetricResult::Count(IntermediateCount::from_collector(*self)) IntermediateMetricResult::Count(IntermediateCount::from_stats(stats))
} }
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)), StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_stats(stats)),
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)), StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_stats(stats)),
StatsType::Stats => IntermediateMetricResult::Stats(self.stats), StatsType::Stats => IntermediateMetricResult::Stats(stats),
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)), StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_stats(stats)),
_ => { _ => {
return Err(TantivyError::InvalidArgument(format!( return Err(TantivyError::InvalidArgument(format!(
"Unsupported stats type for stats aggregation: {:?}", "Unsupported stats type for stats aggregation: {:?}",
req.collecting_for self.collecting_for
))) )))
} }
}; };
@@ -271,41 +271,67 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
#[inline] #[inline]
fn collect( fn collect(
&mut self, &mut self,
doc: crate::DocId, parent_bucket_id: BucketId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
if let Some(missing) = req_data.missing_u64 {
let mut has_val = false;
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
has_val = true;
}
if !has_val {
self.stats
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
}
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId], docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx, agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> { ) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx); // TODO: remove once we fetch all values for all bucket ids in one go
self.collect_block_with_field(docs, req_data); if docs.len() == 1 && self.missing_u64.is_none() {
collect_stats::<COLUMN_TYPE_ID>(
&mut self.buckets[parent_bucket_id as usize],
self.accessor.values_for_doc(docs[0]),
self.is_number_or_date_type,
)?;
return Ok(());
}
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_u64,
);
collect_stats::<COLUMN_TYPE_ID>(
&mut self.buckets[parent_bucket_id as usize],
agg_data.column_block_accessor.iter_vals(),
self.is_number_or_date_type,
)?;
Ok(()) Ok(())
} }
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let required_buckets = (max_bucket as usize) + 1;
if self.buckets.len() < required_buckets {
self.buckets
.resize_with(required_buckets, IntermediateStats::default);
}
Ok(())
}
}
#[inline]
fn collect_stats<const COLUMN_TYPE_ID: u8>(
stats: &mut IntermediateStats,
vals: impl Iterator<Item = u64>,
is_number_or_date_type: bool,
) -> crate::Result<()> {
if is_number_or_date_type {
for val in vals {
let val1 = convert_to_f64::<COLUMN_TYPE_ID>(val);
stats.collect(val1);
}
} else {
for _val in vals {
// we ignore the value and simply record that we got something
stats.collect(0.0);
}
}
Ok(())
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -52,10 +52,8 @@ pub struct IntermediateSum {
impl IntermediateSum { impl IntermediateSum {
/// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`]. /// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self { pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { Self { stats }
stats: collector.stats,
}
} }
/// Merges the other intermediate result into self. /// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateSum) { pub fn merge_fruits(&mut self, other: IntermediateSum) {

View File

@@ -15,12 +15,11 @@ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateMetricResult, IntermediateAggregationResult, IntermediateMetricResult,
}; };
use crate::aggregation::segment_agg_result::SegmentAggregationCollector; use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::AggregationError; use crate::aggregation::{AggregationError, BucketId};
use crate::collector::sort_key::ReverseComparator; use crate::collector::sort_key::ReverseComparator;
use crate::collector::TopNComputer; use crate::collector::TopNComputer;
use crate::schema::OwnedValue; use crate::schema::OwnedValue;
use crate::{DocAddress, DocId, SegmentOrdinal}; use crate::{DocAddress, DocId, SegmentOrdinal};
// duplicate import removed; already imported above
/// Contains all information required by the TopHitsSegmentCollector to perform the /// Contains all information required by the TopHitsSegmentCollector to perform the
/// top_hits aggregation on a segment. /// top_hits aggregation on a segment.
@@ -472,7 +471,10 @@ impl TopHitsTopNComputer {
/// Create a new TopHitsCollector /// Create a new TopHitsCollector
pub fn new(req: &TopHitsAggregationReq) -> Self { pub fn new(req: &TopHitsAggregationReq) -> Self {
Self { Self {
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), top_n: TopNComputer::new_with_comparator(
req.size + req.from.unwrap_or(0),
ReverseComparator,
),
req: req.clone(), req: req.clone(),
} }
} }
@@ -518,7 +520,8 @@ impl TopHitsTopNComputer {
pub(crate) struct TopHitsSegmentCollector { pub(crate) struct TopHitsSegmentCollector {
segment_ordinal: SegmentOrdinal, segment_ordinal: SegmentOrdinal,
accessor_idx: usize, accessor_idx: usize,
top_n: TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>, buckets: Vec<TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>>,
num_hits: usize,
} }
impl TopHitsSegmentCollector { impl TopHitsSegmentCollector {
@@ -527,19 +530,29 @@ impl TopHitsSegmentCollector {
accessor_idx: usize, accessor_idx: usize,
segment_ordinal: SegmentOrdinal, segment_ordinal: SegmentOrdinal,
) -> Self { ) -> Self {
let num_hits = req.size + req.from.unwrap_or(0);
Self { Self {
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), num_hits,
segment_ordinal, segment_ordinal,
accessor_idx, accessor_idx,
buckets: vec![TopNComputer::new_with_comparator(num_hits, ReverseComparator); 1],
} }
} }
fn into_top_hits_collector( fn get_top_hits_computer(
self, &mut self,
parent_bucket_id: BucketId,
value_accessors: &HashMap<String, Vec<DynamicColumn>>, value_accessors: &HashMap<String, Vec<DynamicColumn>>,
req: &TopHitsAggregationReq, req: &TopHitsAggregationReq,
) -> TopHitsTopNComputer { ) -> TopHitsTopNComputer {
if parent_bucket_id as usize >= self.buckets.len() {
return TopHitsTopNComputer::new(req);
}
let top_n = std::mem::replace(
&mut self.buckets[parent_bucket_id as usize],
TopNComputer::new(0),
);
let mut top_hits_computer = TopHitsTopNComputer::new(req); let mut top_hits_computer = TopHitsTopNComputer::new(req);
let top_results = self.top_n.into_vec(); let top_results = top_n.into_vec();
for res in top_results { for res in top_results {
let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id); let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id);
@@ -554,54 +567,24 @@ impl TopHitsSegmentCollector {
top_hits_computer top_hits_computer
} }
/// TODO add a specialized variant for a single sort field
fn collect_with(
&mut self,
doc_id: crate::DocId,
req: &TopHitsAggregationReq,
accessors: &[(Column<u64>, ColumnType)],
) -> crate::Result<()> {
let sorts: Vec<DocValueAndOrder> = req
.sort
.iter()
.enumerate()
.map(|(idx, KeyOrder { order, .. })| {
let order = *order;
let value = accessors
.get(idx)
.expect("could not find field in accessors")
.0
.values_for_doc(doc_id)
.next();
DocValueAndOrder { value, order }
})
.collect();
self.top_n.push(
sorts,
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,
},
);
Ok(())
}
} }
impl SegmentAggregationCollector for TopHitsSegmentCollector { impl SegmentAggregationCollector for TopHitsSegmentCollector {
fn add_intermediate_aggregation_result( fn add_intermediate_aggregation_result(
self: Box<Self>, &mut self,
agg_data: &AggregationsSegmentCtx, agg_data: &AggregationsSegmentCtx,
results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults, results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> { ) -> crate::Result<()> {
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
let value_accessors = &req_data.value_accessors; let value_accessors = &req_data.value_accessors;
let intermediate_result = IntermediateMetricResult::TopHits( let intermediate_result = IntermediateMetricResult::TopHits(self.get_top_hits_computer(
self.into_top_hits_collector(value_accessors, &req_data.req), parent_bucket_id,
); value_accessors,
&req_data.req,
));
results.push( results.push(
req_data.name.to_string(), req_data.name.to_string(),
IntermediateAggregationResult::Metric(intermediate_result), IntermediateAggregationResult::Metric(intermediate_result),
@@ -611,26 +594,56 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
/// TODO: Consider a caching layer to reduce the call overhead /// TODO: Consider a caching layer to reduce the call overhead
fn collect( fn collect(
&mut self, &mut self,
doc_id: crate::DocId, parent_bucket_id: BucketId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
self.collect_with(doc_id, &req_data.req, &req_data.accessors)?;
Ok(())
}
fn collect_block(
&mut self,
docs: &[crate::DocId], docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx, agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> { ) -> crate::Result<()> {
let top_n = &mut self.buckets[parent_bucket_id as usize];
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx); let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
// TODO: Consider getting fields with the column block accessor. let req = &req_data.req;
for doc in docs { let accessors = &req_data.accessors;
self.collect_with(*doc, &req_data.req, &req_data.accessors)?; for &doc_id in docs {
// TODO: this is terrible, a new vec is allocated for every doc
// We can fetch blocks instead
// We don't need to store the order for every value
let sorts: Vec<DocValueAndOrder> = req
.sort
.iter()
.enumerate()
.map(|(idx, KeyOrder { order, .. })| {
let order = *order;
let value = accessors
.get(idx)
.expect("could not find field in accessors")
.0
.values_for_doc(doc_id)
.next();
DocValueAndOrder { value, order }
})
.collect();
top_n.push(
sorts,
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,
},
);
} }
Ok(()) Ok(())
} }
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
self.buckets.resize(
(max_bucket as usize) + 1,
TopNComputer::new_with_comparator(self.num_hits, ReverseComparator),
);
Ok(())
}
} }
#[cfg(test)] #[cfg(test)]
@@ -746,7 +759,7 @@ mod tests {
], ],
"from": 0, "from": 0,
} }
} }
})) }))
.unwrap(); .unwrap();
@@ -875,7 +888,7 @@ mod tests {
"mixed.*", "mixed.*",
], ],
} }
} }
}))?; }))?;
let collector = AggregationCollector::from_aggs(d, Default::default()); let collector = AggregationCollector::from_aggs(d, Default::default());

View File

@@ -133,7 +133,7 @@ mod agg_limits;
pub mod agg_req; pub mod agg_req;
pub mod agg_result; pub mod agg_result;
pub mod bucket; pub mod bucket;
mod buf_collector; pub(crate) mod cached_sub_aggs;
mod collector; mod collector;
mod date; mod date;
mod error; mod error;
@@ -162,6 +162,19 @@ use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::TokenizerManager; use crate::tokenizer::TokenizerManager;
/// A bucket id is a dense identifier for a bucket within an aggregation.
/// It is used to index into a Vec that hold per-bucket data.
///
/// For example, in a terms aggregation, each unique term will be assigned a incremental BucketId.
/// This BucketId will be forwarded to sub-aggregations to identify the parent bucket.
///
/// This allows to have a single AggregationCollector instance per aggregation,
/// that can handle multiple buckets efficiently.
///
/// The API to call sub-aggregations is therefore a &[(BucketId, &[DocId])].
/// For that we'll need a buffer. One Vec per bucket aggregation is needed.
pub type BucketId = u32;
/// Context parameters for aggregation execution /// Context parameters for aggregation execution
/// ///
/// This struct holds shared resources needed during aggregation execution: /// This struct holds shared resources needed during aggregation execution:
@@ -335,19 +348,37 @@ impl Display for Key {
} }
} }
pub(crate) fn convert_to_f64<const COLUMN_TYPE_ID: u8>(val: u64) -> f64 {
if COLUMN_TYPE_ID == ColumnType::U64 as u8 {
val as f64
} else if COLUMN_TYPE_ID == ColumnType::I64 as u8
|| COLUMN_TYPE_ID == ColumnType::DateTime as u8
{
i64::from_u64(val) as f64
} else if COLUMN_TYPE_ID == ColumnType::F64 as u8 {
f64::from_u64(val)
} else if COLUMN_TYPE_ID == ColumnType::Bool as u8 {
val as f64
} else {
panic!(
"ColumnType ID {} cannot be converted to f64 metric",
COLUMN_TYPE_ID
)
}
}
/// Inverse of `to_fastfield_u64`. Used to convert to `f64` for metrics. /// Inverse of `to_fastfield_u64`. Used to convert to `f64` for metrics.
/// ///
/// # Panics /// # Panics
/// Only `u64`, `f64`, `date`, and `i64` are supported. /// Only `u64`, `f64`, `date`, and `i64` are supported.
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 { pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: ColumnType) -> f64 {
match field_type { match field_type {
ColumnType::U64 => val as f64, ColumnType::U64 => convert_to_f64::<{ ColumnType::U64 as u8 }>(val),
ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64, ColumnType::I64 => convert_to_f64::<{ ColumnType::I64 as u8 }>(val),
ColumnType::F64 => f64::from_u64(val), ColumnType::F64 => convert_to_f64::<{ ColumnType::F64 as u8 }>(val),
ColumnType::Bool => val as f64, ColumnType::Bool => convert_to_f64::<{ ColumnType::Bool as u8 }>(val),
_ => { ColumnType::DateTime => convert_to_f64::<{ ColumnType::DateTime as u8 }>(val),
panic!("unexpected type {field_type:?}. This should not happen") _ => panic!("unexpected type {field_type:?}. This should not happen"),
}
} }
} }

View File

@@ -8,25 +8,67 @@ use std::fmt::Debug;
pub(crate) use super::agg_limits::AggregationLimitsGuard; pub(crate) use super::agg_limits::AggregationLimitsGuard;
use super::intermediate_agg_result::IntermediateAggregationResults; use super::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::agg_data::AggregationsSegmentCtx; use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::BucketId;
/// Monotonically increasing provider of BucketIds.
#[derive(Debug, Clone, Default)]
pub struct BucketIdProvider(u32);
impl BucketIdProvider {
/// Get the next BucketId.
pub fn next_bucket_id(&mut self) -> BucketId {
let bucket_id = self.0;
self.0 += 1;
bucket_id
}
}
/// A SegmentAggregationCollector is used to collect aggregation results. /// A SegmentAggregationCollector is used to collect aggregation results.
pub trait SegmentAggregationCollector: CollectorClone + Debug { pub trait SegmentAggregationCollector: Debug {
fn add_intermediate_aggregation_result( fn add_intermediate_aggregation_result(
self: Box<Self>, &mut self,
agg_data: &AggregationsSegmentCtx, agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults, results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()>; ) -> crate::Result<()>;
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
fn collect( fn collect(
&mut self, &mut self,
doc: crate::DocId, parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx, agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()>; ) -> crate::Result<()>;
fn collect_block( /// Collect docs for multiple buckets in one call.
/// Minimizes dynamic dispatch overhead when collecting many buckets.
///
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
fn collect_multiple(
&mut self, &mut self,
bucket_ids: &[BucketId],
docs: &[crate::DocId], docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx, agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
debug_assert_eq!(bucket_ids.len(), docs.len());
let mut start = 0;
while start < bucket_ids.len() {
let bucket_id = bucket_ids[start];
let mut end = start + 1;
while end < bucket_ids.len() && bucket_ids[end] == bucket_id {
end += 1;
}
self.collect(bucket_id, &docs[start..end], agg_data)?;
start = end;
}
Ok(())
}
/// Prepare the collector for collecting up to BucketId `max_bucket`.
/// This is useful so we can split allocation ahead of time of collecting.
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()>; ) -> crate::Result<()>;
/// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`. /// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`.
@@ -36,26 +78,7 @@ pub trait SegmentAggregationCollector: CollectorClone + Debug {
} }
} }
/// A helper trait to enable cloning of Box<dyn SegmentAggregationCollector> #[derive(Default)]
pub trait CollectorClone {
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector>;
}
impl<T> CollectorClone for T
where T: 'static + SegmentAggregationCollector + Clone
{
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn SegmentAggregationCollector> {
fn clone(&self) -> Box<dyn SegmentAggregationCollector> {
self.clone_box()
}
}
#[derive(Clone, Default)]
/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which /// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which
/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one /// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one
/// and can provide specialized versions instead, that remove some of its overhead. /// and can provide specialized versions instead, that remove some of its overhead.
@@ -73,12 +96,13 @@ impl Debug for GenericSegmentAggregationResultsCollector {
impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector { impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn add_intermediate_aggregation_result( fn add_intermediate_aggregation_result(
self: Box<Self>, &mut self,
agg_data: &AggregationsSegmentCtx, agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults, results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> { ) -> crate::Result<()> {
for agg in self.aggs { for agg in &mut self.aggs {
agg.add_intermediate_aggregation_result(agg_data, results)?; agg.add_intermediate_aggregation_result(agg_data, results, parent_bucket_id)?;
} }
Ok(()) Ok(())
@@ -86,23 +110,13 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn collect( fn collect(
&mut self, &mut self,
doc: crate::DocId, parent_bucket_id: BucketId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)?;
Ok(())
}
fn collect_block(
&mut self,
docs: &[crate::DocId], docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx, agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> { ) -> crate::Result<()> {
for collector in &mut self.aggs { for collector in &mut self.aggs {
collector.collect_block(docs, agg_data)?; collector.collect(parent_bucket_id, docs, agg_data)?;
} }
Ok(()) Ok(())
} }
@@ -112,4 +126,15 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
} }
Ok(()) Ok(())
} }
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.prepare_max_bucket(max_bucket, agg_data)?;
}
Ok(())
}
} }

View File

@@ -1,10 +1,12 @@
mod order; mod order;
mod sort_by_erased_type;
mod sort_by_score; mod sort_by_score;
mod sort_by_static_fast_value; mod sort_by_static_fast_value;
mod sort_by_string; mod sort_by_string;
mod sort_key_computer; mod sort_key_computer;
pub use order::*; pub use order::*;
pub use sort_by_erased_type::SortByErasedType;
pub use sort_by_score::SortBySimilarityScore; pub use sort_by_score::SortBySimilarityScore;
pub use sort_by_static_fast_value::SortByStaticFastValue; pub use sort_by_static_fast_value::SortByStaticFastValue;
pub use sort_by_string::SortByString; pub use sort_by_string::SortByString;
@@ -34,11 +36,13 @@ pub(crate) mod tests {
use std::collections::HashMap; use std::collections::HashMap;
use std::ops::Range; use std::ops::Range;
use crate::collector::sort_key::{SortBySimilarityScore, SortByStaticFastValue, SortByString}; use crate::collector::sort_key::{
SortByErasedType, SortBySimilarityScore, SortByStaticFastValue, SortByString,
};
use crate::collector::{ComparableDoc, DocSetCollector, TopDocs}; use crate::collector::{ComparableDoc, DocSetCollector, TopDocs};
use crate::indexer::NoMergePolicy; use crate::indexer::NoMergePolicy;
use crate::query::{AllQuery, QueryParser}; use crate::query::{AllQuery, QueryParser};
use crate::schema::{Schema, FAST, TEXT}; use crate::schema::{OwnedValue, Schema, FAST, TEXT};
use crate::{DocAddress, Document, Index, Order, Score, Searcher}; use crate::{DocAddress, Document, Index, Order, Score, Searcher};
fn make_index() -> crate::Result<Index> { fn make_index() -> crate::Result<Index> {
@@ -313,11 +317,9 @@ pub(crate) mod tests {
(SortBySimilarityScore, score_order), (SortBySimilarityScore, score_order),
(SortByString::for_field("city"), city_order), (SortByString::for_field("city"), city_order),
)); ));
Ok(searcher let results: Vec<((Score, Option<String>), DocAddress)> =
.search(&AllQuery, &top_collector)? searcher.search(&AllQuery, &top_collector)?;
.into_iter() Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect())
.map(|(f, doc)| (f, ids[&doc]))
.collect())
} }
assert_eq!( assert_eq!(
@@ -342,6 +344,51 @@ pub(crate) mod tests {
Ok(()) Ok(())
} }
#[test]
fn test_order_by_score_then_owned_value() -> crate::Result<()> {
let index = make_index()?;
type SortKey = (Score, OwnedValue);
fn query(
index: &Index,
score_order: Order,
city_order: Order,
) -> crate::Result<Vec<(SortKey, u64)>> {
let searcher = index.reader()?.searcher();
let ids = id_mapping(&searcher);
let top_collector = TopDocs::with_limit(4).order_by::<(Score, OwnedValue)>((
(SortBySimilarityScore, score_order),
(SortByErasedType::for_field("city"), city_order),
));
let results: Vec<((Score, OwnedValue), DocAddress)> =
searcher.search(&AllQuery, &top_collector)?;
Ok(results.into_iter().map(|(f, doc)| (f, ids[&doc])).collect())
}
assert_eq!(
&query(&index, Order::Asc, Order::Asc)?,
&[
((1.0, OwnedValue::Str("austin".to_owned())), 0),
((1.0, OwnedValue::Str("greenville".to_owned())), 1),
((1.0, OwnedValue::Str("tokyo".to_owned())), 2),
((1.0, OwnedValue::Null), 3),
]
);
assert_eq!(
&query(&index, Order::Asc, Order::Desc)?,
&[
((1.0, OwnedValue::Str("tokyo".to_owned())), 2),
((1.0, OwnedValue::Str("greenville".to_owned())), 1),
((1.0, OwnedValue::Str("austin".to_owned())), 0),
((1.0, OwnedValue::Null), 3),
]
);
Ok(())
}
use proptest::prelude::*; use proptest::prelude::*;
proptest! { proptest! {

View File

@@ -1,11 +1,70 @@
use std::cmp::Ordering; use std::cmp::Ordering;
use columnar::MonotonicallyMappableToU64;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer}; use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::schema::Schema; use crate::schema::{OwnedValue, Schema};
use crate::{DocId, Order, Score}; use crate::{DocId, Order, Score};
fn compare_owned_value<const NULLS_FIRST: bool>(lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
match (lhs, rhs) {
(OwnedValue::Null, OwnedValue::Null) => Ordering::Equal,
(OwnedValue::Null, _) => {
if NULLS_FIRST {
Ordering::Less
} else {
Ordering::Greater
}
}
(_, OwnedValue::Null) => {
if NULLS_FIRST {
Ordering::Greater
} else {
Ordering::Less
}
}
(OwnedValue::Str(a), OwnedValue::Str(b)) => a.cmp(b),
(OwnedValue::PreTokStr(a), OwnedValue::PreTokStr(b)) => a.cmp(b),
(OwnedValue::U64(a), OwnedValue::U64(b)) => a.cmp(b),
(OwnedValue::I64(a), OwnedValue::I64(b)) => a.cmp(b),
(OwnedValue::F64(a), OwnedValue::F64(b)) => a.to_u64().cmp(&b.to_u64()),
(OwnedValue::Bool(a), OwnedValue::Bool(b)) => a.cmp(b),
(OwnedValue::Date(a), OwnedValue::Date(b)) => a.cmp(b),
(OwnedValue::Facet(a), OwnedValue::Facet(b)) => a.cmp(b),
(OwnedValue::Bytes(a), OwnedValue::Bytes(b)) => a.cmp(b),
(OwnedValue::IpAddr(a), OwnedValue::IpAddr(b)) => a.cmp(b),
(OwnedValue::U64(a), OwnedValue::I64(b)) => {
if *b < 0 {
Ordering::Greater
} else {
a.cmp(&(*b as u64))
}
}
(OwnedValue::I64(a), OwnedValue::U64(b)) => {
if *a < 0 {
Ordering::Less
} else {
(*a as u64).cmp(b)
}
}
(OwnedValue::U64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()),
(OwnedValue::F64(a), OwnedValue::U64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()),
(OwnedValue::I64(a), OwnedValue::F64(b)) => (*a as f64).to_u64().cmp(&b.to_u64()),
(OwnedValue::F64(a), OwnedValue::I64(b)) => a.to_u64().cmp(&(*b as f64).to_u64()),
(a, b) => {
let ord = a.discriminant_value().cmp(&b.discriminant_value());
// If the discriminant is equal, it's because a new type was added, but hasn't been
// included in this `match` statement.
assert!(
ord != Ordering::Equal,
"Unimplemented comparison for type of {a:?}, {b:?}"
);
ord
}
}
}
/// Comparator trait defining the order in which documents should be ordered. /// Comparator trait defining the order in which documents should be ordered.
pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default { pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
/// Return the order between two values. /// Return the order between two values.
@@ -25,7 +84,18 @@ pub struct NaturalComparator;
impl<T: PartialOrd> Comparator<T> for NaturalComparator { impl<T: PartialOrd> Comparator<T> for NaturalComparator {
#[inline(always)] #[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering { fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
lhs.partial_cmp(rhs).unwrap() lhs.partial_cmp(rhs).unwrap_or(Ordering::Equal)
}
}
/// A (partial) implementation of comparison for OwnedValue.
///
/// Intended for use within columns of homogenous types, and so will panic for OwnedValues with
/// mismatched types. The one exception is Null, for which we do define all comparisons.
impl Comparator<OwnedValue> for NaturalComparator {
#[inline(always)]
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ true>(lhs, rhs)
} }
} }
@@ -121,6 +191,13 @@ impl Comparator<String> for ReverseNoneIsLowerComparator {
} }
} }
impl Comparator<OwnedValue> for ReverseNoneIsLowerComparator {
#[inline(always)]
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ false>(rhs, lhs)
}
}
/// Compare values naturally, but treating `None` as higher than `Some`. /// Compare values naturally, but treating `None` as higher than `Some`.
/// ///
/// When used with `TopDocs`, which reverses the order, this results in a /// When used with `TopDocs`, which reverses the order, this results in a
@@ -185,6 +262,13 @@ impl Comparator<String> for NaturalNoneIsHigherComparator {
} }
} }
impl Comparator<OwnedValue> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
compare_owned_value::</* NULLS_FIRST= */ false>(lhs, rhs)
}
}
/// An enum representing the different sort orders. /// An enum representing the different sort orders.
#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)] #[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
pub enum ComparatorEnum { pub enum ComparatorEnum {
@@ -404,11 +488,12 @@ impl<TSegmentSortKeyComputer, TSegmentSortKey, TComparator> SegmentSortKeyComput
for SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator> for SegmentSortKeyComputerWithComparator<TSegmentSortKeyComputer, TComparator>
where where
TSegmentSortKeyComputer: SegmentSortKeyComputer<SegmentSortKey = TSegmentSortKey>, TSegmentSortKeyComputer: SegmentSortKeyComputer<SegmentSortKey = TSegmentSortKey>,
TSegmentSortKey: PartialOrd + Clone + 'static + Sync + Send, TSegmentSortKey: Clone + 'static + Sync + Send,
TComparator: Comparator<TSegmentSortKey> + 'static + Sync + Send, TComparator: Comparator<TSegmentSortKey> + 'static + Sync + Send,
{ {
type SortKey = TSegmentSortKeyComputer::SortKey; type SortKey = TSegmentSortKeyComputer::SortKey;
type SegmentSortKey = TSegmentSortKey; type SegmentSortKey = TSegmentSortKey;
type SegmentComparator = TComparator;
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.segment_sort_key_computer.segment_sort_key(doc, score) self.segment_sort_key_computer.segment_sort_key(doc, score)
@@ -432,6 +517,7 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::schema::OwnedValue;
#[test] #[test]
fn test_natural_none_is_higher() { fn test_natural_none_is_higher() {
@@ -455,4 +541,27 @@ mod tests {
// compare(None, None) should be Equal. // compare(None, None) should be Equal.
assert_eq!(comp.compare(&null, &null), Ordering::Equal); assert_eq!(comp.compare(&null, &null), Ordering::Equal);
} }
#[test]
fn test_mixed_ownedvalue_compare() {
let u = OwnedValue::U64(10);
let i = OwnedValue::I64(10);
let f = OwnedValue::F64(10.0);
let nc = NaturalComparator;
assert_eq!(nc.compare(&u, &i), Ordering::Equal);
assert_eq!(nc.compare(&u, &f), Ordering::Equal);
assert_eq!(nc.compare(&i, &f), Ordering::Equal);
let u2 = OwnedValue::U64(11);
assert_eq!(nc.compare(&u2, &f), Ordering::Greater);
let s = OwnedValue::Str("a".to_string());
// Str < U64
assert_eq!(nc.compare(&s, &u), Ordering::Less);
// Str < I64
assert_eq!(nc.compare(&s, &i), Ordering::Less);
// Str < F64
assert_eq!(nc.compare(&s, &f), Ordering::Less);
}
} }

View File

@@ -0,0 +1,361 @@
use columnar::{ColumnType, MonotonicallyMappableToU64};
use crate::collector::sort_key::{
NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString,
};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::FastFieldNotAvailableError;
use crate::schema::OwnedValue;
use crate::{DateTime, DocId, Score};
/// Sort by the boxed / OwnedValue representation of either a fast field, or of the score.
///
/// Using the OwnedValue representation allows for type erasure, and can be useful when sort orders
/// are not known until runtime. But it comes with a performance cost: wherever possible, prefer to
/// use a SortKeyComputer implementation with a known-type at compile time.
#[derive(Debug, Clone)]
pub enum SortByErasedType {
/// Sort by a fast field
Field(String),
/// Sort by score
Score,
}
impl SortByErasedType {
/// Creates a new sort key computer which will sort by the given fast field column, with type
/// erasure.
pub fn for_field(column_name: impl ToString) -> Self {
Self::Field(column_name.to_string())
}
/// Creates a new sort key computer which will sort by score, with type erasure.
pub fn for_score() -> Self {
Self::Score
}
}
trait ErasedSegmentSortKeyComputer: Send + Sync {
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64>;
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue;
}
struct ErasedSegmentSortKeyComputerWrapper<C, F> {
inner: C,
converter: F,
}
impl<C, F> ErasedSegmentSortKeyComputer for ErasedSegmentSortKeyComputerWrapper<C, F>
where
C: SegmentSortKeyComputer<SegmentSortKey = Option<u64>> + Send + Sync,
F: Fn(C::SortKey) -> OwnedValue + Send + Sync + 'static,
{
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
self.inner.segment_sort_key(doc, score)
}
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
let val = self.inner.convert_segment_sort_key(sort_key);
(self.converter)(val)
}
}
struct ScoreSegmentSortKeyComputer {
segment_computer: SortBySimilarityScore,
}
impl ErasedSegmentSortKeyComputer for ScoreSegmentSortKeyComputer {
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
let score_value: f64 = self.segment_computer.segment_sort_key(doc, score).into();
Some(score_value.to_u64())
}
fn convert_segment_sort_key(&self, sort_key: Option<u64>) -> OwnedValue {
let score_value: u64 = sort_key.expect("This implementation always produces a score.");
OwnedValue::F64(f64::from_u64(score_value))
}
}
impl SortKeyComputer for SortByErasedType {
type SortKey = OwnedValue;
type Child = ErasedColumnSegmentSortKeyComputer;
type Comparator = NaturalComparator;
fn requires_scoring(&self) -> bool {
matches!(self, Self::Score)
}
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let inner: Box<dyn ErasedSegmentSortKeyComputer> = match self {
Self::Field(column_name) => {
let fast_fields = segment_reader.fast_fields();
// TODO: We currently double-open the column to avoid relying on the implementation
// details of `SortByString` or `SortByStaticFastValue`. Once
// https://github.com/quickwit-oss/tantivy/issues/2776 is resolved, we should
// consider directly constructing the appropriate `SegmentSortKeyComputer` type for
// the column that we open here.
let (_column, column_type) =
fast_fields.u64_lenient(column_name)?.ok_or_else(|| {
FastFieldNotAvailableError {
field_name: column_name.to_owned(),
}
})?;
match column_type {
ColumnType::Str => {
let computer = SortByString::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<String>| {
val.map(OwnedValue::Str).unwrap_or(OwnedValue::Null)
},
})
}
ColumnType::U64 => {
let computer = SortByStaticFastValue::<u64>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<u64>| {
val.map(OwnedValue::U64).unwrap_or(OwnedValue::Null)
},
})
}
ColumnType::I64 => {
let computer = SortByStaticFastValue::<i64>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<i64>| {
val.map(OwnedValue::I64).unwrap_or(OwnedValue::Null)
},
})
}
ColumnType::F64 => {
let computer = SortByStaticFastValue::<f64>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<f64>| {
val.map(OwnedValue::F64).unwrap_or(OwnedValue::Null)
},
})
}
ColumnType::Bool => {
let computer = SortByStaticFastValue::<bool>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<bool>| {
val.map(OwnedValue::Bool).unwrap_or(OwnedValue::Null)
},
})
}
ColumnType::DateTime => {
let computer = SortByStaticFastValue::<DateTime>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<DateTime>| {
val.map(OwnedValue::Date).unwrap_or(OwnedValue::Null)
},
})
}
column_type => {
return Err(crate::TantivyError::SchemaError(format!(
"Field `{}` is of type {column_type:?}, which is not supported for \
sorting by owned value yet.",
column_name
)))
}
}
}
Self::Score => Box::new(ScoreSegmentSortKeyComputer {
segment_computer: SortBySimilarityScore,
}),
};
Ok(ErasedColumnSegmentSortKeyComputer { inner })
}
}
pub struct ErasedColumnSegmentSortKeyComputer {
inner: Box<dyn ErasedSegmentSortKeyComputer>,
}
impl SegmentSortKeyComputer for ErasedColumnSegmentSortKeyComputer {
type SortKey = OwnedValue;
type SegmentSortKey = Option<u64>;
type SegmentComparator = NaturalComparator;
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Option<u64> {
self.inner.segment_sort_key(doc, score)
}
fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> OwnedValue {
self.inner.convert_segment_sort_key(segment_sort_key)
}
}
#[cfg(test)]
mod tests {
use crate::collector::sort_key::{ComparatorEnum, SortByErasedType};
use crate::collector::TopDocs;
use crate::query::AllQuery;
use crate::schema::{OwnedValue, Schema, FAST, TEXT};
use crate::Index;
#[test]
fn test_sort_by_owned_u64() {
let mut schema_builder = Schema::builder();
let id_field = schema_builder.add_u64_field("id", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.add_document(doc!(id_field => 10u64)).unwrap();
writer.add_document(doc!(id_field => 2u64)).unwrap();
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let collector = TopDocs::with_limit(10)
.order_by((SortByErasedType::for_field("id"), ComparatorEnum::Natural));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![OwnedValue::U64(10), OwnedValue::U64(2), OwnedValue::Null]
);
let collector = TopDocs::with_limit(10).order_by((
SortByErasedType::for_field("id"),
ComparatorEnum::ReverseNoneLower,
));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![OwnedValue::U64(2), OwnedValue::U64(10), OwnedValue::Null]
);
}
#[test]
fn test_sort_by_owned_string() {
let mut schema_builder = Schema::builder();
let city_field = schema_builder.add_text_field("city", FAST | TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.add_document(doc!(city_field => "tokyo")).unwrap();
writer.add_document(doc!(city_field => "austin")).unwrap();
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let collector = TopDocs::with_limit(10).order_by((
SortByErasedType::for_field("city"),
ComparatorEnum::ReverseNoneLower,
));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![
OwnedValue::Str("austin".to_string()),
OwnedValue::Str("tokyo".to_string()),
OwnedValue::Null
]
);
}
#[test]
fn test_sort_by_owned_reverse() {
let mut schema_builder = Schema::builder();
let id_field = schema_builder.add_u64_field("id", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.add_document(doc!(id_field => 10u64)).unwrap();
writer.add_document(doc!(id_field => 2u64)).unwrap();
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let collector = TopDocs::with_limit(10)
.order_by((SortByErasedType::for_field("id"), ComparatorEnum::Reverse));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![OwnedValue::Null, OwnedValue::U64(2), OwnedValue::U64(10)]
);
}
#[test]
fn test_sort_by_owned_score() {
let mut schema_builder = Schema::builder();
let body_field = schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.add_document(doc!(body_field => "a a")).unwrap();
writer.add_document(doc!(body_field => "a")).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let query_parser = crate::query::QueryParser::for_index(&index, vec![body_field]);
let query = query_parser.parse_query("a").unwrap();
// Sort by score descending (Natural)
let collector = TopDocs::with_limit(10)
.order_by((SortByErasedType::for_score(), ComparatorEnum::Natural));
let top_docs = searcher.search(&query, &collector).unwrap();
let values: Vec<f64> = top_docs
.into_iter()
.map(|(key, _)| match key {
OwnedValue::F64(val) => val,
_ => panic!("Wrong type {key:?}"),
})
.collect();
assert_eq!(values.len(), 2);
assert!(values[0] > values[1]);
// Sort by score ascending (ReverseNoneLower)
let collector = TopDocs::with_limit(10).order_by((
SortByErasedType::for_score(),
ComparatorEnum::ReverseNoneLower,
));
let top_docs = searcher.search(&query, &collector).unwrap();
let values: Vec<f64> = top_docs
.into_iter()
.map(|(key, _)| match key {
OwnedValue::F64(val) => val,
_ => panic!("Wrong type {key:?}"),
})
.collect();
assert_eq!(values.len(), 2);
assert!(values[0] < values[1]);
}
}

View File

@@ -63,8 +63,8 @@ impl SortKeyComputer for SortBySimilarityScore {
impl SegmentSortKeyComputer for SortBySimilarityScore { impl SegmentSortKeyComputer for SortBySimilarityScore {
type SortKey = Score; type SortKey = Score;
type SegmentSortKey = Score; type SegmentSortKey = Score;
type SegmentComparator = NaturalComparator;
#[inline(always)] #[inline(always)]
fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score { fn segment_sort_key(&mut self, _doc: DocId, score: Score) -> Score {

View File

@@ -34,9 +34,7 @@ impl<T: FastValue> SortByStaticFastValue<T> {
impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> { impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
type Child = SortByFastValueSegmentSortKeyComputer<T>; type Child = SortByFastValueSegmentSortKeyComputer<T>;
type SortKey = Option<T>; type SortKey = Option<T>;
type Comparator = NaturalComparator; type Comparator = NaturalComparator;
fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> { fn check_schema(&self, schema: &crate::schema::Schema) -> crate::Result<()> {
@@ -84,8 +82,8 @@ pub struct SortByFastValueSegmentSortKeyComputer<T> {
impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer<T> { impl<T: FastValue> SegmentSortKeyComputer for SortByFastValueSegmentSortKeyComputer<T> {
type SortKey = Option<T>; type SortKey = Option<T>;
type SegmentSortKey = Option<u64>; type SegmentSortKey = Option<u64>;
type SegmentComparator = NaturalComparator;
#[inline(always)] #[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey { fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey {

View File

@@ -30,9 +30,7 @@ impl SortByString {
impl SortKeyComputer for SortByString { impl SortKeyComputer for SortByString {
type SortKey = Option<String>; type SortKey = Option<String>;
type Child = ByStringColumnSegmentSortKeyComputer; type Child = ByStringColumnSegmentSortKeyComputer;
type Comparator = NaturalComparator; type Comparator = NaturalComparator;
fn segment_sort_key_computer( fn segment_sort_key_computer(
@@ -50,8 +48,8 @@ pub struct ByStringColumnSegmentSortKeyComputer {
impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
type SortKey = Option<String>; type SortKey = Option<String>;
type SegmentSortKey = Option<TermOrdinal>; type SegmentSortKey = Option<TermOrdinal>;
type SegmentComparator = NaturalComparator;
#[inline(always)] #[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> { fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
@@ -60,6 +58,8 @@ impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer {
} }
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> { fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<String> {
// TODO: Individual lookups to the dictionary like this are very likely to repeatedly
// decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776
let term_ord = term_ord_opt?; let term_ord = term_ord_opt?;
let str_column = self.str_column_opt.as_ref()?; let str_column = self.str_column_opt.as_ref()?;
let mut bytes = Vec::new(); let mut bytes = Vec::new();

View File

@@ -12,13 +12,21 @@ use crate::{DocAddress, DocId, Result, Score, SegmentReader};
/// It is the segment local version of the [`SortKeyComputer`]. /// It is the segment local version of the [`SortKeyComputer`].
pub trait SegmentSortKeyComputer: 'static { pub trait SegmentSortKeyComputer: 'static {
/// The final score being emitted. /// The final score being emitted.
type SortKey: 'static + PartialOrd + Send + Sync + Clone; type SortKey: 'static + Send + Sync + Clone;
/// Sort key used by at the segment level by the `SegmentSortKeyComputer`. /// Sort key used by at the segment level by the `SegmentSortKeyComputer`.
/// ///
/// It is typically small like a `u64`, and is meant to be converted /// It is typically small like a `u64`, and is meant to be converted
/// to the final score at the end of the collection of the segment. /// to the final score at the end of the collection of the segment.
type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone; type SegmentSortKey: 'static + Clone + Send + Sync + Clone;
/// Comparator type.
type SegmentComparator: Comparator<Self::SegmentSortKey> + 'static;
/// Returns the segment sort key comparator.
fn segment_comparator(&self) -> Self::SegmentComparator {
Self::SegmentComparator::default()
}
/// Computes the sort key for the given document and score. /// Computes the sort key for the given document and score.
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey; fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey;
@@ -47,7 +55,7 @@ pub trait SegmentSortKeyComputer: 'static {
left: &Self::SegmentSortKey, left: &Self::SegmentSortKey,
right: &Self::SegmentSortKey, right: &Self::SegmentSortKey,
) -> Ordering { ) -> Ordering {
NaturalComparator.compare(left, right) self.segment_comparator().compare(left, right)
} }
/// Implementing this method makes it possible to avoid computing /// Implementing this method makes it possible to avoid computing
@@ -81,7 +89,7 @@ pub trait SegmentSortKeyComputer: 'static {
/// the sort key at a segment scale. /// the sort key at a segment scale.
pub trait SortKeyComputer: Sync { pub trait SortKeyComputer: Sync {
/// The sort key type. /// The sort key type.
type SortKey: 'static + Send + Sync + PartialOrd + Clone + std::fmt::Debug; type SortKey: 'static + Send + Sync + Clone + std::fmt::Debug;
/// Type of the associated [`SegmentSortKeyComputer`]. /// Type of the associated [`SegmentSortKeyComputer`].
type Child: SegmentSortKeyComputer<SortKey = Self::SortKey>; type Child: SegmentSortKeyComputer<SortKey = Self::SortKey>;
/// Comparator type. /// Comparator type.
@@ -136,10 +144,7 @@ where
HeadSortKeyComputer: SortKeyComputer, HeadSortKeyComputer: SortKeyComputer,
TailSortKeyComputer: SortKeyComputer, TailSortKeyComputer: SortKeyComputer,
{ {
type SortKey = ( type SortKey = (HeadSortKeyComputer::SortKey, TailSortKeyComputer::SortKey);
<HeadSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
<TailSortKeyComputer::Child as SegmentSortKeyComputer>::SortKey,
);
type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child); type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child);
type Comparator = ( type Comparator = (
@@ -188,6 +193,11 @@ where
TailSegmentSortKeyComputer::SegmentSortKey, TailSegmentSortKeyComputer::SegmentSortKey,
); );
type SegmentComparator = (
HeadSegmentSortKeyComputer::SegmentComparator,
TailSegmentSortKeyComputer::SegmentComparator,
);
/// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on /// A SegmentSortKeyComputer maps to a SegmentSortKey, but it can also decide on
/// its ordering. /// its ordering.
/// ///
@@ -269,11 +279,12 @@ impl<T, PreviousScore, NewScore> SegmentSortKeyComputer
for MappedSegmentSortKeyComputer<T, PreviousScore, NewScore> for MappedSegmentSortKeyComputer<T, PreviousScore, NewScore>
where where
T: SegmentSortKeyComputer<SortKey = PreviousScore>, T: SegmentSortKeyComputer<SortKey = PreviousScore>,
PreviousScore: 'static + Clone + Send + Sync + PartialOrd, PreviousScore: 'static + Clone + Send + Sync,
NewScore: 'static + Clone + Send + Sync + PartialOrd, NewScore: 'static + Clone + Send + Sync,
{ {
type SortKey = NewScore; type SortKey = NewScore;
type SegmentSortKey = T::SegmentSortKey; type SegmentSortKey = T::SegmentSortKey;
type SegmentComparator = T::SegmentComparator;
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { fn segment_sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey {
self.sort_key_computer.segment_sort_key(doc, score) self.sort_key_computer.segment_sort_key(doc, score)
@@ -463,6 +474,7 @@ where
{ {
type SortKey = TSortKey; type SortKey = TSortKey;
type SegmentSortKey = TSortKey; type SegmentSortKey = TSortKey;
type SegmentComparator = NaturalComparator;
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey { fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey {
(self)(doc) (self)(doc)

View File

@@ -324,7 +324,7 @@ impl TopDocs {
sort_key_computer: impl SortKeyComputer<SortKey = TSortKey> + Send + 'static, sort_key_computer: impl SortKeyComputer<SortKey = TSortKey> + Send + 'static,
) -> impl Collector<Fruit = Vec<(TSortKey, DocAddress)>> ) -> impl Collector<Fruit = Vec<(TSortKey, DocAddress)>>
where where
TSortKey: 'static + Clone + Send + Sync + PartialOrd + std::fmt::Debug, TSortKey: 'static + Clone + Send + Sync + std::fmt::Debug,
{ {
TopBySortKeyCollector::new(sort_key_computer, self.doc_range()) TopBySortKeyCollector::new(sort_key_computer, self.doc_range())
} }
@@ -445,7 +445,7 @@ where
F: 'static + Send + Sync + Fn(&SegmentReader) -> TTweakScoreSortKeyFn, F: 'static + Send + Sync + Fn(&SegmentReader) -> TTweakScoreSortKeyFn,
TTweakScoreSortKeyFn: 'static + Fn(DocId, Score) -> TSortKey, TTweakScoreSortKeyFn: 'static + Fn(DocId, Score) -> TSortKey,
TweakScoreSegmentSortKeyComputer<TTweakScoreSortKeyFn>: TweakScoreSegmentSortKeyComputer<TTweakScoreSortKeyFn>:
SegmentSortKeyComputer<SortKey = TSortKey>, SegmentSortKeyComputer<SortKey = TSortKey, SegmentSortKey = TSortKey>,
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug, TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
{ {
type SortKey = TSortKey; type SortKey = TSortKey;
@@ -480,6 +480,7 @@ where
{ {
type SortKey = TSortKey; type SortKey = TSortKey;
type SegmentSortKey = TSortKey; type SegmentSortKey = TSortKey;
type SegmentComparator = NaturalComparator;
fn segment_sort_key(&mut self, doc: DocId, score: Score) -> TSortKey { fn segment_sort_key(&mut self, doc: DocId, score: Score) -> TSortKey {
(self.sort_key_fn)(doc, score) (self.sort_key_fn)(doc, score)

View File

@@ -48,7 +48,15 @@ impl Executor {
F: Sized + Sync + Fn(A) -> crate::Result<R>, F: Sized + Sync + Fn(A) -> crate::Result<R>,
{ {
match self { match self {
Executor::SingleThread => args.map(f).collect::<crate::Result<_>>(), Executor::SingleThread => {
// Avoid `collect`, since the stacktrace is blown up by it, which makes profiling
// harder.
let mut result = Vec::with_capacity(args.size_hint().0);
for arg in args {
result.push(f(arg)?);
}
Ok(result)
}
Executor::ThreadPool(pool) => { Executor::ThreadPool(pool) => {
let args: Vec<A> = args.collect(); let args: Vec<A> = args.collect();
let num_fruits = args.len(); let num_fruits = args.len();

View File

@@ -1,3 +1,5 @@
mod file_watcher;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt; use std::fmt;
use std::fs::{self, File, OpenOptions}; use std::fs::{self, File, OpenOptions};
@@ -7,6 +9,7 @@ use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock, Weak}; use std::sync::{Arc, RwLock, Weak};
use common::StableDeref; use common::StableDeref;
use file_watcher::FileWatcher;
use fs4::fs_std::FileExt; use fs4::fs_std::FileExt;
#[cfg(all(feature = "mmap", unix))] #[cfg(all(feature = "mmap", unix))]
pub use memmap2::Advice; pub use memmap2::Advice;
@@ -18,7 +21,6 @@ use crate::core::META_FILEPATH;
use crate::directory::error::{ use crate::directory::error::{
DeleteError, LockError, OpenDirectoryError, OpenReadError, OpenWriteError, DeleteError, LockError, OpenDirectoryError, OpenReadError, OpenWriteError,
}; };
use crate::directory::file_watcher::FileWatcher;
use crate::directory::{ use crate::directory::{
AntiCallToken, Directory, DirectoryLock, FileHandle, Lock, OwnedBytes, TerminatingWrite, AntiCallToken, Directory, DirectoryLock, FileHandle, Lock, OwnedBytes, TerminatingWrite,
WatchCallback, WatchHandle, WritePtr, WatchCallback, WatchHandle, WritePtr,

View File

@@ -5,7 +5,6 @@ mod mmap_directory;
mod directory; mod directory;
mod directory_lock; mod directory_lock;
mod file_watcher;
pub mod footer; pub mod footer;
mod managed_directory; mod managed_directory;
mod ram_directory; mod ram_directory;

View File

@@ -42,7 +42,6 @@ pub trait DocSet: Send {
/// Calling `seek(TERMINATED)` is also legal and is the normal way to consume a `DocSet`. /// Calling `seek(TERMINATED)` is also legal and is the normal way to consume a `DocSet`.
/// ///
/// `target` has to be larger or equal to `.doc()` when calling `seek`. /// `target` has to be larger or equal to `.doc()` when calling `seek`.
/// If `target` is equal to `.doc()` then the DocSet should not advance.
fn seek(&mut self, target: DocId) -> DocId { fn seek(&mut self, target: DocId) -> DocId {
let mut doc = self.doc(); let mut doc = self.doc();
debug_assert!(doc <= target); debug_assert!(doc <= target);
@@ -167,19 +166,6 @@ pub trait DocSet: Send {
} }
} }
/// Consumes the `DocSet` and returns a Vec with all of the docs in the DocSet
/// including the current doc.
#[cfg(test)]
pub fn docset_to_doc_vec(mut doc_set: Box<dyn DocSet>) -> Vec<DocId> {
let mut output = Vec::new();
let mut doc = doc_set.doc();
while doc != TERMINATED {
output.push(doc);
doc = doc_set.advance();
}
output
}
impl DocSet for &mut dyn DocSet { impl DocSet for &mut dyn DocSet {
fn advance(&mut self) -> u32 { fn advance(&mut self) -> u32 {
(**self).advance() (**self).advance()

View File

@@ -113,7 +113,7 @@ mod tests {
IndexRecordOption::WithFreqs, IndexRecordOption::WithFreqs,
); );
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
let mut scorer = weight.scorer(searcher.segment_reader(0), 1.0f32, 0)?; let mut scorer = weight.scorer(searcher.segment_reader(0), 1.0f32)?;
assert_eq!(scorer.doc(), 0); assert_eq!(scorer.doc(), 0);
assert!((scorer.score() - 0.22920431).abs() < 0.001f32); assert!((scorer.score() - 0.22920431).abs() < 0.001f32);
assert_eq!(scorer.advance(), 1); assert_eq!(scorer.advance(), 1);
@@ -142,7 +142,7 @@ mod tests {
IndexRecordOption::WithFreqs, IndexRecordOption::WithFreqs,
); );
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
let mut scorer = weight.scorer(searcher.segment_reader(0), 1.0f32, 0)?; let mut scorer = weight.scorer(searcher.segment_reader(0), 1.0f32)?;
assert_eq!(scorer.doc(), 0); assert_eq!(scorer.doc(), 0);
assert!((scorer.score() - 0.22920431).abs() < 0.001f32); assert!((scorer.score() - 0.22920431).abs() < 0.001f32);
assert_eq!(scorer.advance(), 1); assert_eq!(scorer.advance(), 1);

View File

@@ -404,7 +404,10 @@ mod tests {
schema_builder.build() schema_builder.build()
}; };
let index_metas = IndexMeta { let index_metas = IndexMeta {
index_settings: IndexSettings::default(), index_settings: IndexSettings {
docstore_compression: Compressor::None,
..Default::default()
},
segments: Vec::new(), segments: Vec::new(),
schema, schema,
opstamp: 0u64, opstamp: 0u64,
@@ -413,7 +416,7 @@ mod tests {
let json = serde_json::ser::to_string(&index_metas).expect("serialization failed"); let json = serde_json::ser::to_string(&index_metas).expect("serialization failed");
assert_eq!( assert_eq!(
json, json,
r#"{"index_settings":{"docstore_compression":"lz4","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"# r#"{"index_settings":{"docstore_compression":"none","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#
); );
let deser_meta: UntrackedIndexMeta = serde_json::from_str(&json).unwrap(); let deser_meta: UntrackedIndexMeta = serde_json::from_str(&json).unwrap();
@@ -494,6 +497,8 @@ mod tests {
#[test] #[test]
#[cfg(feature = "lz4-compression")] #[cfg(feature = "lz4-compression")]
fn test_index_settings_default() { fn test_index_settings_default() {
use crate::store::Compressor;
let mut index_settings = IndexSettings::default(); let mut index_settings = IndexSettings::default();
assert_eq!( assert_eq!(
index_settings, index_settings,

View File

@@ -14,7 +14,6 @@ use crate::positions::PositionReader;
use crate::postings::{BlockSegmentPostings, SegmentPostings, TermInfo}; use crate::postings::{BlockSegmentPostings, SegmentPostings, TermInfo};
use crate::schema::{IndexRecordOption, Term, Type}; use crate::schema::{IndexRecordOption, Term, Type};
use crate::termdict::TermDictionary; use crate::termdict::TermDictionary;
use crate::DocId;
/// The inverted index reader is in charge of accessing /// The inverted index reader is in charge of accessing
/// the inverted index associated with a specific field. /// the inverted index associated with a specific field.
@@ -193,34 +192,9 @@ impl InvertedIndexReader {
term: &Term, term: &Term,
option: IndexRecordOption, option: IndexRecordOption,
) -> io::Result<Option<BlockSegmentPostings>> { ) -> io::Result<Option<BlockSegmentPostings>> {
let Some(term_info) = self.get_term_info(term)? else { self.get_term_info(term)?
return Ok(None); .map(move |term_info| self.read_block_postings_from_terminfo(&term_info, option))
}; .transpose()
let block_postings_not_loaded =
self.read_block_postings_from_terminfo(&term_info, option)?;
Ok(Some(block_postings_not_loaded))
}
/// Returns a block postings given a `term_info`.
/// This method is for an advanced usage only.
///
/// Most users should prefer using [`Self::read_postings()`] instead.
pub(crate) fn read_block_postings_from_terminfo_with_seek(
&self,
term_info: &TermInfo,
requested_option: IndexRecordOption,
seek_doc: DocId,
) -> io::Result<(BlockSegmentPostings, usize)> {
let postings_data = self
.postings_file_slice
.slice(term_info.postings_range.clone());
BlockSegmentPostings::open(
term_info.doc_freq,
postings_data,
self.record_option,
requested_option,
seek_doc,
)
} }
/// Returns a block postings given a `term_info`. /// Returns a block postings given a `term_info`.
@@ -232,9 +206,15 @@ impl InvertedIndexReader {
term_info: &TermInfo, term_info: &TermInfo,
requested_option: IndexRecordOption, requested_option: IndexRecordOption,
) -> io::Result<BlockSegmentPostings> { ) -> io::Result<BlockSegmentPostings> {
let (block_segment_postings, _) = let postings_data = self
self.read_block_postings_from_terminfo_with_seek(term_info, requested_option, 0)?; .postings_file_slice
Ok(block_segment_postings) .slice(term_info.postings_range.clone());
BlockSegmentPostings::open(
term_info.doc_freq,
postings_data,
self.record_option,
requested_option,
)
} }
/// Returns a posting object given a `term_info`. /// Returns a posting object given a `term_info`.
@@ -244,13 +224,13 @@ impl InvertedIndexReader {
pub fn read_postings_from_terminfo( pub fn read_postings_from_terminfo(
&self, &self,
term_info: &TermInfo, term_info: &TermInfo,
record_option: IndexRecordOption, option: IndexRecordOption,
seek_doc: DocId,
) -> io::Result<SegmentPostings> { ) -> io::Result<SegmentPostings> {
let (block_segment_postings, position_within_block) = let option = option.downgrade(self.record_option);
self.read_block_postings_from_terminfo_with_seek(term_info, record_option, seek_doc)?;
let block_postings = self.read_block_postings_from_terminfo(term_info, option)?;
let position_reader = { let position_reader = {
if record_option.has_positions() { if option.has_positions() {
let positions_data = self let positions_data = self
.positions_file_slice .positions_file_slice
.read_bytes_slice(term_info.positions_range.clone())?; .read_bytes_slice(term_info.positions_range.clone())?;
@@ -261,9 +241,8 @@ impl InvertedIndexReader {
} }
}; };
Ok(SegmentPostings::from_block_postings( Ok(SegmentPostings::from_block_postings(
block_segment_postings, block_postings,
position_reader, position_reader,
position_within_block,
)) ))
} }
@@ -289,7 +268,7 @@ impl InvertedIndexReader {
option: IndexRecordOption, option: IndexRecordOption,
) -> io::Result<Option<SegmentPostings>> { ) -> io::Result<Option<SegmentPostings>> {
self.get_term_info(term)? self.get_term_info(term)?
.map(move |term_info| self.read_postings_from_terminfo(&term_info, option, 0u32)) .map(move |term_info| self.read_postings_from_terminfo(&term_info, option))
.transpose() .transpose()
} }

View File

@@ -4,19 +4,20 @@ use std::sync::{Arc, RwLock, Weak};
use super::operation::DeleteOperation; use super::operation::DeleteOperation;
use crate::Opstamp; use crate::Opstamp;
// The DeleteQueue is similar in conceptually to a multiple /// The DeleteQueue is similar in conceptually to a multiple
// consumer single producer broadcast channel. /// consumer single producer broadcast channel.
// ///
// All consumer will receive all messages. /// All consumer will receive all messages.
// ///
// Consumer of the delete queue are holding a `DeleteCursor`, /// Consumer of the delete queue are holding a `DeleteCursor`,
// which points to a specific place of the `DeleteQueue`. /// which points to a specific place of the `DeleteQueue`.
// ///
// New consumer can be created in two ways /// New consumer can be created in two ways
// - calling `delete_queue.cursor()` returns a cursor, that will include all future delete operation /// - calling `delete_queue.cursor()` returns a cursor, that will include all future delete
// (and some or none of the past operations... The client is in charge of checking the opstamps.). /// operation (and some or none of the past operations... The client is in charge of checking the
// - cloning an existing cursor returns a new cursor, that is at the exact same position, and can /// opstamps.).
// now advance independently from the original cursor. /// - cloning an existing cursor returns a new cursor, that is at the exact same position, and can
/// now advance independently from the original cursor.
#[derive(Default)] #[derive(Default)]
struct InnerDeleteQueue { struct InnerDeleteQueue {
writer: Vec<DeleteOperation>, writer: Vec<DeleteOperation>,
@@ -249,12 +250,7 @@ mod tests {
struct DummyWeight; struct DummyWeight;
impl Weight for DummyWeight { impl Weight for DummyWeight {
fn scorer( fn scorer(&self, _reader: &SegmentReader, _boost: Score) -> crate::Result<Box<dyn Scorer>> {
&self,
_reader: &SegmentReader,
_boost: Score,
_seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> {
Err(crate::TantivyError::InternalError("dummy impl".to_owned())) Err(crate::TantivyError::InternalError("dummy impl".to_owned()))
} }

View File

@@ -367,11 +367,8 @@ impl IndexMerger {
for (segment_ord, term_info) in merged_terms.current_segment_ords_and_term_infos() { for (segment_ord, term_info) in merged_terms.current_segment_ords_and_term_infos() {
let segment_reader = &self.readers[segment_ord]; let segment_reader = &self.readers[segment_ord];
let inverted_index: &InvertedIndexReader = &field_readers[segment_ord]; let inverted_index: &InvertedIndexReader = &field_readers[segment_ord];
let segment_postings = inverted_index.read_postings_from_terminfo( let segment_postings = inverted_index
&term_info, .read_postings_from_terminfo(&term_info, segment_postings_option)?;
segment_postings_option,
0u32,
)?;
let alive_bitset_opt = segment_reader.alive_bitset(); let alive_bitset_opt = segment_reader.alive_bitset();
let doc_freq = if let Some(alive_bitset) = alive_bitset_opt { let doc_freq = if let Some(alive_bitset) = alive_bitset_opt {
segment_postings.doc_freq_given_deletes(alive_bitset) segment_postings.doc_freq_given_deletes(alive_bitset)

View File

@@ -4,6 +4,7 @@
//! `IndexWriter` is the main entry point for that, which created from //! `IndexWriter` is the main entry point for that, which created from
//! [`Index::writer`](crate::Index::writer). //! [`Index::writer`](crate::Index::writer).
/// Delete queue implementation for broadcasting delete operations to consumers.
pub(crate) mod delete_queue; pub(crate) mod delete_queue;
pub(crate) mod path_to_unordered_id; pub(crate) mod path_to_unordered_id;

View File

@@ -421,10 +421,9 @@ fn remap_and_write(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::path::{Path, PathBuf}; use std::path::Path;
use columnar::ColumnType; use columnar::ColumnType;
use tempfile::TempDir;
use crate::collector::{Count, TopDocs}; use crate::collector::{Count, TopDocs};
use crate::directory::RamDirectory; use crate::directory::RamDirectory;
@@ -1067,10 +1066,7 @@ mod tests {
let mut schema_builder = Schema::builder(); let mut schema_builder = Schema::builder();
schema_builder.add_text_field("title", text_options); schema_builder.add_text_field("title", text_options);
let schema = schema_builder.build(); let schema = schema_builder.build();
let tempdir = TempDir::new().unwrap(); let index = Index::create_in_ram(schema);
let tempdir_path = PathBuf::from(tempdir.path());
Index::create_in_dir(&tempdir_path, schema).unwrap();
let index = Index::open_in_dir(tempdir_path).unwrap();
let schema = index.schema(); let schema = index.schema();
let mut index_writer = index.writer(50_000_000).unwrap(); let mut index_writer = index.writer(50_000_000).unwrap();
let title = schema.get_field("title").unwrap(); let title = schema.get_field("title").unwrap();

View File

@@ -17,6 +17,7 @@
//! //!
//! ```rust //! ```rust
//! # use std::path::Path; //! # use std::path::Path;
//! # use std::fs;
//! # use tempfile::TempDir; //! # use tempfile::TempDir;
//! # use tantivy::collector::TopDocs; //! # use tantivy::collector::TopDocs;
//! # use tantivy::query::QueryParser; //! # use tantivy::query::QueryParser;
@@ -27,8 +28,11 @@
//! # // Let's create a temporary directory for the //! # // Let's create a temporary directory for the
//! # // sake of this example //! # // sake of this example
//! # if let Ok(dir) = TempDir::new() { //! # if let Ok(dir) = TempDir::new() {
//! # run_example(dir.path()).unwrap(); //! # let index_path = dir.path().join("index");
//! # dir.close().unwrap(); //! # // In case the directory already exists, we remove it
//! # let _ = fs::remove_dir_all(&index_path);
//! # fs::create_dir_all(&index_path).unwrap();
//! # run_example(&index_path).unwrap();
//! # } //! # }
//! # } //! # }
//! # //! #
@@ -203,6 +207,7 @@ mod docset;
mod reader; mod reader;
#[cfg(test)] #[cfg(test)]
#[cfg(feature = "mmap")]
mod compat_tests; mod compat_tests;
pub use self::reader::{IndexReader, IndexReaderBuilder, ReloadPolicy, Warmer}; pub use self::reader::{IndexReader, IndexReaderBuilder, ReloadPolicy, Warmer};
@@ -1170,12 +1175,11 @@ pub mod tests {
#[test] #[test]
fn test_validate_checksum() -> crate::Result<()> { fn test_validate_checksum() -> crate::Result<()> {
let index_path = tempfile::tempdir().expect("dir");
let mut builder = Schema::builder(); let mut builder = Schema::builder();
let body = builder.add_text_field("body", TEXT | STORED); let body = builder.add_text_field("body", TEXT | STORED);
let schema = builder.build(); let schema = builder.build();
let index = Index::create_in_dir(&index_path, schema)?; let index = Index::create_in_ram(schema);
let mut writer: IndexWriter = index.writer(50_000_000)?; let mut writer: IndexWriter = index.writer_for_tests()?;
writer.set_merge_policy(Box::new(NoMergePolicy)); writer.set_merge_policy(Box::new(NoMergePolicy));
for _ in 0..5000 { for _ in 0..5000 {
writer.add_document(doc!(body => "foo"))?; writer.add_document(doc!(body => "foo"))?;

View File

@@ -99,8 +99,7 @@ impl BlockSegmentPostings {
data: FileSlice, data: FileSlice,
mut record_option: IndexRecordOption, mut record_option: IndexRecordOption,
requested_option: IndexRecordOption, requested_option: IndexRecordOption,
seek_doc: DocId, ) -> io::Result<BlockSegmentPostings> {
) -> io::Result<(BlockSegmentPostings, usize)> {
let bytes = data.read_bytes()?; let bytes = data.read_bytes()?;
let (skip_data_opt, postings_data) = split_into_skips_and_postings(doc_freq, bytes)?; let (skip_data_opt, postings_data) = split_into_skips_and_postings(doc_freq, bytes)?;
let skip_reader = match skip_data_opt { let skip_reader = match skip_data_opt {
@@ -126,7 +125,7 @@ impl BlockSegmentPostings {
(_, _) => FreqReadingOption::ReadFreq, (_, _) => FreqReadingOption::ReadFreq,
}; };
let mut block_segment_postings: BlockSegmentPostings = BlockSegmentPostings { let mut block_segment_postings = BlockSegmentPostings {
doc_decoder: BlockDecoder::with_val(TERMINATED), doc_decoder: BlockDecoder::with_val(TERMINATED),
block_loaded: false, block_loaded: false,
freq_decoder: BlockDecoder::with_val(1), freq_decoder: BlockDecoder::with_val(1),
@@ -136,13 +135,8 @@ impl BlockSegmentPostings {
data: postings_data, data: postings_data,
skip_reader, skip_reader,
}; };
let inner_pos = if seek_doc == 0 { block_segment_postings.load_block();
block_segment_postings.load_block(); Ok(block_segment_postings)
0
} else {
block_segment_postings.seek(seek_doc)
};
Ok((block_segment_postings, inner_pos))
} }
/// Returns the block_max_score for the current block. /// Returns the block_max_score for the current block.
@@ -264,9 +258,7 @@ impl BlockSegmentPostings {
self.doc_decoder.output_len self.doc_decoder.output_len
} }
/// Position on a block that may contains `target_doc`, and returns the /// Position on a block that may contains `target_doc`.
/// position of the first document greater than or equal to `target_doc`
/// within that block.
/// ///
/// If all docs are smaller than target, the block loaded may be empty, /// If all docs are smaller than target, the block loaded may be empty,
/// or be the last an incomplete VInt block. /// or be the last an incomplete VInt block.
@@ -461,7 +453,7 @@ mod tests {
doc_ids.push(130); doc_ids.push(130);
{ {
let block_segments = build_block_postings(&doc_ids)?; let block_segments = build_block_postings(&doc_ids)?;
let mut docset = SegmentPostings::from_block_postings(block_segments, None, 0); let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.seek(128), 129); assert_eq!(docset.seek(128), 129);
assert_eq!(docset.doc(), 129); assert_eq!(docset.doc(), 129);
assert_eq!(docset.advance(), 130); assert_eq!(docset.advance(), 130);
@@ -469,8 +461,8 @@ mod tests {
assert_eq!(docset.advance(), TERMINATED); assert_eq!(docset.advance(), TERMINATED);
} }
{ {
let block_segments = build_block_postings(&doc_ids)?; let block_segments = build_block_postings(&doc_ids).unwrap();
let mut docset = SegmentPostings::from_block_postings(block_segments, None, 0); let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.seek(129), 129); assert_eq!(docset.seek(129), 129);
assert_eq!(docset.doc(), 129); assert_eq!(docset.doc(), 129);
assert_eq!(docset.advance(), 130); assert_eq!(docset.advance(), 130);
@@ -479,7 +471,7 @@ mod tests {
} }
{ {
let block_segments = build_block_postings(&doc_ids)?; let block_segments = build_block_postings(&doc_ids)?;
let mut docset = SegmentPostings::from_block_postings(block_segments, None, 0); let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.doc(), 0); assert_eq!(docset.doc(), 0);
assert_eq!(docset.seek(131), TERMINATED); assert_eq!(docset.seek(131), TERMINATED);
assert_eq!(docset.doc(), TERMINATED); assert_eq!(docset.doc(), TERMINATED);

View File

@@ -527,6 +527,7 @@ pub(crate) mod tests {
} }
impl<TScorer: Scorer> Scorer for UnoptimizedDocSet<TScorer> { impl<TScorer: Scorer> Scorer for UnoptimizedDocSet<TScorer> {
#[inline]
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
self.0.score() self.0.score()
} }

View File

@@ -79,15 +79,14 @@ impl SegmentPostings {
.close_term(docs.len() as u32) .close_term(docs.len() as u32)
.expect("In memory Serialization should never fail."); .expect("In memory Serialization should never fail.");
} }
let (block_segment_postings, position_within_block) = BlockSegmentPostings::open( let block_segment_postings = BlockSegmentPostings::open(
docs.len() as u32, docs.len() as u32,
FileSlice::from(buffer), FileSlice::from(buffer),
IndexRecordOption::Basic, IndexRecordOption::Basic,
IndexRecordOption::Basic, IndexRecordOption::Basic,
0u32,
) )
.unwrap(); .unwrap();
SegmentPostings::from_block_postings(block_segment_postings, None, position_within_block) SegmentPostings::from_block_postings(block_segment_postings, None)
} }
/// Helper functions to create `SegmentPostings` for tests. /// Helper functions to create `SegmentPostings` for tests.
@@ -128,29 +127,28 @@ impl SegmentPostings {
postings_serializer postings_serializer
.close_term(doc_and_tfs.len() as u32) .close_term(doc_and_tfs.len() as u32)
.unwrap(); .unwrap();
let (block_segment_postings, position_within_block) = BlockSegmentPostings::open( let block_segment_postings = BlockSegmentPostings::open(
doc_and_tfs.len() as u32, doc_and_tfs.len() as u32,
FileSlice::from(buffer), FileSlice::from(buffer),
IndexRecordOption::WithFreqs, IndexRecordOption::WithFreqs,
IndexRecordOption::WithFreqs, IndexRecordOption::WithFreqs,
0u32,
) )
.unwrap(); .unwrap();
SegmentPostings::from_block_postings(block_segment_postings, None, position_within_block) SegmentPostings::from_block_postings(block_segment_postings, None)
} }
/// Creates a Segment Postings from a /// Reads a Segment postings from an &[u8]
/// - `BlockSegmentPostings`, ///
/// - a position reader /// * `len` - number of document in the posting lists.
/// - a target document to seek to /// * `data` - data array. The complete data is not necessarily used.
/// * `freq_handler` - the freq handler is in charge of decoding frequencies and/or positions
pub(crate) fn from_block_postings( pub(crate) fn from_block_postings(
segment_block_postings: BlockSegmentPostings, segment_block_postings: BlockSegmentPostings,
position_reader: Option<PositionReader>, position_reader: Option<PositionReader>,
position_within_block: usize,
) -> SegmentPostings { ) -> SegmentPostings {
SegmentPostings { SegmentPostings {
block_cursor: segment_block_postings, block_cursor: segment_block_postings,
cur: position_within_block, cur: 0, // cursor within the block
position_reader, position_reader,
} }
} }

View File

@@ -6,17 +6,21 @@ use crate::{DocId, Score, TERMINATED};
// doc num bits uses the following encoding: // doc num bits uses the following encoding:
// given 0b a b cdefgh // given 0b a b cdefgh
// |1|2| 3 | // |1|2|3| 4 |
// - 1: unused // - 1: unused
// - 2: is delta-1 encoded. 0 if not, 1, if yes // - 2: is delta-1 encoded. 0 if not, 1, if yes
// - 3: a 6 bit number in 0..=32, the actual bitwidth // - 3: unused
// - 4: a 5 bit number in 0..32, the actual bitwidth. Bitpacking could in theory say this is 32
// (requiring a 6th bit), but the biggest doc_id we can want to encode is TERMINATED-1, which can
// be represented on 31b without delta encoding.
fn encode_bitwidth(bitwidth: u8, delta_1: bool) -> u8 { fn encode_bitwidth(bitwidth: u8, delta_1: bool) -> u8 {
assert!(bitwidth < 32);
bitwidth | ((delta_1 as u8) << 6) bitwidth | ((delta_1 as u8) << 6)
} }
fn decode_bitwidth(raw_bitwidth: u8) -> (u8, bool) { fn decode_bitwidth(raw_bitwidth: u8) -> (u8, bool) {
let delta_1 = ((raw_bitwidth >> 6) & 1) != 0; let delta_1 = ((raw_bitwidth >> 6) & 1) != 0;
let bitwidth = raw_bitwidth & 0x3f; let bitwidth = raw_bitwidth & 0x1f;
(bitwidth, delta_1) (bitwidth, delta_1)
} }
@@ -430,7 +434,7 @@ mod tests {
#[test] #[test]
fn test_encode_decode_bitwidth() { fn test_encode_decode_bitwidth() {
for bitwidth in 0..=32 { for bitwidth in 0..32 {
for delta_1 in [false, true] { for delta_1 in [false, true] {
assert_eq!( assert_eq!(
(bitwidth, delta_1), (bitwidth, delta_1),

View File

@@ -21,12 +21,7 @@ impl Query for AllQuery {
pub struct AllWeight; pub struct AllWeight;
impl Weight for AllWeight { impl Weight for AllWeight {
fn scorer( fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
&self,
reader: &SegmentReader,
boost: Score,
_seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> {
let all_scorer = AllScorer::new(reader.max_doc()); let all_scorer = AllScorer::new(reader.max_doc());
if boost != 1.0 { if boost != 1.0 {
Ok(Box::new(BoostScorer::new(all_scorer, boost))) Ok(Box::new(BoostScorer::new(all_scorer, boost)))
@@ -110,6 +105,7 @@ impl DocSet for AllScorer {
} }
impl Scorer for AllScorer { impl Scorer for AllScorer {
#[inline]
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
1.0 1.0
} }
@@ -145,7 +141,7 @@ mod tests {
let weight = AllQuery.weight(EnableScoring::disabled_from_schema(&index.schema()))?; let weight = AllQuery.weight(EnableScoring::disabled_from_schema(&index.schema()))?;
{ {
let reader = searcher.segment_reader(0); let reader = searcher.segment_reader(0);
let mut scorer = weight.scorer(reader, 1.0, 0)?; let mut scorer = weight.scorer(reader, 1.0)?;
assert_eq!(scorer.doc(), 0u32); assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.advance(), 1u32); assert_eq!(scorer.advance(), 1u32);
assert_eq!(scorer.doc(), 1u32); assert_eq!(scorer.doc(), 1u32);
@@ -153,7 +149,7 @@ mod tests {
} }
{ {
let reader = searcher.segment_reader(1); let reader = searcher.segment_reader(1);
let mut scorer = weight.scorer(reader, 1.0, 0)?; let mut scorer = weight.scorer(reader, 1.0)?;
assert_eq!(scorer.doc(), 0u32); assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.advance(), TERMINATED); assert_eq!(scorer.advance(), TERMINATED);
} }
@@ -168,12 +164,12 @@ mod tests {
let weight = AllQuery.weight(EnableScoring::disabled_from_schema(searcher.schema()))?; let weight = AllQuery.weight(EnableScoring::disabled_from_schema(searcher.schema()))?;
let reader = searcher.segment_reader(0); let reader = searcher.segment_reader(0);
{ {
let mut scorer = weight.scorer(reader, 2.0, 0)?; let mut scorer = weight.scorer(reader, 2.0)?;
assert_eq!(scorer.doc(), 0u32); assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.score(), 2.0); assert_eq!(scorer.score(), 2.0);
} }
{ {
let mut scorer = weight.scorer(reader, 1.5, 0)?; let mut scorer = weight.scorer(reader, 1.5)?;
assert_eq!(scorer.doc(), 0u32); assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.score(), 1.5); assert_eq!(scorer.score(), 1.5);
} }

View File

@@ -84,12 +84,7 @@ where
A: Automaton + Send + Sync + 'static, A: Automaton + Send + Sync + 'static,
A::State: Clone, A::State: Clone,
{ {
fn scorer( fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
&self,
reader: &SegmentReader,
boost: Score,
seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> {
let max_doc = reader.max_doc(); let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc); let mut doc_bitset = BitSet::with_max_value(max_doc);
let inverted_index = reader.inverted_index(self.field)?; let inverted_index = reader.inverted_index(self.field)?;
@@ -97,12 +92,8 @@ where
let mut term_stream = self.automaton_stream(term_dict)?; let mut term_stream = self.automaton_stream(term_dict)?;
while term_stream.advance() { while term_stream.advance() {
let term_info = term_stream.value(); let term_info = term_stream.value();
let (mut block_segment_postings, _) = inverted_index let mut block_segment_postings = inverted_index
.read_block_postings_from_terminfo_with_seek( .read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic)?;
term_info,
IndexRecordOption::Basic,
seek_doc,
)?;
loop { loop {
let docs = block_segment_postings.docs(); let docs = block_segment_postings.docs();
if docs.is_empty() { if docs.is_empty() {
@@ -120,7 +111,7 @@ where
} }
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0, 0)?; let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) == doc { if scorer.seek(doc) == doc {
Ok(Explanation::new("AutomatonScorer", 1.0)) Ok(Explanation::new("AutomatonScorer", 1.0))
} else { } else {
@@ -195,7 +186,7 @@ mod tests {
let automaton_weight = AutomatonWeight::new(field, PrefixedByA); let automaton_weight = AutomatonWeight::new(field, PrefixedByA);
let reader = index.reader()?; let reader = index.reader()?;
let searcher = reader.searcher(); let searcher = reader.searcher();
let mut scorer = automaton_weight.scorer(searcher.segment_reader(0u32), 1.0, 0)?; let mut scorer = automaton_weight.scorer(searcher.segment_reader(0u32), 1.0)?;
assert_eq!(scorer.doc(), 0u32); assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.score(), 1.0); assert_eq!(scorer.score(), 1.0);
assert_eq!(scorer.advance(), 2u32); assert_eq!(scorer.advance(), 2u32);
@@ -212,7 +203,7 @@ mod tests {
let automaton_weight = AutomatonWeight::new(field, PrefixedByA); let automaton_weight = AutomatonWeight::new(field, PrefixedByA);
let reader = index.reader()?; let reader = index.reader()?;
let searcher = reader.searcher(); let searcher = reader.searcher();
let mut scorer = automaton_weight.scorer(searcher.segment_reader(0u32), 1.32, 0)?; let mut scorer = automaton_weight.scorer(searcher.segment_reader(0u32), 1.32)?;
assert_eq!(scorer.doc(), 0u32); assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.score(), 1.32); assert_eq!(scorer.score(), 1.32);
Ok(()) Ok(())

View File

@@ -12,7 +12,7 @@ use crate::query::{
intersect_scorers, AllScorer, BufferedUnionScorer, EmptyScorer, Exclude, Explanation, Occur, intersect_scorers, AllScorer, BufferedUnionScorer, EmptyScorer, Exclude, Explanation, Occur,
RequiredOptionalScorer, Scorer, Weight, RequiredOptionalScorer, Scorer, Weight,
}; };
use crate::{DocId, Score, TERMINATED}; use crate::{DocId, Score};
enum SpecializedScorer { enum SpecializedScorer {
TermUnion(Vec<TermScorer>), TermUnion(Vec<TermScorer>),
@@ -156,19 +156,6 @@ fn effective_should_scorer_for_union<TScoreCombiner: ScoreCombiner>(
} }
} }
fn create_scorer(
weight: &dyn Weight,
reader: &SegmentReader,
boost: Score,
target_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> {
if target_doc >= reader.max_doc() {
Ok(Box::new(EmptyScorer))
} else {
weight.scorer(reader, boost, target_doc)
}
}
enum ShouldScorersCombinationMethod { enum ShouldScorersCombinationMethod {
// Should scorers are irrelevant. // Should scorers are irrelevant.
Ignored, Ignored,
@@ -220,29 +207,10 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
&self, &self,
reader: &SegmentReader, reader: &SegmentReader,
boost: Score, boost: Score,
mut seek_first_doc: DocId,
) -> crate::Result<HashMap<Occur, Vec<Box<dyn Scorer>>>> { ) -> crate::Result<HashMap<Occur, Vec<Box<dyn Scorer>>>> {
let mut per_occur_scorers: HashMap<Occur, Vec<Box<dyn Scorer>>> = HashMap::new(); let mut per_occur_scorers: HashMap<Occur, Vec<Box<dyn Scorer>>> = HashMap::new();
let (mut must_weights, other_weights): (Vec<(Occur, _)>, Vec<(Occur, _)>) = self for (occur, subweight) in &self.weights {
.weights let sub_scorer: Box<dyn Scorer> = subweight.scorer(reader, boost)?;
.iter()
.map(|(occur, weight)| (*occur, weight))
.partition(|(occur, _weight)| *occur == Occur::Must);
// We start by must weights in order to get the best "seek_first_doc" so that we
// can skip the first few documents of the other scorers.
must_weights.sort_by_key(|weight| weight.1.intersection_priority());
for (_, must_sub_weight) in must_weights {
let sub_scorer: Box<dyn Scorer> =
create_scorer(must_sub_weight.as_ref(), reader, boost, seek_first_doc)?;
seek_first_doc = seek_first_doc.max(sub_scorer.doc());
per_occur_scorers
.entry(Occur::Must)
.or_default()
.push(sub_scorer);
}
for (occur, sub_weight) in &other_weights {
let sub_scorer: Box<dyn Scorer> =
create_scorer(sub_weight.as_ref(), reader, boost, seek_first_doc)?;
per_occur_scorers per_occur_scorers
.entry(*occur) .entry(*occur)
.or_default() .or_default()
@@ -256,10 +224,9 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
reader: &SegmentReader, reader: &SegmentReader,
boost: Score, boost: Score,
score_combiner_fn: impl Fn() -> TComplexScoreCombiner, score_combiner_fn: impl Fn() -> TComplexScoreCombiner,
seek_doc: u32,
) -> crate::Result<SpecializedScorer> { ) -> crate::Result<SpecializedScorer> {
let num_docs = reader.num_docs(); let num_docs = reader.num_docs();
let mut per_occur_scorers = self.per_occur_scorers(reader, boost, seek_doc)?; let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
// Indicate how should clauses are combined with must clauses. // Indicate how should clauses are combined with must clauses.
let mut must_scorers: Vec<Box<dyn Scorer>> = let mut must_scorers: Vec<Box<dyn Scorer>> =
@@ -440,7 +407,7 @@ fn remove_and_count_all_and_empty_scorers(
if scorer.is::<AllScorer>() { if scorer.is::<AllScorer>() {
counts.num_all_scorers += 1; counts.num_all_scorers += 1;
false false
} else if scorer.doc() == TERMINATED { } else if scorer.is::<EmptyScorer>() {
counts.num_empty_scorers += 1; counts.num_empty_scorers += 1;
false false
} else { } else {
@@ -451,12 +418,7 @@ fn remove_and_count_all_and_empty_scorers(
} }
impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombiner> { impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombiner> {
fn scorer( fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
&self,
reader: &SegmentReader,
boost: Score,
seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> {
let num_docs = reader.num_docs(); let num_docs = reader.num_docs();
if self.weights.is_empty() { if self.weights.is_empty() {
Ok(Box::new(EmptyScorer)) Ok(Box::new(EmptyScorer))
@@ -465,15 +427,15 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
if occur == Occur::MustNot { if occur == Occur::MustNot {
Ok(Box::new(EmptyScorer)) Ok(Box::new(EmptyScorer))
} else { } else {
weight.scorer(reader, boost, seek_doc) weight.scorer(reader, boost)
} }
} else if self.scoring_enabled { } else if self.scoring_enabled {
self.complex_scorer(reader, boost, &self.score_combiner_fn, seek_doc) self.complex_scorer(reader, boost, &self.score_combiner_fn)
.map(|specialized_scorer| { .map(|specialized_scorer| {
into_box_scorer(specialized_scorer, &self.score_combiner_fn, num_docs) into_box_scorer(specialized_scorer, &self.score_combiner_fn, num_docs)
}) })
} else { } else {
self.complex_scorer(reader, boost, DoNothingCombiner::default, seek_doc) self.complex_scorer(reader, boost, DoNothingCombiner::default)
.map(|specialized_scorer| { .map(|specialized_scorer| {
into_box_scorer(specialized_scorer, DoNothingCombiner::default, num_docs) into_box_scorer(specialized_scorer, DoNothingCombiner::default, num_docs)
}) })
@@ -481,7 +443,7 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
} }
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0, 0)?; let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) != doc { if scorer.seek(doc) != doc {
return Err(does_not_match(doc)); return Err(does_not_match(doc));
} }
@@ -505,7 +467,7 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
reader: &SegmentReader, reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score), callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> { ) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn, 0)?; let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
match scorer { match scorer {
SpecializedScorer::TermUnion(term_scorers) => { SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer = BufferedUnionScorer::build( let mut union_scorer = BufferedUnionScorer::build(
@@ -527,7 +489,7 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
reader: &SegmentReader, reader: &SegmentReader,
callback: &mut dyn FnMut(&[DocId]), callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> { ) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner, 0u32)?; let scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner)?;
let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN]; let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN];
match scorer { match scorer {
@@ -562,7 +524,7 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
reader: &SegmentReader, reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score) -> Score, callback: &mut dyn FnMut(DocId, Score) -> Score,
) -> crate::Result<()> { ) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn, 0u32)?; let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
match scorer { match scorer {
SpecializedScorer::TermUnion(term_scorers) => { SpecializedScorer::TermUnion(term_scorers) => {
super::block_wand(term_scorers, threshold, callback); super::block_wand(term_scorers, threshold, callback);

View File

@@ -57,7 +57,7 @@ mod tests {
let query = query_parser.parse_query("+a")?; let query = query_parser.parse_query("+a")?;
let searcher = index.reader()?.searcher(); let searcher = index.reader()?.searcher();
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0, 0)?; let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?;
assert!(scorer.is::<TermScorer>()); assert!(scorer.is::<TermScorer>());
Ok(()) Ok(())
} }
@@ -70,13 +70,13 @@ mod tests {
{ {
let query = query_parser.parse_query("+a +b +c")?; let query = query_parser.parse_query("+a +b +c")?;
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0, 0)?; let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?;
assert!(scorer.is::<Intersection<TermScorer>>()); assert!(scorer.is::<Intersection<TermScorer>>());
} }
{ {
let query = query_parser.parse_query("+a +(b c)")?; let query = query_parser.parse_query("+a +(b c)")?;
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0, 0)?; let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?;
assert!(scorer.is::<Intersection<Box<dyn Scorer>>>()); assert!(scorer.is::<Intersection<Box<dyn Scorer>>>());
} }
Ok(()) Ok(())
@@ -90,14 +90,14 @@ mod tests {
{ {
let query = query_parser.parse_query("+a b")?; let query = query_parser.parse_query("+a b")?;
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0, 0)?; let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?;
assert!(scorer assert!(scorer
.is::<RequiredOptionalScorer<Box<dyn Scorer>, Box<dyn Scorer>, SumCombiner>>()); .is::<RequiredOptionalScorer<Box<dyn Scorer>, Box<dyn Scorer>, SumCombiner>>());
} }
{ {
let query = query_parser.parse_query("+a b")?; let query = query_parser.parse_query("+a b")?;
let weight = query.weight(EnableScoring::disabled_from_schema(searcher.schema()))?; let weight = query.weight(EnableScoring::disabled_from_schema(searcher.schema()))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0, 0)?; let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?;
assert!(scorer.is::<TermScorer>()); assert!(scorer.is::<TermScorer>());
} }
Ok(()) Ok(())
@@ -244,14 +244,12 @@ mod tests {
.weight(EnableScoring::enabled_from_searcher(&searcher)) .weight(EnableScoring::enabled_from_searcher(&searcher))
.unwrap(); .unwrap();
{ {
let mut boolean_scorer = let mut boolean_scorer = boolean_weight.scorer(searcher.segment_reader(0u32), 1.0)?;
boolean_weight.scorer(searcher.segment_reader(0u32), 1.0, 0)?;
assert_eq!(boolean_scorer.doc(), 0u32); assert_eq!(boolean_scorer.doc(), 0u32);
assert_nearly_equals!(boolean_scorer.score(), 0.84163445); assert_nearly_equals!(boolean_scorer.score(), 0.84163445);
} }
{ {
let mut boolean_scorer = let mut boolean_scorer = boolean_weight.scorer(searcher.segment_reader(0u32), 2.0)?;
boolean_weight.scorer(searcher.segment_reader(0u32), 2.0, 0)?;
assert_eq!(boolean_scorer.doc(), 0u32); assert_eq!(boolean_scorer.doc(), 0u32);
assert_nearly_equals!(boolean_scorer.score(), 1.6832689); assert_nearly_equals!(boolean_scorer.score(), 1.6832689);
} }
@@ -345,7 +343,7 @@ mod tests {
(Occur::Must, term_match_some.box_clone()), (Occur::Must, term_match_some.box_clone()),
]); ]);
let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?; let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32, 0)?; let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
assert!(scorer.is::<TermScorer>()); assert!(scorer.is::<TermScorer>());
} }
{ {
@@ -355,7 +353,7 @@ mod tests {
(Occur::Must, term_match_none.box_clone()), (Occur::Must, term_match_none.box_clone()),
]); ]);
let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?; let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32, 0)?; let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
assert!(scorer.is::<EmptyScorer>()); assert!(scorer.is::<EmptyScorer>());
} }
{ {
@@ -364,7 +362,7 @@ mod tests {
(Occur::Should, term_match_none.box_clone()), (Occur::Should, term_match_none.box_clone()),
]); ]);
let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?; let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32, 0)?; let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
assert!(scorer.is::<AllScorer>()); assert!(scorer.is::<AllScorer>());
} }
{ {
@@ -373,7 +371,7 @@ mod tests {
(Occur::Should, term_match_none.box_clone()), (Occur::Should, term_match_none.box_clone()),
]); ]);
let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?; let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32, 0)?; let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
assert!(scorer.is::<TermScorer>()); assert!(scorer.is::<TermScorer>());
} }
Ok(()) Ok(())
@@ -613,134 +611,6 @@ mod tests {
Ok(()) Ok(())
} }
/// Test that the seek_doc parameter correctly skips documents in BooleanWeight::scorer.
///
/// When seek_doc is provided, the scorer should start from that document (or the first
/// matching document >= seek_doc), skipping earlier documents.
#[test]
pub fn test_boolean_weight_seek_doc() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let value_field = schema_builder.add_u64_field("value", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// Create 11 documents:
// doc 0: value=0
// doc 1: value=10
// doc 2: value=20
// ...
// doc 9: value=90
// doc 10: value=50 (matches range 30-70)
for i in 0..10 {
index_writer.add_document(doc!(
text_field => "hello",
value_field => (i * 10) as u64
))?;
}
index_writer.add_document(doc!(
text_field => "hello",
value_field => 50u64
))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let segment_reader = searcher.segment_reader(0);
// Create a Boolean query: MUST(term "hello") AND MUST(range 30..=70)
// This should match docs with value in [30, 70]: docs 3, 4, 5, 6, 7, 10
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "hello"),
IndexRecordOption::Basic,
));
let range_query: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Included(Term::from_field_u64(value_field, 30)),
Bound::Included(Term::from_field_u64(value_field, 70)),
));
let boolean_query =
BooleanQuery::new(vec![(Occur::Must, term_query), (Occur::Must, range_query)]);
let weight =
boolean_query.weight(EnableScoring::disabled_from_schema(searcher.schema()))?;
let doc_when_seeking_from = |seek_from: DocId| {
let scorer = weight.scorer(segment_reader, 1.0f32, seek_from).unwrap();
crate::docset::docset_to_doc_vec(scorer)
};
// Expected matching docs: 3, 4, 5, 6, 7, 10 (values 30, 40, 50, 60, 70, 50)
assert_eq!(doc_when_seeking_from(0), vec![3, 4, 5, 6, 7, 10]);
assert_eq!(doc_when_seeking_from(1), vec![3, 4, 5, 6, 7, 10]);
assert_eq!(doc_when_seeking_from(3), vec![3, 4, 5, 6, 7, 10]);
assert_eq!(doc_when_seeking_from(4), vec![4, 5, 6, 7, 10]);
assert_eq!(doc_when_seeking_from(7), vec![7, 10]);
assert_eq!(doc_when_seeking_from(8), vec![10]);
assert_eq!(doc_when_seeking_from(10), vec![10]);
assert_eq!(doc_when_seeking_from(11), Vec::<DocId>::new());
Ok(())
}
/// Test that the seek_doc parameter works correctly with SHOULD clauses.
#[test]
pub fn test_boolean_weight_seek_doc_with_should() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// Create documents:
// doc 0: "a b"
// doc 1: "a"
// doc 2: "b"
// doc 3: "c"
// doc 4: "a b c"
index_writer.add_document(doc!(text_field => "a b"))?;
index_writer.add_document(doc!(text_field => "a"))?;
index_writer.add_document(doc!(text_field => "b"))?;
index_writer.add_document(doc!(text_field => "c"))?;
index_writer.add_document(doc!(text_field => "a b c"))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let segment_reader = searcher.segment_reader(0);
// Create a Boolean query: SHOULD(term "a") OR SHOULD(term "b")
// This should match docs 0, 1, 2, 4
let term_a: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "a"),
IndexRecordOption::Basic,
));
let term_b: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "b"),
IndexRecordOption::Basic,
));
let boolean_query =
BooleanQuery::new(vec![(Occur::Should, term_a), (Occur::Should, term_b)]);
let weight =
boolean_query.weight(EnableScoring::disabled_from_schema(searcher.schema()))?;
let doc_when_seeking_from = |seek_from: DocId| {
let scorer = weight.scorer(segment_reader, 1.0f32, seek_from).unwrap();
crate::docset::docset_to_doc_vec(scorer)
};
// Expected matching docs: 0, 1, 2, 4
assert_eq!(doc_when_seeking_from(0), vec![0, 1, 2, 4]);
assert_eq!(doc_when_seeking_from(1), vec![1, 2, 4]);
assert_eq!(doc_when_seeking_from(2), vec![2, 4]);
assert_eq!(doc_when_seeking_from(3), vec![4]);
assert_eq!(doc_when_seeking_from(4), vec![4]);
assert_eq!(doc_when_seeking_from(5), Vec::<DocId>::new());
Ok(())
}
/// Test multiple AllScorer instances in different clause types. /// Test multiple AllScorer instances in different clause types.
/// ///
/// Verifies correct behavior when AllScorers appear in multiple positions. /// Verifies correct behavior when AllScorers appear in multiple positions.

View File

@@ -67,13 +67,8 @@ impl BoostWeight {
} }
impl Weight for BoostWeight { impl Weight for BoostWeight {
fn scorer( fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
&self, self.weight.scorer(reader, boost * self.boost)
reader: &SegmentReader,
boost: Score,
seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> {
self.weight.scorer(reader, boost * self.boost, seek_doc)
} }
fn explain(&self, reader: &SegmentReader, doc: u32) -> crate::Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: u32) -> crate::Result<Explanation> {
@@ -88,10 +83,6 @@ impl Weight for BoostWeight {
fn count(&self, reader: &SegmentReader) -> crate::Result<u32> { fn count(&self, reader: &SegmentReader) -> crate::Result<u32> {
self.weight.count(reader) self.weight.count(reader)
} }
fn intersection_priority(&self) -> u32 {
self.weight.intersection_priority()
}
} }
pub(crate) struct BoostScorer<S: Scorer> { pub(crate) struct BoostScorer<S: Scorer> {
@@ -143,6 +134,7 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
} }
impl<S: Scorer> Scorer for BoostScorer<S> { impl<S: Scorer> Scorer for BoostScorer<S> {
#[inline]
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
self.underlying.score() * self.boost self.underlying.score() * self.boost
} }

View File

@@ -63,18 +63,13 @@ impl ConstWeight {
} }
impl Weight for ConstWeight { impl Weight for ConstWeight {
fn scorer( fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
&self, let inner_scorer = self.weight.scorer(reader, boost)?;
reader: &SegmentReader,
boost: Score,
seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> {
let inner_scorer = self.weight.scorer(reader, boost, seek_doc)?;
Ok(Box::new(ConstScorer::new(inner_scorer, boost * self.score))) Ok(Box::new(ConstScorer::new(inner_scorer, boost * self.score)))
} }
fn explain(&self, reader: &SegmentReader, doc: u32) -> crate::Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: u32) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0, 0)?; let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) != doc { if scorer.seek(doc) != doc {
return Err(TantivyError::InvalidArgument(format!( return Err(TantivyError::InvalidArgument(format!(
"Document #({doc}) does not match" "Document #({doc}) does not match"
@@ -89,10 +84,6 @@ impl Weight for ConstWeight {
fn count(&self, reader: &SegmentReader) -> crate::Result<u32> { fn count(&self, reader: &SegmentReader) -> crate::Result<u32> {
self.weight.count(reader) self.weight.count(reader)
} }
fn intersection_priority(&self) -> u32 {
self.weight.intersection_priority()
}
} }
/// Wraps a `DocSet` and simply returns a constant `Scorer`. /// Wraps a `DocSet` and simply returns a constant `Scorer`.
@@ -146,6 +137,7 @@ impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
} }
impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> { impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> {
#[inline]
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
self.score self.score
} }

View File

@@ -173,6 +173,7 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> DocSet
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Scorer impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Scorer
for Disjunction<TScorer, TScoreCombiner> for Disjunction<TScorer, TScoreCombiner>
{ {
#[inline]
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
self.current_score self.current_score
} }
@@ -307,6 +308,7 @@ mod tests {
} }
impl Scorer for DummyScorer { impl Scorer for DummyScorer {
#[inline]
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0) self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0)
} }

View File

@@ -26,24 +26,13 @@ impl Query for EmptyQuery {
/// It is useful for tests and handling edge cases. /// It is useful for tests and handling edge cases.
pub struct EmptyWeight; pub struct EmptyWeight;
impl Weight for EmptyWeight { impl Weight for EmptyWeight {
fn scorer( fn scorer(&self, _reader: &SegmentReader, _boost: Score) -> crate::Result<Box<dyn Scorer>> {
&self,
_reader: &SegmentReader,
_boost: Score,
_seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> {
Ok(Box::new(EmptyScorer)) Ok(Box::new(EmptyScorer))
} }
fn explain(&self, _reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> { fn explain(&self, _reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
Err(does_not_match(doc)) Err(does_not_match(doc))
} }
/// Returns a priority number used to sort weights when running an
/// intersection.
fn intersection_priority(&self) -> u32 {
0u32
}
} }
/// `EmptyScorer` is a dummy `Scorer` in which no document matches. /// `EmptyScorer` is a dummy `Scorer` in which no document matches.
@@ -66,6 +55,7 @@ impl DocSet for EmptyScorer {
} }
impl Scorer for EmptyScorer { impl Scorer for EmptyScorer {
#[inline]
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
0.0 0.0
} }

View File

@@ -84,6 +84,7 @@ where
TScorer: Scorer, TScorer: Scorer,
TDocSetExclude: DocSet + 'static, TDocSetExclude: DocSet + 'static,
{ {
#[inline]
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
self.underlying_docset.score() self.underlying_docset.score()
} }

View File

@@ -98,12 +98,7 @@ pub struct ExistsWeight {
} }
impl Weight for ExistsWeight { impl Weight for ExistsWeight {
fn scorer( fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
&self,
reader: &SegmentReader,
boost: Score,
_seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> {
let fast_field_reader = reader.fast_fields(); let fast_field_reader = reader.fast_fields();
let mut column_handles = fast_field_reader.dynamic_column_handles(&self.field_name)?; let mut column_handles = fast_field_reader.dynamic_column_handles(&self.field_name)?;
if self.field_type == Type::Json && self.json_subpaths { if self.field_type == Type::Json && self.json_subpaths {
@@ -171,7 +166,7 @@ impl Weight for ExistsWeight {
} }
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0, 0)?; let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) != doc { if scorer.seek(doc) != doc {
return Err(does_not_match(doc)); return Err(does_not_match(doc));
} }

View File

@@ -105,6 +105,7 @@ impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
} }
impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOtherDocSet> { impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOtherDocSet> {
#[inline]
fn advance(&mut self) -> DocId { fn advance(&mut self) -> DocId {
let (left, right) = (&mut self.left, &mut self.right); let (left, right) = (&mut self.left, &mut self.right);
let mut candidate = left.advance(); let mut candidate = left.advance();
@@ -174,6 +175,7 @@ impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOt
.all(|docset| docset.seek_into_the_danger_zone(target)) .all(|docset| docset.seek_into_the_danger_zone(target))
} }
#[inline]
fn doc(&self) -> DocId { fn doc(&self) -> DocId {
self.left.doc() self.left.doc()
} }
@@ -200,6 +202,7 @@ where
TScorer: Scorer, TScorer: Scorer,
TOtherScorer: Scorer, TOtherScorer: Scorer,
{ {
#[inline]
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
self.left.score() self.left.score()
+ self.right.score() + self.right.score()

View File

@@ -81,6 +81,7 @@ impl<TPostings: Postings> DocSet for PhraseKind<TPostings> {
} }
impl<TPostings: Postings> Scorer for PhraseKind<TPostings> { impl<TPostings: Postings> Scorer for PhraseKind<TPostings> {
#[inline]
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
match self { match self {
PhraseKind::SinglePrefix { positions, .. } => { PhraseKind::SinglePrefix { positions, .. } => {
@@ -215,6 +216,7 @@ impl<TPostings: Postings> DocSet for PhrasePrefixScorer<TPostings> {
} }
impl<TPostings: Postings> Scorer for PhrasePrefixScorer<TPostings> { impl<TPostings: Postings> Scorer for PhrasePrefixScorer<TPostings> {
#[inline]
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
// TODO modify score?? // TODO modify score??
self.phrase_scorer.score() self.phrase_scorer.score()

View File

@@ -42,11 +42,10 @@ impl PhrasePrefixWeight {
Ok(FieldNormReader::constant(reader.max_doc(), 1)) Ok(FieldNormReader::constant(reader.max_doc(), 1))
} }
pub(crate) fn prefix_phrase_scorer( pub(crate) fn phrase_scorer(
&self, &self,
reader: &SegmentReader, reader: &SegmentReader,
boost: Score, boost: Score,
seek_doc: DocId,
) -> crate::Result<Option<PhrasePrefixScorer<SegmentPostings>>> { ) -> crate::Result<Option<PhrasePrefixScorer<SegmentPostings>>> {
let similarity_weight_opt = self let similarity_weight_opt = self
.similarity_weight_opt .similarity_weight_opt
@@ -55,16 +54,14 @@ impl PhrasePrefixWeight {
let fieldnorm_reader = self.fieldnorm_reader(reader)?; let fieldnorm_reader = self.fieldnorm_reader(reader)?;
let mut term_postings_list = Vec::new(); let mut term_postings_list = Vec::new();
for &(offset, ref term) in &self.phrase_terms { for &(offset, ref term) in &self.phrase_terms {
let inverted_index = reader.inverted_index(term.field())?; if let Some(postings) = reader
let Some(term_info) = inverted_index.get_term_info(term)? else { .inverted_index(term.field())?
.read_postings(term, IndexRecordOption::WithFreqsAndPositions)?
{
term_postings_list.push((offset, postings));
} else {
return Ok(None); return Ok(None);
}; }
let postings = inverted_index.read_postings_from_terminfo(
&term_info,
IndexRecordOption::WithFreqsAndPositions,
seek_doc,
)?;
term_postings_list.push((offset, postings));
} }
let inv_index = reader.inverted_index(self.prefix.1.field())?; let inv_index = reader.inverted_index(self.prefix.1.field())?;
@@ -117,13 +114,8 @@ impl PhrasePrefixWeight {
} }
impl Weight for PhrasePrefixWeight { impl Weight for PhrasePrefixWeight {
fn scorer( fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
&self, if let Some(scorer) = self.phrase_scorer(reader, boost)? {
reader: &SegmentReader,
boost: Score,
seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> {
if let Some(scorer) = self.prefix_phrase_scorer(reader, boost, seek_doc)? {
Ok(Box::new(scorer)) Ok(Box::new(scorer))
} else { } else {
Ok(Box::new(EmptyScorer)) Ok(Box::new(EmptyScorer))
@@ -131,7 +123,7 @@ impl Weight for PhrasePrefixWeight {
} }
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let scorer_opt = self.prefix_phrase_scorer(reader, 1.0, doc)?; let scorer_opt = self.phrase_scorer(reader, 1.0)?;
if scorer_opt.is_none() { if scorer_opt.is_none() {
return Err(does_not_match(doc)); return Err(does_not_match(doc));
} }
@@ -148,10 +140,6 @@ impl Weight for PhrasePrefixWeight {
} }
Ok(explanation) Ok(explanation)
} }
fn intersection_priority(&self) -> u32 {
50u32
}
} }
#[cfg(test)] #[cfg(test)]
@@ -199,7 +187,7 @@ mod tests {
.unwrap() .unwrap()
.unwrap(); .unwrap();
let mut phrase_scorer = phrase_weight let mut phrase_scorer = phrase_weight
.prefix_phrase_scorer(searcher.segment_reader(0u32), 1.0, 0u32)? .phrase_scorer(searcher.segment_reader(0u32), 1.0)?
.unwrap(); .unwrap();
assert_eq!(phrase_scorer.doc(), 1); assert_eq!(phrase_scorer.doc(), 1);
assert_eq!(phrase_scorer.phrase_count(), 2); assert_eq!(phrase_scorer.phrase_count(), 2);
@@ -226,7 +214,7 @@ mod tests {
.unwrap() .unwrap()
.unwrap(); .unwrap();
let mut phrase_scorer = phrase_weight let mut phrase_scorer = phrase_weight
.prefix_phrase_scorer(searcher.segment_reader(0u32), 1.0, 0u32)? .phrase_scorer(searcher.segment_reader(0u32), 1.0)?
.unwrap(); .unwrap();
assert_eq!(phrase_scorer.doc(), 1); assert_eq!(phrase_scorer.doc(), 1);
assert_eq!(phrase_scorer.phrase_count(), 2); assert_eq!(phrase_scorer.phrase_count(), 2);
@@ -250,7 +238,7 @@ mod tests {
.unwrap() .unwrap()
.is_none()); .is_none());
let weight = phrase_query.weight(enable_scoring).unwrap(); let weight = phrase_query.weight(enable_scoring).unwrap();
let mut phrase_scorer = weight.scorer(searcher.segment_reader(0u32), 1.0, 0)?; let mut phrase_scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?;
assert_eq!(phrase_scorer.doc(), 1); assert_eq!(phrase_scorer.doc(), 1);
assert_eq!(phrase_scorer.advance(), 2); assert_eq!(phrase_scorer.advance(), 2);
assert_eq!(phrase_scorer.doc(), 2); assert_eq!(phrase_scorer.doc(), 2);
@@ -271,7 +259,7 @@ mod tests {
]); ]);
let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); let enable_scoring = EnableScoring::enabled_from_searcher(&searcher);
let weight = phrase_query.weight(enable_scoring).unwrap(); let weight = phrase_query.weight(enable_scoring).unwrap();
let mut phrase_scorer = weight.scorer(searcher.segment_reader(0u32), 1.0, 0)?; let mut phrase_scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?;
assert_eq!(phrase_scorer.advance(), TERMINATED); assert_eq!(phrase_scorer.advance(), TERMINATED);
Ok(()) Ok(())
} }

View File

@@ -84,7 +84,7 @@ pub(crate) mod tests {
let phrase_query = PhraseQuery::new(terms); let phrase_query = PhraseQuery::new(terms);
let phrase_weight = let phrase_weight =
phrase_query.phrase_weight(EnableScoring::disabled_from_schema(searcher.schema()))?; phrase_query.phrase_weight(EnableScoring::disabled_from_schema(searcher.schema()))?;
let mut phrase_scorer = phrase_weight.scorer(searcher.segment_reader(0), 1.0, 0)?; let mut phrase_scorer = phrase_weight.scorer(searcher.segment_reader(0), 1.0)?;
assert_eq!(phrase_scorer.doc(), 1); assert_eq!(phrase_scorer.doc(), 1);
assert_eq!(phrase_scorer.advance(), TERMINATED); assert_eq!(phrase_scorer.advance(), TERMINATED);
Ok(()) Ok(())
@@ -343,43 +343,6 @@ pub(crate) mod tests {
Ok(()) Ok(())
} }
#[test]
pub fn test_phrase_weight_seek_doc() -> crate::Result<()> {
// Create an index with documents where the phrase "a b" appears in some of them.
// Documents: 0: "c d", 1: "a b", 2: "e f", 3: "a b c", 4: "g h", 5: "a b", 6: "i j"
let index = create_index(&["c d", "a b", "e f", "a b c", "g h", "a b", "i j"])?;
let text_field = index.schema().get_field("text").unwrap();
let searcher = index.reader()?.searcher();
let segment_reader = searcher.segment_reader(0);
let phrase_query = PhraseQuery::new(vec![
Term::from_field_text(text_field, "a"),
Term::from_field_text(text_field, "b"),
]);
let phrase_weight =
phrase_query.phrase_weight(EnableScoring::disabled_from_schema(searcher.schema()))?;
// Helper function to collect all docs from a scorer created with a given seek_doc
let docs_when_seeking_from = |seek_from: DocId| {
let scorer = phrase_weight
.scorer(segment_reader, 1.0f32, seek_from)
.unwrap();
crate::docset::docset_to_doc_vec(scorer)
};
// Documents with "a b": 1, 3, 5
assert_eq!(docs_when_seeking_from(0), vec![1, 3, 5]);
assert_eq!(docs_when_seeking_from(1), vec![1, 3, 5]);
assert_eq!(docs_when_seeking_from(2), vec![3, 5]);
assert_eq!(docs_when_seeking_from(3), vec![3, 5]);
assert_eq!(docs_when_seeking_from(4), vec![5]);
assert_eq!(docs_when_seeking_from(5), vec![5]);
assert_eq!(docs_when_seeking_from(6), Vec::<DocId>::new());
assert_eq!(docs_when_seeking_from(7), Vec::<DocId>::new());
Ok(())
}
#[test] #[test]
pub fn test_phrase_query_on_json() -> crate::Result<()> { pub fn test_phrase_query_on_json() -> crate::Result<()> {
let mut schema_builder = Schema::builder(); let mut schema_builder = Schema::builder();
@@ -410,7 +373,7 @@ pub(crate) mod tests {
.weight(EnableScoring::disabled_from_schema(searcher.schema())) .weight(EnableScoring::disabled_from_schema(searcher.schema()))
.unwrap(); .unwrap();
let mut phrase_scorer = phrase_weight let mut phrase_scorer = phrase_weight
.scorer(searcher.segment_reader(0), 1.0f32, 0) .scorer(searcher.segment_reader(0), 1.0f32)
.unwrap(); .unwrap();
let mut docs = Vec::new(); let mut docs = Vec::new();
loop { loop {

View File

@@ -563,6 +563,7 @@ impl<TPostings: Postings> DocSet for PhraseScorer<TPostings> {
} }
impl<TPostings: Postings> Scorer for PhraseScorer<TPostings> { impl<TPostings: Postings> Scorer for PhraseScorer<TPostings> {
#[inline]
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
let doc = self.doc(); let doc = self.doc();
let fieldnorm_id = self.fieldnorm_reader.fieldnorm_id(doc); let fieldnorm_id = self.fieldnorm_reader.fieldnorm_id(doc);

View File

@@ -43,7 +43,6 @@ impl PhraseWeight {
&self, &self,
reader: &SegmentReader, reader: &SegmentReader,
boost: Score, boost: Score,
seek_doc: DocId,
) -> crate::Result<Option<PhraseScorer<SegmentPostings>>> { ) -> crate::Result<Option<PhraseScorer<SegmentPostings>>> {
let similarity_weight_opt = self let similarity_weight_opt = self
.similarity_weight_opt .similarity_weight_opt
@@ -52,16 +51,14 @@ impl PhraseWeight {
let fieldnorm_reader = self.fieldnorm_reader(reader)?; let fieldnorm_reader = self.fieldnorm_reader(reader)?;
let mut term_postings_list = Vec::new(); let mut term_postings_list = Vec::new();
for &(offset, ref term) in &self.phrase_terms { for &(offset, ref term) in &self.phrase_terms {
let inverted_index = reader.inverted_index(term.field())?; if let Some(postings) = reader
let Some(term_info) = inverted_index.get_term_info(term)? else { .inverted_index(term.field())?
.read_postings(term, IndexRecordOption::WithFreqsAndPositions)?
{
term_postings_list.push((offset, postings));
} else {
return Ok(None); return Ok(None);
}; }
let postings = inverted_index.read_postings_from_terminfo(
&term_info,
IndexRecordOption::WithFreqsAndPositions,
seek_doc,
)?;
term_postings_list.push((offset, postings));
} }
Ok(Some(PhraseScorer::new( Ok(Some(PhraseScorer::new(
term_postings_list, term_postings_list,
@@ -77,13 +74,8 @@ impl PhraseWeight {
} }
impl Weight for PhraseWeight { impl Weight for PhraseWeight {
fn scorer( fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
&self, if let Some(scorer) = self.phrase_scorer(reader, boost)? {
reader: &SegmentReader,
boost: Score,
seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> {
if let Some(scorer) = self.phrase_scorer(reader, boost, seek_doc)? {
Ok(Box::new(scorer)) Ok(Box::new(scorer))
} else { } else {
Ok(Box::new(EmptyScorer)) Ok(Box::new(EmptyScorer))
@@ -91,12 +83,12 @@ impl Weight for PhraseWeight {
} }
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let scorer_opt = self.phrase_scorer(reader, 1.0, doc)?; let scorer_opt = self.phrase_scorer(reader, 1.0)?;
if scorer_opt.is_none() { if scorer_opt.is_none() {
return Err(does_not_match(doc)); return Err(does_not_match(doc));
} }
let mut scorer = scorer_opt.unwrap(); let mut scorer = scorer_opt.unwrap();
if scorer.doc() != doc { if scorer.seek(doc) != doc {
return Err(does_not_match(doc)); return Err(does_not_match(doc));
} }
let fieldnorm_reader = self.fieldnorm_reader(reader)?; let fieldnorm_reader = self.fieldnorm_reader(reader)?;
@@ -108,10 +100,6 @@ impl Weight for PhraseWeight {
} }
Ok(explanation) Ok(explanation)
} }
fn intersection_priority(&self) -> u32 {
40u32
}
} }
#[cfg(test)] #[cfg(test)]
@@ -134,7 +122,7 @@ mod tests {
let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); let enable_scoring = EnableScoring::enabled_from_searcher(&searcher);
let phrase_weight = phrase_query.phrase_weight(enable_scoring).unwrap(); let phrase_weight = phrase_query.phrase_weight(enable_scoring).unwrap();
let mut phrase_scorer = phrase_weight let mut phrase_scorer = phrase_weight
.phrase_scorer(searcher.segment_reader(0u32), 1.0, 0)? .phrase_scorer(searcher.segment_reader(0u32), 1.0)?
.unwrap(); .unwrap();
assert_eq!(phrase_scorer.doc(), 1); assert_eq!(phrase_scorer.doc(), 1);
assert_eq!(phrase_scorer.phrase_count(), 2); assert_eq!(phrase_scorer.phrase_count(), 2);

View File

@@ -195,11 +195,8 @@ impl RegexPhraseWeight {
const SPARSE_TERM_DOC_THRESHOLD: u32 = 100; const SPARSE_TERM_DOC_THRESHOLD: u32 = 100;
for term_info in term_infos { for term_info in term_infos {
let mut term_posting = inverted_index.read_postings_from_terminfo( let mut term_posting = inverted_index
term_info, .read_postings_from_terminfo(term_info, IndexRecordOption::WithFreqsAndPositions)?;
IndexRecordOption::WithFreqsAndPositions,
0u32,
)?;
let num_docs = term_posting.doc_freq(); let num_docs = term_posting.doc_freq();
if num_docs < SPARSE_TERM_DOC_THRESHOLD { if num_docs < SPARSE_TERM_DOC_THRESHOLD {
@@ -272,12 +269,7 @@ impl RegexPhraseWeight {
} }
impl Weight for RegexPhraseWeight { impl Weight for RegexPhraseWeight {
fn scorer( fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
&self,
reader: &SegmentReader,
boost: Score,
_seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> {
if let Some(scorer) = self.phrase_scorer(reader, boost)? { if let Some(scorer) = self.phrase_scorer(reader, boost)? {
Ok(Box::new(scorer)) Ok(Box::new(scorer))
} else { } else {

View File

@@ -61,11 +61,7 @@ pub(crate) struct RangeDocSet<T> {
const DEFAULT_FETCH_HORIZON: u32 = 128; const DEFAULT_FETCH_HORIZON: u32 = 128;
impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> { impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
pub(crate) fn new( pub(crate) fn new(value_range: RangeInclusive<T>, column: Column<T>) -> Self {
value_range: RangeInclusive<T>,
column: Column<T>,
seek_first_doc: DocId,
) -> Self {
if *value_range.start() > column.max_value() || *value_range.end() < column.min_value() { if *value_range.start() > column.max_value() || *value_range.end() < column.min_value() {
return Self { return Self {
value_range, value_range,
@@ -81,7 +77,7 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
value_range, value_range,
column, column,
loaded_docs: VecCursor::new(), loaded_docs: VecCursor::new(),
next_fetch_start: seek_first_doc, next_fetch_start: 0,
fetch_horizon: DEFAULT_FETCH_HORIZON, fetch_horizon: DEFAULT_FETCH_HORIZON,
last_seek_pos_opt: None, last_seek_pos_opt: None,
}; };

View File

@@ -212,12 +212,7 @@ impl InvertedIndexRangeWeight {
} }
impl Weight for InvertedIndexRangeWeight { impl Weight for InvertedIndexRangeWeight {
fn scorer( fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
&self,
reader: &SegmentReader,
boost: Score,
seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> {
let max_doc = reader.max_doc(); let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc); let mut doc_bitset = BitSet::with_max_value(max_doc);
@@ -234,12 +229,7 @@ impl Weight for InvertedIndexRangeWeight {
processed_count += 1; processed_count += 1;
let term_info = term_range.value(); let term_info = term_range.value();
let mut block_segment_postings = inverted_index let mut block_segment_postings = inverted_index
.read_block_postings_from_terminfo_with_seek( .read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic)?;
term_info,
IndexRecordOption::Basic,
seek_doc,
)?
.0;
loop { loop {
let docs = block_segment_postings.docs(); let docs = block_segment_postings.docs();
if docs.is_empty() { if docs.is_empty() {
@@ -256,7 +246,7 @@ impl Weight for InvertedIndexRangeWeight {
} }
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0, 0)?; let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) != doc { if scorer.seek(doc) != doc {
return Err(does_not_match(doc)); return Err(does_not_match(doc));
} }
@@ -696,7 +686,7 @@ mod tests {
.weight(EnableScoring::disabled_from_schema(&schema)) .weight(EnableScoring::disabled_from_schema(&schema))
.unwrap(); .unwrap();
let range_scorer = range_weight let range_scorer = range_weight
.scorer(&searcher.segment_readers()[0], 1.0f32, 0) .scorer(&searcher.segment_readers()[0], 1.0f32)
.unwrap(); .unwrap();
range_scorer range_scorer
}; };

View File

@@ -52,12 +52,7 @@ impl FastFieldRangeWeight {
} }
impl Weight for FastFieldRangeWeight { impl Weight for FastFieldRangeWeight {
fn scorer( fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
&self,
reader: &SegmentReader,
boost: Score,
seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> {
// Check if both bounds are Bound::Unbounded // Check if both bounds are Bound::Unbounded
if self.bounds.is_unbounded() { if self.bounds.is_unbounded() {
return Ok(Box::new(AllScorer::new(reader.max_doc()))); return Ok(Box::new(AllScorer::new(reader.max_doc())));
@@ -114,21 +109,11 @@ impl Weight for FastFieldRangeWeight {
else { else {
return Ok(Box::new(EmptyScorer)); return Ok(Box::new(EmptyScorer));
}; };
search_on_u64_ff( search_on_u64_ff(column, boost, BoundsRange::new(lower_bound, upper_bound))
column, }
boost, Type::U64 | Type::I64 | Type::F64 => {
BoundsRange::new(lower_bound, upper_bound), search_on_json_numerical_field(reader, &field_name, typ, bounds, boost)
seek_doc,
)
} }
Type::U64 | Type::I64 | Type::F64 => search_on_json_numerical_field(
reader,
&field_name,
typ,
bounds,
boost,
seek_doc,
),
Type::Date => { Type::Date => {
let fast_field_reader = reader.fast_fields(); let fast_field_reader = reader.fast_fields();
let Some((column, _col_type)) = fast_field_reader let Some((column, _col_type)) = fast_field_reader
@@ -141,7 +126,6 @@ impl Weight for FastFieldRangeWeight {
column, column,
boost, boost,
BoundsRange::new(bounds.lower_bound, bounds.upper_bound), BoundsRange::new(bounds.lower_bound, bounds.upper_bound),
seek_doc,
) )
} }
Type::Bool | Type::Facet | Type::Bytes | Type::Json | Type::IpAddr => { Type::Bool | Type::Facet | Type::Bytes | Type::Json | Type::IpAddr => {
@@ -170,7 +154,7 @@ impl Weight for FastFieldRangeWeight {
ip_addr_column.min_value(), ip_addr_column.min_value(),
ip_addr_column.max_value(), ip_addr_column.max_value(),
); );
let docset = RangeDocSet::new(value_range, ip_addr_column, seek_doc); let docset = RangeDocSet::new(value_range, ip_addr_column);
Ok(Box::new(ConstScorer::new(docset, boost))) Ok(Box::new(ConstScorer::new(docset, boost)))
} else if field_type.is_str() { } else if field_type.is_str() {
let Some(str_dict_column): Option<StrColumn> = reader.fast_fields().str(&field_name)? let Some(str_dict_column): Option<StrColumn> = reader.fast_fields().str(&field_name)?
@@ -189,12 +173,7 @@ impl Weight for FastFieldRangeWeight {
else { else {
return Ok(Box::new(EmptyScorer)); return Ok(Box::new(EmptyScorer));
}; };
search_on_u64_ff( search_on_u64_ff(column, boost, BoundsRange::new(lower_bound, upper_bound))
column,
boost,
BoundsRange::new(lower_bound, upper_bound),
seek_doc,
)
} else { } else {
assert!( assert!(
maps_to_u64_fastfield(field_type.value_type()), maps_to_u64_fastfield(field_type.value_type()),
@@ -236,13 +215,12 @@ impl Weight for FastFieldRangeWeight {
column, column,
boost, boost,
BoundsRange::new(bounds.lower_bound, bounds.upper_bound), BoundsRange::new(bounds.lower_bound, bounds.upper_bound),
seek_doc,
) )
} }
} }
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0, 0)?; let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) != doc { if scorer.seek(doc) != doc {
return Err(TantivyError::InvalidArgument(format!( return Err(TantivyError::InvalidArgument(format!(
"Document #({doc}) does not match" "Document #({doc}) does not match"
@@ -252,10 +230,6 @@ impl Weight for FastFieldRangeWeight {
Ok(explanation) Ok(explanation)
} }
fn intersection_priority(&self) -> u32 {
30u32
}
} }
/// On numerical fields the column type may not match the user provided one. /// On numerical fields the column type may not match the user provided one.
@@ -267,7 +241,6 @@ fn search_on_json_numerical_field(
typ: Type, typ: Type,
bounds: BoundsRange<ValueBytes<Vec<u8>>>, bounds: BoundsRange<ValueBytes<Vec<u8>>>,
boost: Score, boost: Score,
seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> { ) -> crate::Result<Box<dyn Scorer>> {
// Since we don't know which type was interpolated for the internal column we // Since we don't know which type was interpolated for the internal column we
// have to check for all numeric types (only one exists) // have to check for all numeric types (only one exists)
@@ -345,7 +318,6 @@ fn search_on_json_numerical_field(
column, column,
boost, boost,
BoundsRange::new(bounds.lower_bound, bounds.upper_bound), BoundsRange::new(bounds.lower_bound, bounds.upper_bound),
seek_doc,
) )
} }
@@ -424,7 +396,6 @@ fn search_on_u64_ff(
column: Column<u64>, column: Column<u64>,
boost: Score, boost: Score,
bounds: BoundsRange<u64>, bounds: BoundsRange<u64>,
seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> { ) -> crate::Result<Box<dyn Scorer>> {
let col_min_value = column.min_value(); let col_min_value = column.min_value();
let col_max_value = column.max_value(); let col_max_value = column.max_value();
@@ -455,8 +426,8 @@ fn search_on_u64_ff(
} }
} }
let doc_set = RangeDocSet::new(value_range, column, seek_doc); let docset = RangeDocSet::new(value_range, column);
Ok(Box::new(ConstScorer::new(doc_set, boost))) Ok(Box::new(ConstScorer::new(docset, boost)))
} }
/// Returns true if the type maps to a u64 fast field /// Returns true if the type maps to a u64 fast field
@@ -533,7 +504,7 @@ mod tests {
DateOptions, Field, NumericOptions, Schema, SchemaBuilder, FAST, INDEXED, STORED, STRING, DateOptions, Field, NumericOptions, Schema, SchemaBuilder, FAST, INDEXED, STORED, STRING,
TEXT, TEXT,
}; };
use crate::{DocId, Index, IndexWriter, TantivyDocument, Term, TERMINATED}; use crate::{Index, IndexWriter, TantivyDocument, Term, TERMINATED};
#[test] #[test]
fn test_text_field_ff_range_query() -> crate::Result<()> { fn test_text_field_ff_range_query() -> crate::Result<()> {
@@ -1171,52 +1142,11 @@ mod tests {
Bound::Included(Term::from_field_u64(field, 50_002)), Bound::Included(Term::from_field_u64(field, 50_002)),
)); ));
let scorer = range_query let scorer = range_query
.scorer(searcher.segment_reader(0), 1.0f32, 0) .scorer(searcher.segment_reader(0), 1.0f32)
.unwrap(); .unwrap();
assert_eq!(scorer.doc(), TERMINATED); assert_eq!(scorer.doc(), TERMINATED);
} }
#[test]
fn test_fastfield_range_weight_seek_doc() {
let mut schema_builder = SchemaBuilder::new();
let field = schema_builder.add_u64_field("value", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer: IndexWriter = index.writer_for_tests().unwrap();
// Create 20 documents with values
// 0, 10, 20, ..., 90
// and then 50 again.
for i in 0..10 {
writer.add_document(doc!(field => (i * 10) as u64)).unwrap();
}
writer.add_document(doc!(field => 50u64)).unwrap();
writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher();
let segment_reader = searcher.segment_reader(0);
let range_weight = FastFieldRangeWeight::new(BoundsRange::new(
Bound::Included(Term::from_field_u64(field, 30)),
Bound::Included(Term::from_field_u64(field, 70)),
));
let doc_when_seeking_from = |seek_from: DocId| {
let doc_set = range_weight
.scorer(segment_reader, 1.0f32, seek_from)
.unwrap();
crate::docset::docset_to_doc_vec(doc_set)
};
assert_eq!(doc_when_seeking_from(0), vec![3, 4, 5, 6, 7, 10]);
assert_eq!(doc_when_seeking_from(1), vec![3, 4, 5, 6, 7, 10]);
assert_eq!(doc_when_seeking_from(3), vec![3, 4, 5, 6, 7, 10]);
assert_eq!(doc_when_seeking_from(7), vec![7, 10]);
assert_eq!(doc_when_seeking_from(8), vec![10]);
assert_eq!(doc_when_seeking_from(10), vec![10]);
assert_eq!(doc_when_seeking_from(11), Vec::<DocId>::new());
}
#[test] #[test]
fn range_regression3_test() { fn range_regression3_test() {
let ops = vec![doc_from_id_1(1), doc_from_id_1(2), doc_from_id_1(3)]; let ops = vec![doc_from_id_1(1), doc_from_id_1(2), doc_from_id_1(3)];

View File

@@ -81,6 +81,7 @@ where
TOptScorer: Scorer, TOptScorer: Scorer,
TScoreCombiner: ScoreCombiner, TScoreCombiner: ScoreCombiner,
{ {
#[inline]
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
if let Some(score) = self.score_cache { if let Some(score) = self.score_cache {
return score; return score;

View File

@@ -29,6 +29,7 @@ impl ScoreCombiner for DoNothingCombiner {
fn clear(&mut self) {} fn clear(&mut self) {}
#[inline]
fn score(&self) -> Score { fn score(&self) -> Score {
1.0 1.0
} }
@@ -49,6 +50,7 @@ impl ScoreCombiner for SumCombiner {
self.score = 0.0; self.score = 0.0;
} }
#[inline]
fn score(&self) -> Score { fn score(&self) -> Score {
self.score self.score
} }
@@ -86,6 +88,7 @@ impl ScoreCombiner for DisjunctionMaxCombiner {
self.sum = 0.0; self.sum = 0.0;
} }
#[inline]
fn score(&self) -> Score { fn score(&self) -> Score {
self.max + (self.sum - self.max) * self.tie_breaker self.max + (self.sum - self.max) * self.tie_breaker
} }

View File

@@ -18,6 +18,7 @@ pub trait Scorer: downcast_rs::Downcast + DocSet + 'static {
impl_downcast!(Scorer); impl_downcast!(Scorer);
impl Scorer for Box<dyn Scorer> { impl Scorer for Box<dyn Scorer> {
#[inline]
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
self.deref_mut().score() self.deref_mut().score()
} }

View File

@@ -37,7 +37,7 @@ mod tests {
); );
let term_weight = term_query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let term_weight = term_query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
let segment_reader = searcher.segment_reader(0); let segment_reader = searcher.segment_reader(0);
let mut term_scorer = term_weight.scorer(segment_reader, 1.0, 0)?; let mut term_scorer = term_weight.scorer(segment_reader, 1.0)?;
assert_eq!(term_scorer.doc(), 0); assert_eq!(term_scorer.doc(), 0);
assert_nearly_equals!(term_scorer.score(), 0.28768212); assert_nearly_equals!(term_scorer.score(), 0.28768212);
Ok(()) Ok(())
@@ -65,7 +65,7 @@ mod tests {
); );
let term_weight = term_query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let term_weight = term_query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
let segment_reader = searcher.segment_reader(0); let segment_reader = searcher.segment_reader(0);
let mut term_scorer = term_weight.scorer(segment_reader, 1.0, 0)?; let mut term_scorer = term_weight.scorer(segment_reader, 1.0)?;
for i in 0u32..COMPRESSION_BLOCK_SIZE as u32 { for i in 0u32..COMPRESSION_BLOCK_SIZE as u32 {
assert_eq!(term_scorer.doc(), i); assert_eq!(term_scorer.doc(), i);
if i == COMPRESSION_BLOCK_SIZE as u32 - 1u32 { if i == COMPRESSION_BLOCK_SIZE as u32 - 1u32 {
@@ -162,7 +162,7 @@ mod tests {
let searcher = index.reader()?.searcher(); let searcher = index.reader()?.searcher();
let term_weight = let term_weight =
term_query.weight(EnableScoring::disabled_from_schema(searcher.schema()))?; term_query.weight(EnableScoring::disabled_from_schema(searcher.schema()))?;
let mut term_scorer = term_weight.scorer(searcher.segment_reader(0u32), 1.0, 0)?; let mut term_scorer = term_weight.scorer(searcher.segment_reader(0u32), 1.0)?;
assert_eq!(term_scorer.doc(), 0u32); assert_eq!(term_scorer.doc(), 0u32);
term_scorer.seek(1u32); term_scorer.seek(1u32);
assert_eq!(term_scorer.doc(), 1u32); assert_eq!(term_scorer.doc(), 1u32);
@@ -470,7 +470,7 @@ mod tests {
.weight(EnableScoring::disabled_from_schema(&schema)) .weight(EnableScoring::disabled_from_schema(&schema))
.unwrap(); .unwrap();
term_weight term_weight
.scorer(searcher.segment_reader(0u32), 1.0f32, 0) .scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap() .unwrap()
}; };
// Should be an allscorer // Should be an allscorer
@@ -484,53 +484,6 @@ mod tests {
assert!(empty_scorer.is::<EmptyScorer>()); assert!(empty_scorer.is::<EmptyScorer>());
} }
#[test]
fn test_term_weight_seek_doc() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// Create 11 documents where docs 3, 4, 5, 6, 7, and 10 contain "target"
// (similar pattern to test_fastfield_range_weight_seek_doc)
for i in 0..11 {
if i == 3 || i == 4 || i == 5 || i == 6 || i == 7 || i == 10 {
index_writer.add_document(doc!(text_field => "target"))?;
} else {
index_writer.add_document(doc!(text_field => "other"))?;
}
}
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let segment_reader = searcher.segment_reader(0);
let term_query = TermQuery::new(
Term::from_field_text(text_field, "target"),
IndexRecordOption::Basic,
);
let term_weight =
term_query.weight(EnableScoring::disabled_from_schema(searcher.schema()))?;
let doc_when_seeking_from = |seek_from: crate::DocId| {
let scorer = term_weight
.scorer(segment_reader, 1.0f32, seek_from)
.unwrap();
crate::docset::docset_to_doc_vec(scorer)
};
assert_eq!(doc_when_seeking_from(0), vec![3, 4, 5, 6, 7, 10]);
assert_eq!(doc_when_seeking_from(1), vec![3, 4, 5, 6, 7, 10]);
assert_eq!(doc_when_seeking_from(3), vec![3, 4, 5, 6, 7, 10]);
assert_eq!(doc_when_seeking_from(7), vec![7, 10]);
assert_eq!(doc_when_seeking_from(8), vec![10]);
assert_eq!(doc_when_seeking_from(10), vec![10]);
assert_eq!(doc_when_seeking_from(11), Vec::<crate::DocId>::new());
Ok(())
}
#[test] #[test]
fn test_term_weight_all_query_optimization_disable_when_scoring_enabled() { fn test_term_weight_all_query_optimization_disable_when_scoring_enabled() {
let mut schema_builder = Schema::builder(); let mut schema_builder = Schema::builder();
@@ -556,7 +509,7 @@ mod tests {
.weight(EnableScoring::enabled_from_searcher(&searcher)) .weight(EnableScoring::enabled_from_searcher(&searcher))
.unwrap(); .unwrap();
term_weight term_weight
.scorer(searcher.segment_reader(0u32), 1.0f32, 0) .scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap() .unwrap()
}; };
// Should be an allscorer // Should be an allscorer

View File

@@ -119,6 +119,7 @@ impl DocSet for TermScorer {
} }
impl Scorer for TermScorer { impl Scorer for TermScorer {
#[inline]
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
let fieldnorm_id = self.fieldnorm_id(); let fieldnorm_id = self.fieldnorm_id();
let term_freq = self.term_freq(); let term_freq = self.term_freq();

View File

@@ -34,19 +34,12 @@ impl TermOrEmptyOrAllScorer {
} }
impl Weight for TermWeight { impl Weight for TermWeight {
fn scorer( fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
&self, Ok(self.specialized_scorer(reader, boost)?.into_boxed_scorer())
reader: &SegmentReader,
boost: Score,
seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> {
Ok(self
.specialized_scorer(reader, boost, seek_doc)?
.into_boxed_scorer())
} }
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
match self.specialized_scorer(reader, 1.0, doc)? { match self.specialized_scorer(reader, 1.0)? {
TermOrEmptyOrAllScorer::TermScorer(mut term_scorer) => { TermOrEmptyOrAllScorer::TermScorer(mut term_scorer) => {
if term_scorer.doc() > doc || term_scorer.seek(doc) != doc { if term_scorer.doc() > doc || term_scorer.seek(doc) != doc {
return Err(does_not_match(doc)); return Err(does_not_match(doc));
@@ -62,7 +55,7 @@ impl Weight for TermWeight {
fn count(&self, reader: &SegmentReader) -> crate::Result<u32> { fn count(&self, reader: &SegmentReader) -> crate::Result<u32> {
if let Some(alive_bitset) = reader.alive_bitset() { if let Some(alive_bitset) = reader.alive_bitset() {
Ok(self.scorer(reader, 1.0, 0)?.count(alive_bitset)) Ok(self.scorer(reader, 1.0)?.count(alive_bitset))
} else { } else {
let field = self.term.field(); let field = self.term.field();
let inv_index = reader.inverted_index(field)?; let inv_index = reader.inverted_index(field)?;
@@ -78,7 +71,7 @@ impl Weight for TermWeight {
reader: &SegmentReader, reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score), callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> { ) -> crate::Result<()> {
match self.specialized_scorer(reader, 1.0, 0u32)? { match self.specialized_scorer(reader, 1.0)? {
TermOrEmptyOrAllScorer::TermScorer(mut term_scorer) => { TermOrEmptyOrAllScorer::TermScorer(mut term_scorer) => {
for_each_scorer(&mut *term_scorer, callback); for_each_scorer(&mut *term_scorer, callback);
} }
@@ -97,7 +90,7 @@ impl Weight for TermWeight {
reader: &SegmentReader, reader: &SegmentReader,
callback: &mut dyn FnMut(&[DocId]), callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> { ) -> crate::Result<()> {
match self.specialized_scorer(reader, 1.0, 0u32)? { match self.specialized_scorer(reader, 1.0)? {
TermOrEmptyOrAllScorer::TermScorer(mut term_scorer) => { TermOrEmptyOrAllScorer::TermScorer(mut term_scorer) => {
let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN]; let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN];
for_each_docset_buffered(&mut term_scorer, &mut buffer, callback); for_each_docset_buffered(&mut term_scorer, &mut buffer, callback);
@@ -128,7 +121,7 @@ impl Weight for TermWeight {
reader: &SegmentReader, reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score) -> Score, callback: &mut dyn FnMut(DocId, Score) -> Score,
) -> crate::Result<()> { ) -> crate::Result<()> {
let specialized_scorer = self.specialized_scorer(reader, 1.0, 0u32)?; let specialized_scorer = self.specialized_scorer(reader, 1.0)?;
match specialized_scorer { match specialized_scorer {
TermOrEmptyOrAllScorer::TermScorer(term_scorer) => { TermOrEmptyOrAllScorer::TermScorer(term_scorer) => {
crate::query::boolean_query::block_wand_single_scorer( crate::query::boolean_query::block_wand_single_scorer(
@@ -146,12 +139,6 @@ impl Weight for TermWeight {
} }
Ok(()) Ok(())
} }
/// Returns a priority number used to sort weights when running an
/// intersection.
fn intersection_priority(&self) -> u32 {
10u32
}
} }
impl TermWeight { impl TermWeight {
@@ -182,7 +169,7 @@ impl TermWeight {
reader: &SegmentReader, reader: &SegmentReader,
boost: Score, boost: Score,
) -> crate::Result<Option<TermScorer>> { ) -> crate::Result<Option<TermScorer>> {
let scorer = self.specialized_scorer(reader, boost, 0u32)?; let scorer = self.specialized_scorer(reader, boost)?;
Ok(match scorer { Ok(match scorer {
TermOrEmptyOrAllScorer::TermScorer(scorer) => Some(*scorer), TermOrEmptyOrAllScorer::TermScorer(scorer) => Some(*scorer),
_ => None, _ => None,
@@ -193,7 +180,6 @@ impl TermWeight {
&self, &self,
reader: &SegmentReader, reader: &SegmentReader,
boost: Score, boost: Score,
seek_doc: DocId,
) -> crate::Result<TermOrEmptyOrAllScorer> { ) -> crate::Result<TermOrEmptyOrAllScorer> {
let field = self.term.field(); let field = self.term.field();
let inverted_index = reader.inverted_index(field)?; let inverted_index = reader.inverted_index(field)?;
@@ -210,11 +196,8 @@ impl TermWeight {
))); )));
} }
let segment_postings: SegmentPostings = inverted_index.read_postings_from_terminfo( let segment_postings: SegmentPostings =
&term_info, inverted_index.read_postings_from_terminfo(&term_info, self.index_record_option)?;
self.index_record_option,
seek_doc,
)?;
let fieldnorm_reader = self.fieldnorm_reader(reader)?; let fieldnorm_reader = self.fieldnorm_reader(reader)?;
let similarity_weight = self.similarity_weight.boost_by(boost); let similarity_weight = self.similarity_weight.boost_by(boost);

View File

@@ -128,6 +128,7 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
} }
} }
#[inline]
fn advance_buffered(&mut self) -> bool { fn advance_buffered(&mut self) -> bool {
while self.bucket_idx < HORIZON_NUM_TINYBITSETS { while self.bucket_idx < HORIZON_NUM_TINYBITSETS {
if let Some(val) = self.bitsets[self.bucket_idx].pop_lowest() { if let Some(val) = self.bitsets[self.bucket_idx].pop_lowest() {
@@ -156,6 +157,7 @@ where
TScorer: Scorer, TScorer: Scorer,
TScoreCombiner: ScoreCombiner, TScoreCombiner: ScoreCombiner,
{ {
#[inline]
fn advance(&mut self) -> DocId { fn advance(&mut self) -> DocId {
if self.advance_buffered() { if self.advance_buffered() {
return self.doc; return self.doc;
@@ -245,6 +247,7 @@ where
} }
} }
#[inline]
fn doc(&self) -> DocId { fn doc(&self) -> DocId {
self.doc self.doc
} }
@@ -286,6 +289,7 @@ where
TScoreCombiner: ScoreCombiner, TScoreCombiner: ScoreCombiner,
TScorer: Scorer, TScorer: Scorer,
{ {
#[inline]
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
self.score self.score
} }

View File

@@ -68,28 +68,15 @@ pub trait Weight: Send + Sync + 'static {
/// ///
/// `boost` is a multiplier to apply to the score. /// `boost` is a multiplier to apply to the score.
/// ///
/// As an optimization, the scorer can be positioned on any document below `seek_doc`
/// matching the request.
/// If there are no such document, it should match the first document matching the request;
/// (or TERMINATED if no documents match).
///
/// Entirely ignoring that parameter and positionning the Scorer on the first document
/// is always correct.
///
/// See [`Query`](crate::query::Query). /// See [`Query`](crate::query::Query).
fn scorer( fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>>;
&self,
reader: &SegmentReader,
boost: Score,
seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>>;
/// Returns an [`Explanation`] for the given document. /// Returns an [`Explanation`] for the given document.
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation>; fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation>;
/// Returns the number documents within the given [`SegmentReader`]. /// Returns the number documents within the given [`SegmentReader`].
fn count(&self, reader: &SegmentReader) -> crate::Result<u32> { fn count(&self, reader: &SegmentReader) -> crate::Result<u32> {
let mut scorer = self.scorer(reader, 1.0, 0)?; let mut scorer = self.scorer(reader, 1.0)?;
if let Some(alive_bitset) = reader.alive_bitset() { if let Some(alive_bitset) = reader.alive_bitset() {
Ok(scorer.count(alive_bitset)) Ok(scorer.count(alive_bitset))
} else { } else {
@@ -104,7 +91,7 @@ pub trait Weight: Send + Sync + 'static {
reader: &SegmentReader, reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score), callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> { ) -> crate::Result<()> {
let mut scorer = self.scorer(reader, 1.0, 0)?; let mut scorer = self.scorer(reader, 1.0)?;
for_each_scorer(scorer.as_mut(), callback); for_each_scorer(scorer.as_mut(), callback);
Ok(()) Ok(())
} }
@@ -116,7 +103,7 @@ pub trait Weight: Send + Sync + 'static {
reader: &SegmentReader, reader: &SegmentReader,
callback: &mut dyn FnMut(&[DocId]), callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> { ) -> crate::Result<()> {
let mut docset = self.scorer(reader, 1.0, 0)?; let mut docset = self.scorer(reader, 1.0)?;
let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN]; let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN];
for_each_docset_buffered(&mut docset, &mut buffer, callback); for_each_docset_buffered(&mut docset, &mut buffer, callback);
@@ -139,18 +126,8 @@ pub trait Weight: Send + Sync + 'static {
reader: &SegmentReader, reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score) -> Score, callback: &mut dyn FnMut(DocId, Score) -> Score,
) -> crate::Result<()> { ) -> crate::Result<()> {
let mut scorer = self.scorer(reader, 1.0, 0)?; let mut scorer = self.scorer(reader, 1.0)?;
for_each_pruning_scorer(scorer.as_mut(), threshold, callback); for_each_pruning_scorer(scorer.as_mut(), threshold, callback);
Ok(()) Ok(())
} }
/// Returns a priority number used to sort weights when running an
/// intersection.
///
/// Tweaking this value only impacts performance.
/// A higher priority means that the `.scorer()` will be more likely to be evaluated
/// after the sibling weights, and be passed a higher `seek_doc` value as a result.
fn intersection_priority(&self) -> u32 {
20u32
}
} }

View File

@@ -58,6 +58,31 @@ impl AsRef<OwnedValue> for OwnedValue {
} }
} }
impl OwnedValue {
/// Returns a u8 discriminant value for the `OwnedValue` variant.
///
/// This can be used to sort `OwnedValue` instances by their type.
pub fn discriminant_value(&self) -> u8 {
match self {
OwnedValue::Null => 0,
OwnedValue::Str(_) => 1,
OwnedValue::PreTokStr(_) => 2,
// It is key to make sure U64, I64, F64 are grouped together in there, otherwise we
// might be breaking transivity.
OwnedValue::U64(_) => 3,
OwnedValue::I64(_) => 4,
OwnedValue::F64(_) => 5,
OwnedValue::Bool(_) => 6,
OwnedValue::Date(_) => 7,
OwnedValue::Facet(_) => 8,
OwnedValue::Bytes(_) => 9,
OwnedValue::Array(_) => 10,
OwnedValue::Object(_) => 11,
OwnedValue::IpAddr(_) => 12,
}
}
}
impl<'a> Value<'a> for &'a OwnedValue { impl<'a> Value<'a> for &'a OwnedValue {
type ArrayIter = std::slice::Iter<'a, OwnedValue>; type ArrayIter = std::slice::Iter<'a, OwnedValue>;
type ObjectIter = ObjectMapIter<'a>; type ObjectIter = ObjectMapIter<'a>;

View File

@@ -98,6 +98,10 @@
//! make it possible to access the value given the doc id rapidly. This is useful if the value //! make it possible to access the value given the doc id rapidly. This is useful if the value
//! of the field is required during scoring or collection for instance. //! of the field is required during scoring or collection for instance.
//! //!
//! Some queries may leverage Fast fields when run on a field that is not indexed. This can be
//! handy if that kind of request is infrequent, however note that searching on a Fast field is
//! generally much slower than searching in an index.
//!
//! ``` //! ```
//! use tantivy::schema::*; //! use tantivy::schema::*;
//! let mut schema_builder = Schema::builder(); //! let mut schema_builder = Schema::builder();

View File

@@ -483,7 +483,7 @@ mod tests {
use super::{collapse_overlapped_ranges, search_fragments, select_best_fragment_combination}; use super::{collapse_overlapped_ranges, search_fragments, select_best_fragment_combination};
use crate::query::QueryParser; use crate::query::QueryParser;
use crate::schema::{IndexRecordOption, Schema, TextFieldIndexing, TextOptions, TEXT}; use crate::schema::{Schema, TEXT};
use crate::snippet::SnippetGenerator; use crate::snippet::SnippetGenerator;
use crate::tokenizer::{NgramTokenizer, SimpleTokenizer}; use crate::tokenizer::{NgramTokenizer, SimpleTokenizer};
use crate::Index; use crate::Index;
@@ -727,8 +727,10 @@ Survey in 2016, 2017, and 2018."#;
Ok(()) Ok(())
} }
#[cfg(feature = "stemmer")]
#[test] #[test]
fn test_snippet_generator() -> crate::Result<()> { fn test_snippet_generator() -> crate::Result<()> {
use crate::schema::{IndexRecordOption, TextFieldIndexing, TextOptions};
let mut schema_builder = Schema::builder(); let mut schema_builder = Schema::builder();
let text_options = TextOptions::default().set_indexing_options( let text_options = TextOptions::default().set_indexing_options(
TextFieldIndexing::default() TextFieldIndexing::default()

View File

@@ -102,6 +102,7 @@ pub(crate) mod tests {
} }
const NUM_DOCS: usize = 1_000; const NUM_DOCS: usize = 1_000;
#[test] #[test]
fn test_doc_store_iter_with_delete_bug_1077() -> crate::Result<()> { fn test_doc_store_iter_with_delete_bug_1077() -> crate::Result<()> {
// this will cover deletion of the first element in a checkpoint // this will cover deletion of the first element in a checkpoint
@@ -113,7 +114,7 @@ pub(crate) mod tests {
let directory = RamDirectory::create(); let directory = RamDirectory::create();
let store_wrt = directory.open_write(path)?; let store_wrt = directory.open_write(path)?;
let schema = let schema =
write_lorem_ipsum_store(store_wrt, NUM_DOCS, Compressor::Lz4, BLOCK_SIZE, true); write_lorem_ipsum_store(store_wrt, NUM_DOCS, Compressor::default(), BLOCK_SIZE, true);
let field_title = schema.get_field("title").unwrap(); let field_title = schema.get_field("title").unwrap();
let store_file = directory.open_read(path)?; let store_file = directory.open_read(path)?;
let store = StoreReader::open(store_file, 10)?; let store = StoreReader::open(store_file, 10)?;

View File

@@ -465,7 +465,7 @@ mod tests {
let directory = RamDirectory::create(); let directory = RamDirectory::create();
let path = Path::new("store"); let path = Path::new("store");
let writer = directory.open_write(path)?; let writer = directory.open_write(path)?;
let schema = write_lorem_ipsum_store(writer, 500, Compressor::default(), BLOCK_SIZE, true); let schema = write_lorem_ipsum_store(writer, 500, Compressor::None, BLOCK_SIZE, true);
let title = schema.get_field("title").unwrap(); let title = schema.get_field("title").unwrap();
let store_file = directory.open_read(path)?; let store_file = directory.open_read(path)?;
let store = StoreReader::open(store_file, DOCSTORE_CACHE_CAPACITY)?; let store = StoreReader::open(store_file, DOCSTORE_CACHE_CAPACITY)?;
@@ -499,7 +499,7 @@ mod tests {
assert_eq!(store.cache_stats().cache_hits, 1); assert_eq!(store.cache_stats().cache_hits, 1);
assert_eq!(store.cache_stats().cache_misses, 2); assert_eq!(store.cache_stats().cache_misses, 2);
assert_eq!(store.cache.peek_lru(), Some(11207)); assert_eq!(store.cache.peek_lru(), Some(232206));
Ok(()) Ok(())
} }

View File

@@ -132,13 +132,14 @@ mod regex_tokenizer;
mod remove_long; mod remove_long;
mod simple_tokenizer; mod simple_tokenizer;
mod split_compound_words; mod split_compound_words;
mod stemmer;
mod stop_word_filter; mod stop_word_filter;
mod tokenized_string; mod tokenized_string;
mod tokenizer; mod tokenizer;
mod tokenizer_manager; mod tokenizer_manager;
mod whitespace_tokenizer; mod whitespace_tokenizer;
#[cfg(feature = "stemmer")]
mod stemmer;
pub use tokenizer_api::{BoxTokenStream, Token, TokenFilter, TokenStream, Tokenizer}; pub use tokenizer_api::{BoxTokenStream, Token, TokenFilter, TokenStream, Tokenizer};
pub use self::alphanum_only::AlphaNumOnlyFilter; pub use self::alphanum_only::AlphaNumOnlyFilter;
@@ -151,6 +152,7 @@ pub use self::regex_tokenizer::RegexTokenizer;
pub use self::remove_long::RemoveLongFilter; pub use self::remove_long::RemoveLongFilter;
pub use self::simple_tokenizer::{SimpleTokenStream, SimpleTokenizer}; pub use self::simple_tokenizer::{SimpleTokenStream, SimpleTokenizer};
pub use self::split_compound_words::SplitCompoundWords; pub use self::split_compound_words::SplitCompoundWords;
#[cfg(feature = "stemmer")]
pub use self::stemmer::{Language, Stemmer}; pub use self::stemmer::{Language, Stemmer};
pub use self::stop_word_filter::StopWordFilter; pub use self::stop_word_filter::StopWordFilter;
pub use self::tokenized_string::{PreTokenizedStream, PreTokenizedString}; pub use self::tokenized_string::{PreTokenizedStream, PreTokenizedString};
@@ -167,10 +169,7 @@ pub const MAX_TOKEN_LEN: usize = u16::MAX as usize - 5;
#[cfg(test)] #[cfg(test)]
pub(crate) mod tests { pub(crate) mod tests {
use super::{ use super::{Token, TokenizerManager};
Language, LowerCaser, RemoveLongFilter, SimpleTokenizer, Stemmer, Token, TokenizerManager,
};
use crate::tokenizer::TextAnalyzer;
/// This is a function that can be used in tests and doc tests /// This is a function that can be used in tests and doc tests
/// to assert a token's correctness. /// to assert a token's correctness.
@@ -205,59 +204,15 @@ pub(crate) mod tests {
} }
#[test] #[test]
fn test_en_tokenizer() { fn test_tokenizer_does_not_exist() {
let tokenizer_manager = TokenizerManager::default(); let tokenizer_manager = TokenizerManager::default();
assert!(tokenizer_manager.get("en_doesnotexist").is_none()); assert!(tokenizer_manager.get("en_doesnotexist").is_none());
let mut en_tokenizer = tokenizer_manager.get("en_stem").unwrap();
let mut tokens: Vec<Token> = vec![];
{
let mut add_token = |token: &Token| {
tokens.push(token.clone());
};
en_tokenizer
.token_stream("Hello, happy tax payer!")
.process(&mut add_token);
}
assert_eq!(tokens.len(), 4);
assert_token(&tokens[0], 0, "hello", 0, 5);
assert_token(&tokens[1], 1, "happi", 7, 12);
assert_token(&tokens[2], 2, "tax", 13, 16);
assert_token(&tokens[3], 3, "payer", 17, 22);
}
#[test]
fn test_non_en_tokenizer() {
let tokenizer_manager = TokenizerManager::default();
tokenizer_manager.register(
"el_stem",
TextAnalyzer::builder(SimpleTokenizer::default())
.filter(RemoveLongFilter::limit(40))
.filter(LowerCaser)
.filter(Stemmer::new(Language::Greek))
.build(),
);
let mut en_tokenizer = tokenizer_manager.get("el_stem").unwrap();
let mut tokens: Vec<Token> = vec![];
{
let mut add_token = |token: &Token| {
tokens.push(token.clone());
};
en_tokenizer
.token_stream("Καλημέρα, χαρούμενε φορολογούμενε!")
.process(&mut add_token);
}
assert_eq!(tokens.len(), 3);
assert_token(&tokens[0], 0, "καλημερ", 0, 16);
assert_token(&tokens[1], 1, "χαρουμεν", 18, 36);
assert_token(&tokens[2], 2, "φορολογουμεν", 37, 63);
} }
#[test] #[test]
fn test_tokenizer_empty() { fn test_tokenizer_empty() {
let tokenizer_manager = TokenizerManager::default(); let tokenizer_manager = TokenizerManager::default();
let mut en_tokenizer = tokenizer_manager.get("en_stem").unwrap(); let mut en_tokenizer = tokenizer_manager.get("default").unwrap();
{ {
let mut tokens: Vec<Token> = vec![]; let mut tokens: Vec<Token> = vec![];
{ {

View File

@@ -142,3 +142,60 @@ impl<T: TokenStream> TokenStream for StemmerTokenStream<T> {
self.tail.token_mut() self.tail.token_mut()
} }
} }
#[cfg(test)]
mod tests {
use tokenizer_api::Token;
use super::*;
use crate::tokenizer::tests::assert_token;
use crate::tokenizer::{LowerCaser, SimpleTokenizer, TextAnalyzer, TokenizerManager};
#[test]
fn test_en_stem() {
let tokenizer_manager = TokenizerManager::default();
let mut en_tokenizer = tokenizer_manager.get("en_stem").unwrap();
let mut tokens: Vec<Token> = vec![];
{
let mut add_token = |token: &Token| {
tokens.push(token.clone());
};
en_tokenizer
.token_stream("Dogs are the bests!")
.process(&mut add_token);
}
assert_eq!(tokens.len(), 4);
assert_token(&tokens[0], 0, "dog", 0, 4);
assert_token(&tokens[1], 1, "are", 5, 8);
assert_token(&tokens[2], 2, "the", 9, 12);
assert_token(&tokens[3], 3, "best", 13, 18);
}
#[test]
fn test_non_en_stem() {
let tokenizer_manager = TokenizerManager::default();
tokenizer_manager.register(
"el_stem",
TextAnalyzer::builder(SimpleTokenizer::default())
.filter(LowerCaser)
.filter(Stemmer::new(Language::Greek))
.build(),
);
let mut el_tokenizer = tokenizer_manager.get("el_stem").unwrap();
let mut tokens: Vec<Token> = vec![];
{
let mut add_token = |token: &Token| {
tokens.push(token.clone());
};
el_tokenizer
.token_stream("Καλημέρα, χαρούμενε φορολογούμενε!")
.process(&mut add_token);
}
assert_eq!(tokens.len(), 3);
assert_token(&tokens[0], 0, "καλημερ", 0, 16);
assert_token(&tokens[1], 1, "χαρουμεν", 18, 36);
assert_token(&tokens[2], 2, "φορολογουμεν", 37, 63);
}
}

View File

@@ -1,10 +1,9 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use crate::tokenizer::stemmer::Language;
use crate::tokenizer::tokenizer::TextAnalyzer; use crate::tokenizer::tokenizer::TextAnalyzer;
use crate::tokenizer::{ use crate::tokenizer::{
LowerCaser, RawTokenizer, RemoveLongFilter, SimpleTokenizer, Stemmer, WhitespaceTokenizer, LowerCaser, RawTokenizer, RemoveLongFilter, SimpleTokenizer, WhitespaceTokenizer,
}; };
/// The tokenizer manager serves as a store for /// The tokenizer manager serves as a store for
@@ -64,14 +63,18 @@ impl Default for TokenizerManager {
.filter(LowerCaser) .filter(LowerCaser)
.build(), .build(),
); );
manager.register( #[cfg(feature = "stemmer")]
"en_stem", {
TextAnalyzer::builder(SimpleTokenizer::default()) use crate::tokenizer::stemmer::{Language, Stemmer};
.filter(RemoveLongFilter::limit(40)) manager.register(
.filter(LowerCaser) "en_stem",
.filter(Stemmer::new(Language::English)) TextAnalyzer::builder(SimpleTokenizer::default())
.build(), .filter(RemoveLongFilter::limit(40))
); .filter(LowerCaser) // The stemmer does not lowercase
.filter(Stemmer::new(Language::English))
.build(),
);
}
manager.register("whitespace", WhitespaceTokenizer::default()); manager.register("whitespace", WhitespaceTokenizer::default());
manager manager
} }